@@ -525,3 +525,50 @@ def forward(self, x, b=5, c=None, d=None):
525
525
engine = convert_exported_program_to_serialized_trt_engine (
526
526
exp_program , ** compile_spec
527
527
)
528
+
529
+
530
+ def test_custom_model_compile_engine_with_pure_kwarg_inputs ():
531
+ class net (nn .Module ):
532
+ def __init__ (self ):
533
+ super ().__init__ ()
534
+ self .conv1 = nn .Conv2d (3 , 12 , 3 , padding = 1 )
535
+ self .bn = nn .BatchNorm2d (12 )
536
+ self .conv2 = nn .Conv2d (12 , 12 , 3 , padding = 1 )
537
+ self .fc1 = nn .Linear (12 * 56 * 56 , 10 )
538
+
539
+ def forward (self , x , b = 5 , c = None , d = None ):
540
+ x = self .conv1 (x )
541
+ x = F .relu (x )
542
+ x = self .bn (x )
543
+ x = F .max_pool2d (x , (2 , 2 ))
544
+ x = self .conv2 (x )
545
+ x = F .relu (x )
546
+ x = F .max_pool2d (x , (2 , 2 ))
547
+ x = torch .flatten (x , 1 )
548
+ x = x + b
549
+ if c is not None :
550
+ x = x * c
551
+ if d is not None :
552
+ x = x - d ["value" ]
553
+ return self .fc1 (x )
554
+
555
+ model = net ().eval ().to ("cuda" )
556
+ kwargs = {
557
+ "x" : torch .rand ((1 , 3 , 224 , 224 )).to ("cuda" ),
558
+ "b" : torch .tensor (6 ).to ("cuda" ),
559
+ "d" : {"value" : torch .tensor (8 ).to ("cuda" )},
560
+ }
561
+
562
+ compile_spec = {
563
+ "arg_inputs" : (),
564
+ "kwarg_inputs" : kwargs ,
565
+ "device" : torchtrt .Device ("cuda:0" ),
566
+ "enabled_precisions" : {torch .float },
567
+ "pass_through_build_failures" : True ,
568
+ "optimization_level" : 1 ,
569
+ "min_block_size" : 1 ,
570
+ "ir" : "dynamo" ,
571
+ }
572
+
573
+ exp_program = torch .export .export (model , args = (), kwargs = kwargs )
574
+ _ = convert_exported_program_to_serialized_trt_engine (exp_program , ** compile_spec )
0 commit comments