Skip to content

Build failure of TensorRT 10.16.1.11 when running ONNX Attention inside If on GPU #4757

@ALinrunrun

Description

@ALinrunrun

Description

TensorRT fails to build an ONNX model when an Attention node is placed inside an If subgraph.

The same Attention node builds successfully when it is used at the top level of the graph. However, when the node is wrapped in both branches of an ONNX If, TensorRT engine build returns None with an internal Myelin error.

This appears to be a TensorRT build/optimizer issue for ONNX Attention inside control-flow subgraphs.

Environment

TensorRT Version: 10.16.1.11

NVIDIA GPU: N/A / not detected by nvidia-smi

NVIDIA Driver Version: N/A / nvidia-smi failed

CUDA Version: N/A / nvcc not found

CUDNN Version: N/A / torch.backends.cudnn.version() returned None

Operating System: Linux 6.17.0-20-generic x86_64, glibc 2.39

Python Version (if applicable): Python 3.11.15

Tensorflow Version (if applicable): N/A

PyTorch Version (if applicable): 2.11.0+cpu

Baremetal or Container (if so, version): Baremetal / non-Docker environment (/proc/1/cgroup: 0::/init.scope)

Additional package versions:

ONNX Version: 1.21.0
ONNX Runtime Version: 1.25.1

Relevant Files

Model link:N/A

The ONNX models are generated inline by the minimal reproducible script below.

Steps To Reproduce

import numpy as np
import onnx
from onnx import helper, TensorProto
import tensorrt as trt

B, S, N, H = 1, 64, 8, 64
dtype = TensorProto.FLOAT16


def make_attn_only():
    Q = helper.make_tensor_value_info("Q", dtype, [B, N, S, H])
    K = helper.make_tensor_value_info("K", dtype, [B, N, S, H])
    V = helper.make_tensor_value_info("V", dtype, [B, N, S, H])
    Y = helper.make_tensor_value_info("Y", dtype, [B, N, S, H])

    n = helper.make_node("Attention", ["Q", "K", "V"], ["Y"])
    g = helper.make_graph([n], "g", [Q, K, V], [Y])
    m = helper.make_model(g, opset_imports=[helper.make_opsetid("", 23)])
    m.ir_version = 10
    return m.SerializeToString()


def make_attn_in_if():
    cond = helper.make_tensor_value_info("cond", TensorProto.BOOL, [])
    Q = helper.make_tensor_value_info("Q", dtype, [B, N, S, H])
    K = helper.make_tensor_value_info("K", dtype, [B, N, S, H])
    V = helper.make_tensor_value_info("V", dtype, [B, N, S, H])
    Y = helper.make_tensor_value_info("Y", dtype, [B, N, S, H])

    def branch(name):
        out = helper.make_tensor_value_info("out", dtype, [B, N, S, H])
        n = helper.make_node("Attention", ["Q", "K", "V"], ["out"])
        return helper.make_graph([n], name, [], [out])

    if_n = helper.make_node(
        "If",
        ["cond"],
        ["Y"],
        then_branch=branch("then"),
        else_branch=branch("else"),
    )

    g = helper.make_graph([if_n], "outer", [cond, Q, K, V], [Y])
    m = helper.make_model(g, opset_imports=[helper.make_opsetid("", 23)])
    m.ir_version = 10
    return m.SerializeToString()


def try_build(ob):
    logger = trt.Logger(trt.Logger.ERROR)
    builder = trt.Builder(logger)
    flags = 1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED)
    net = builder.create_network(flags)

    p = trt.OnnxParser(net, logger)
    if not p.parse(ob):
        return False, "parse failed"

    cfg = builder.create_builder_config()
    cfg.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30)

    plan = builder.build_serialized_network(net, cfg)
    return plan is not None, "ok" if plan is not None else "build returned None"


ok_top, msg_top = try_build(make_attn_only())
ok_if, msg_if = try_build(make_attn_in_if())

print("TRT top-level Attention build:", ok_top, msg_top)
print("TRT Attention-inside-If build:", ok_if, msg_if)

assert ok_top is True
assert ok_if is False

Have you tried the latest release?: Yes, reproduced with TensorRT 10.16.1.11.

Attach the captured .json and .bin files from TensorRT's API Capture tool if you're on an x86_64 Unix system Not attached. The issue is reproducible from the self-contained Python script above.

Can this model run on other frameworks? For example run ONNX model with ONNXRuntime (polygraphy run <model.onnx> --onnxrt): The top-level Attention model builds successfully in TensorRT. The failure only appears when the same Attention node is placed inside an ONNX If subgraph.

Actual output:

[05/11/2026-05:09:16] [TRT] [E] Error Code: 9: Skipping tactic 0x0000000000000000 due to exception [ir_op_builder.cpp:249: myelinOpSetInput] Called with unknown input tensor or sequence name "(Unnamed Layer* 8) [ElementWise]_output". In createMyelinOp at /_src/optimizer/myelin/codeGenerator.h:1479
[05/11/2026-05:09:16] [TRT] [E] IBuilder::buildSerializedNetwork: Error Code 10: Internal Error (Could not find any implementation for node {ForeignNode[ONNXTRT_ShapeTensorFromDims...node_of_Y_OutputLayer]}. In computeCosts at /_src/optimizer/common/tactic/optimizer.cpp:4265)
TRT top-level Attention build: True ok
TRT Attention-inside-If build: False build returned None

The top-level Attention case succeeds, while the equivalent Attention inside If fails during engine build with an internal TensorRT/Myelin error.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Module:Engine BuildIssues with building TensorRT enginesModule:ONNXIssues relating to ONNX usage and import

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions