Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch 1.13 - 2.3 Onnx Scope name not correct! #90439

Closed
kevalmorabia97 opened this issue Dec 8, 2022 · 9 comments
Closed

Torch 1.13 - 2.3 Onnx Scope name not correct! #90439

kevalmorabia97 opened this issue Dec 8, 2022 · 9 comments
Labels
low priority We're unlikely to get around to doing this in the near future module: onnx Related to torch.onnx triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@kevalmorabia97
Copy link
Contributor

kevalmorabia97 commented Dec 8, 2022

🐛 Describe the bug

In torch 1.13.0, Onnx Scope name was introduced which replaces the old Onnx node format (Conv_0, Relu_1, ...) into a new format that captures the pytorch submodule name. But the actual onnx node name may not always correspond to a submodule and instead needs some further post-processing.

Example:

import io
import onnx
import torch
from torchvision.models import resnet18

buffer = io.BytesIO()
torch.onnx.export(resnet18(), torch.randn(1, 3, 224, 224), buffer)
buffer.seek(0, 0)
onnx_model = onnx.load(buffer)
for node in onnx_model.graph.node:
    print(node.name)

Output:

...
/conv1/Conv
/relu/Relu
/maxpool/MaxPool
/layer1/layer1.0/conv1/Conv
/layer1/layer1.0/relu/Relu
/layer1/layer1.0/conv2/Conv
...
/layer4/layer4.1/relu_1/Relu
/avgpool/GlobalAveragePool
/Flatten
/fc/Gemm

As we can see here for ResNet18, there is no submodule Conv named layer1.layer1.0.conv1 in the model but rather layer1.0.conv1. So, its onnx node name should be /layer1/0/conv1/Conv.

Sidenote: Ideally it would have been more helpful to have the submodule name as the onnx node name i.e. layer1.0.conv1 instead of /layer1/layer1.0/conv1/Conv.

Versions

Collecting environment information...
PyTorch version: 1.13.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 13.0.1 (arm64)
GCC version: Could not collect
Clang version: 14.0.0 (clang-1400.0.29.202)
CMake version: version 3.22.1
Libc version: N/A

Python version: 3.8.13 | packaged by conda-forge | (default, Mar 25 2022, 06:04:14)  [Clang 12.0.1 ] (64-bit runtime)
Python platform: macOS-13.0.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] mypy==0.971
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.23.1
[pip3] pytorch-sphinx-theme==0.0.19
[pip3] torch==1.13.0
[pip3] torchaudio==0.13.0.dev20221003
[pip3] torchpack==0.3.1
[pip3] torchprofile==0.0.4
[pip3] torchvision==0.14.0
[conda] cpuonly                   2.0                           0    pytorch-nightly
[conda] numpy                     1.23.1           py38h42add53_0  
[conda] numpy-base                1.23.1           py38hadd41eb_0  
[conda] pytorch-mutex             1.0                         cpu    pytorch-nightly
[conda] pytorch-sphinx-theme      0.0.19                   pypi_0    pypi
[conda] torch                     1.13.0                   pypi_0    pypi
[conda] torchaudio                0.13.0.dev20221003        py38_cpu    pytorch-nightly
[conda] torchpack                 0.3.1                    pypi_0    pypi
[conda] torchprofile              0.0.4                    pypi_0    pypi
[conda] torchvision               0.14.0                   pypi_0    pypi
@soulitzer soulitzer added module: onnx Related to torch.onnx triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Dec 9, 2022
@BowenBao
Copy link
Collaborator

BowenBao commented Jan 10, 2023

Thanks for reporting, it looks like the discrepancy was due to a fix for Sequential module being skipped by torch in scope recording, will need to investigate.

# Run this without patching numerical atom `_unqualified_variable_name`
class MainModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.module_list = torch.nn.ModuleList(
            [torch.nn.Linear(10, 10) for _ in range(2)]
        )

    def forward(self, x):
        y = self.module_list[0](x)
        z = self.module_list[1](y)
        return z

module = MainModule()
print(module)
f = io.BytesIO()
torch.onnx.export(module, torch.randn(1, 10), f, verbose=True)
f.seek(0, 0)
onnx_model = onnx.load(f)
for node in onnx_model.graph.node:
    print(node.name)

/0/Gemm
/1/Gemm

where it should be

/module_list/0/Gemm
/module_list/1/Gemm

@kevalmorabia97
Copy link
Contributor Author

@BowenBao thanks for looking into this. I am getting the following output with this model you suggested:

/module_list.0/Gemm
/module_list.1/Gemm

@kevalmorabia97
Copy link
Contributor Author

kevalmorabia97 commented Jan 12, 2023

I tested intel-isl/MiDaS model from torch.hub where I can also see onnx node names incorrect.

module = torch.hub.load("intel-isl/MiDaS:f28885af", "MiDaS_small")
input = torch.randn(1, 3, 320, 640)

All onnx node names are missing /pretrained prefixes where it should be there. For example, I am seeing /layer3/layer3.1/layer3.1.0/conv_pw/Conv instead of /pretrained/layer3/layer3.1/layer3.1.0/conv_pw/Conv. Here model.pretrained.layer3 is a Sequential module.

@BowenBao
Copy link
Collaborator

@kevalmorabia97 thanks for providing another great example. I took another look and it turns out pytorch populates scope name at _slow_forward method, which is invoked by forward method, in nn.Module class. Hence that's why my simple example above doesn't capture the Sequential module name, as its forward was not invoked.

This is also the reason why pretrained is not appearing for intel-isl/MiDaS, if I'm looking at the correct code
https://github.com/isl-org/MiDaS/blob/master/midas/midas_net.py#L59-L62.

A more precise solution is probably to store the full qualified name of modules as scope name, instead of relying on the pytorch scope tree structure. However, it may have implications on other components.

@kevalmorabia97
Copy link
Contributor Author

Thanks @BowenBao for the explanation.

Do you have any rough estimate on how much effort is required in the fix and whether it could make it in the next release? I find pytorch scope name to be super useful as it helps in mapping the onnx nodes with the pytorch modules providing insights at the pytorch module level.

@BowenBao
Copy link
Collaborator

I will bring this to the attention of our team. We need to figure if there is enough bandwidth.
Optimistically the improvement itself shouldn't require too much amount of work; we just need to make sure it doesn't break other stuff.

@kevalmorabia97
Copy link
Contributor Author

@BowenBao any updates on this?

@justinchuby justinchuby added the low priority We're unlikely to get around to doing this in the near future label Jan 23, 2024
@thiagocrepaldi
Copy link
Collaborator

@BowenBao any updates on this?

Could you retest this, please?

@kevalmorabia97
Copy link
Contributor Author

Still seeing same issue with torch 2.3.0 on the intel-isl/MiDaS model

@kevalmorabia97 kevalmorabia97 changed the title Torch 1.13 Onnx Scope name not correct! Torch 1.13 - 2.3 Onnx Scope name not correct! May 4, 2024
@BowenBao BowenBao removed their assignment May 16, 2024
@justinchuby justinchuby closed this as not planned Won't fix, can't repro, duplicate, stale Oct 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
low priority We're unlikely to get around to doing this in the near future module: onnx Related to torch.onnx triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Status: Done
Development

No branches or pull requests

5 participants