Skip to content

Error Code 1: Cask (isConsistent) #121

@MHGL

Description

@MHGL

question

I get this error while convert module to tensorrt

  • module has 5 down sample
  • upsample at last down sample
  • torch.cat

To Reproduce

Steps to reproduce the behavior:

  1. code example
import torch
import torch.nn.functional as F

class MyModule(torch.nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        for i in range(1, 6):
            setattr(self, f"down{i}", torch.nn.Conv2d(3, 3, 3, 2, padding=1))

    def forward(self, x):
        x1 = self.down1(x)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        x4 = self.down4(x3)
        x5 = self.down5(x4)
        return torch.cat([x4, F.interpolate(x5, scale_factor=2)], 1)

torch_model = MyModule()

# torch.onnx.export
torch.onnx.export(torch_model,
    torch.randn(1, 3, 224, 224),
    "./tmp.onnx",
    input_names=["inputs"],
    output_names=["outputs"],
    dynamic_axes={"inputs": {0: "batch", 2: "height", 3: "width"}, "outputs": {0: "batch", 1: "class", 2: "height", 3: "width"}},
    opset_version=11,
    export_params=True)

import os
onnx_file = os.path.join(os.getcwd(), "tmp.onnx")

# onnx -> tensorrt
# !!!
# you should build tensorrt first
import tensorrt as trt

TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
with trt.Builder(TRT_LOGGER) as builder, builder.create_network(1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) as network, trt.OnnxParser(network, TRT_LOGGER) as parser:
    with open(onnx_file, 'rb') as model:
        parser.parse(model.read())

    config = builder.create_builder_config()

    profile = builder.create_optimization_profile()
    profile.set_shape("inputs", (1, 3, 1, 1), (1, 3, 224, 224), (1, 3, 2000, 2000))
    config.add_optimization_profile(profile)

    engine = builder.build_engine(network, config)
    with open("tmp.trt", "wb") as f:
        f.write(engine.serialize())
  1. stack traces
  • sometimes i failed and get this
mini_code.py:54: DeprecationWarning: Use build_serialized_network instead.
  engine = builder.build_engine(network, config)
[TensorRT] WARNING: Convolution + generic activation fusion is disable due to incompatible driver or nvrtc
[TensorRT] WARNING: TensorRT was linked against cuBLAS/cuBLAS LT 11.4.2 but loaded cuBLAS/cuBLAS LT 11.2.1
[TensorRT] WARNING: Detected invalid timing cache, setup a local cache instead
[TensorRT] ERROR: 1: [convolutionBuilder.cpp::createConvolution::184] Error Code 1: Cask (isConsistent)
Traceback (most recent call last):
  File "mini_code.py", line 56, in <module>
    f.write(engine.serialize())
AttributeError: 'NoneType' object has no attribute 'serialize'
  • sometimes i succeed and get this
mini_code.py:54: DeprecationWarning: Use build_serialized_network instead.
  engine = builder.build_engine(network, config)
[TensorRT] WARNING: Convolution + generic activation fusion is disable due to incompatible driver or nvrtc
[TensorRT] WARNING: TensorRT was linked against cuBLAS/cuBLAS LT 11.4.2 but loaded cuBLAS/cuBLAS LT 11.2.1
[TensorRT] WARNING: Detected invalid timing cache, setup a local cache instead
[TensorRT] WARNING: Max value of this profile is not valid
[TensorRT] WARNING: Min value of this profile is not valid
[TensorRT] WARNING: TensorRT was linked against cuBLAS/cuBLAS LT 11.4.2 but loaded cuBLAS/cuBLAS LT 11.2.1

Expected behavior

Environment

  • TensorRT Version: 8.0.0.3
  • PyTorch Version: 1.9.0
  • OS (e.g., MacOS, Linux): Ubuntu20.04 LTS
  • How you install python (anaconda, virtualenv, system): miniconda
  • python version (e.g. 3.7): 3.8.5
  • any other relevant information:
    • gpu: GeForce GTX 1650
    • driver: Driver Version: 460.80
    • CUDA: CUDA Version: 11.2

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions