Skip to content

rislab/pytorchviz

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

32 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

PyTorchViz

A small package to create visualizations of PyTorch execution graphs and traces.

Open In Colab

Installation

Install graphviz, e.g.:

brew install graphviz

Install the package itself:

pip install torchviz

Install this fork

pip install git+https://github.com/rislab/pytorchviz.git

Or if you plan on making changes locally.

git clone git@github.com:rislab/pytorchviz.git
cd pytorchviz
pip install -e .

Usage

Example usage of make_dot:

model = nn.Sequential()
model.add_module('W0', nn.Linear(8, 16))
model.add_module('tanh', nn.Tanh())
model.add_module('W1', nn.Linear(16, 1))

x = torch.randn(1, 8)
y = model(x)

make_dot(y.mean(), params=dict(model.named_parameters()))

image

Labeling a tensor

Example usage of labeling a tensor with label_var(tensor, "label").

See test/test_labels.py::test_label_var.

import torch
import torchviz as tv

A = torch.randn(3, 3, requires_grad=True)
B = torch.randn(3, 3, requires_grad=True)
C = A @ B
# label a tensor
C = tv.label_var(C, "C")
loss = C.sum()

dot = tv.make_dot(loss, params={"A": A, "B": B})
dot.view()

labelled

Labeling an arg

Example of labeling an arg.

See test/test_labels.py::test_label_args.

import torch
import torchviz as tv

A = torch.randn(3, 3, requires_grad=True)
B = torch.randn(3, 3, requires_grad=True)
# label the args
A = tv.label_arg(A, "A")
B = tv.label_arg(B, "B")
C = A @ B
# label a tensor
C = tv.label_var(C, "C")
loss = C.sum()

dot = tv.make_dot(loss, params={"A": A, "B": B})
dot.view()

labelled

Labeling a return value

Example of labeling an return value. When both the args and return values are labelled, a subgraph for that function will automatically be created.

See test/test_labels.py::test_label_args_rets.

import torch
import torchviz as tv

A = torch.randn(3, 3, requires_grad=True)
B = torch.randn(3, 3, requires_grad=True)
# label the args
A = tv.label_arg(A, "A")
B = tv.label_arg(B, "B")
C = A @ B
# label return value
C = tv.label_ret(C, "C")
loss = C.sum()

dot = tv.make_dot(loss, params={"A": A, "B": B})
dot.view()

labelled

Labeling a function

Example of labeling a function. When the function takes in tensor args and kwargs and returns a tensor or tuple of tensors, we can use a decorator label_fn to do the labeling for us. It will call label_arg and label_ret under the hood and track which function it belongs to (as well as how many times the function was called), so we get a separate groups for calls to the same function. All you have to do is label the return values label_fn("ret 1", "ret 2",...).

See test/test_labels.py::test_label_fn.

import torch
import torchviz as tv

@tv.label_fn("A", "C")
def foo(A, B):
    C = A @ B
    return C, A

@tv.label_fn()
def bar(X, Y):
    Z, X = foo(X, Y)
    X, _ = foo(X, Y)
    W = X * Z
    return W

A = torch.randn(3, 3, requires_grad=True)
B = torch.randn(3, 3, requires_grad=True)
C = bar(A, B)
loss = C.sum()

dot = tv.make_dot(loss, params={"A": A, "B": B})
dot.render('test/test_label_fn', format='png', view=True, cleanup=True)

labelled

Acknowledgements

The script was moved from functional-zoo where it was created with the help of Adam Paszke, Soumith Chintala, Anton Osokin, and uses bits from tensorboard-pytorch. Other contributors are @willprice, @soulitzer, @albanD.

About

A small package to create visualizations of PyTorch execution graphs

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Jupyter Notebook 92.4%
  • Python 7.5%
  • Makefile 0.1%