Skip to content

Commit 8e75039

Browse files
authored
bugfix: allow empty tuple for inputs or arg_inputs (#3122)
1 parent ae7e6c8 commit 8e75039

File tree

2 files changed

+49
-2
lines changed

2 files changed

+49
-2
lines changed

py/torch_tensorrt/dynamo/_compiler.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -604,10 +604,10 @@ def convert_exported_program_to_serialized_trt_engine(
604604
DeprecationWarning,
605605
stacklevel=2,
606606
)
607-
if not arg_inputs and not inputs:
607+
if arg_inputs is None and inputs is None:
608608
raise AssertionError("'arg_inputs' and 'inputs' should not both be None.")
609609

610-
elif arg_inputs and inputs:
610+
elif arg_inputs is not None and inputs is not None:
611611
raise AssertionError(
612612
"'arg_inputs' and 'inputs' should not be used at the same time."
613613
)

tests/py/dynamo/models/test_export_kwargs_serde.py

+47
Original file line numberDiff line numberDiff line change
@@ -525,3 +525,50 @@ def forward(self, x, b=5, c=None, d=None):
525525
engine = convert_exported_program_to_serialized_trt_engine(
526526
exp_program, **compile_spec
527527
)
528+
529+
530+
def test_custom_model_compile_engine_with_pure_kwarg_inputs():
531+
class net(nn.Module):
532+
def __init__(self):
533+
super().__init__()
534+
self.conv1 = nn.Conv2d(3, 12, 3, padding=1)
535+
self.bn = nn.BatchNorm2d(12)
536+
self.conv2 = nn.Conv2d(12, 12, 3, padding=1)
537+
self.fc1 = nn.Linear(12 * 56 * 56, 10)
538+
539+
def forward(self, x, b=5, c=None, d=None):
540+
x = self.conv1(x)
541+
x = F.relu(x)
542+
x = self.bn(x)
543+
x = F.max_pool2d(x, (2, 2))
544+
x = self.conv2(x)
545+
x = F.relu(x)
546+
x = F.max_pool2d(x, (2, 2))
547+
x = torch.flatten(x, 1)
548+
x = x + b
549+
if c is not None:
550+
x = x * c
551+
if d is not None:
552+
x = x - d["value"]
553+
return self.fc1(x)
554+
555+
model = net().eval().to("cuda")
556+
kwargs = {
557+
"x": torch.rand((1, 3, 224, 224)).to("cuda"),
558+
"b": torch.tensor(6).to("cuda"),
559+
"d": {"value": torch.tensor(8).to("cuda")},
560+
}
561+
562+
compile_spec = {
563+
"arg_inputs": (),
564+
"kwarg_inputs": kwargs,
565+
"device": torchtrt.Device("cuda:0"),
566+
"enabled_precisions": {torch.float},
567+
"pass_through_build_failures": True,
568+
"optimization_level": 1,
569+
"min_block_size": 1,
570+
"ir": "dynamo",
571+
}
572+
573+
exp_program = torch.export.export(model, args=(), kwargs=kwargs)
574+
_ = convert_exported_program_to_serialized_trt_engine(exp_program, **compile_spec)

0 commit comments

Comments
 (0)