@@ -3199,54 +3199,47 @@ def test_nms():
3199
3199
3200
3200
model = helper .make_model (graph , producer_name = "nms_test" )
3201
3201
model .opset_import [0 ].version = 11
3202
-
3202
+
3203
3203
# Use deterministic random inputs for consistent testing
3204
3204
bg = np .random .MT19937 (0 )
3205
3205
rg = np .random .Generator (bg )
3206
3206
boxes = rg .standard_normal (size = boxes_shape ).astype (np .float32 )
3207
3207
scores = rg .standard_normal (size = scores_shape ).astype (np .float32 )
3208
3208
inputs = {"boxes" : boxes , "scores" : scores }
3209
-
3209
+
3210
3210
# Run ONNX Runtime
3211
3211
ort_session = onnxruntime .InferenceSession (
3212
3212
model .SerializeToString (), providers = ["CPUExecutionProvider" ]
3213
3213
)
3214
3214
ort_output = ort_session .run ([], inputs )
3215
-
3215
+
3216
3216
# Run TVM
3217
3217
tvm_model = from_onnx (model , opset = 11 , keep_params_in_input = True )
3218
3218
tvm_model = relax .transform .DecomposeOpsForInference ()(tvm_model )
3219
3219
tvm_model = relax .transform .LegalizeOps ()(tvm_model )
3220
3220
tvm_model , params = relax .frontend .detach_params (tvm_model )
3221
-
3221
+
3222
3222
with tvm .transform .PassContext (opt_level = 3 ):
3223
3223
ex = tvm .compile (tvm_model , target = "llvm" )
3224
3224
vm = relax .VirtualMachine (ex , tvm .cpu ())
3225
-
3225
+
3226
3226
input_list = [
3227
3227
inputs [key .name_hint ] for key in tvm_model ["main" ].params if key .name_hint in inputs
3228
3228
]
3229
3229
if params :
3230
3230
input_list += params ["main" ]
3231
-
3231
+
3232
3232
vm .set_input ("main" , * input_list )
3233
3233
vm .invoke_stateful ("main" )
3234
3234
tvm_output = vm .get_outputs ("main" )
3235
-
3236
- # Custom NMS output comparison
3237
- # TVM outputs fixed shape (6,3), ONNX Runtime outputs dynamic shape (varies)
3238
- # We only compare the valid rows based on the actual output count
3235
+
3239
3236
if isinstance (tvm_output , (list , tuple )):
3240
3237
tvm_selected = tvm_output [0 ].numpy ()
3241
3238
else :
3242
3239
tvm_selected = tvm_output .numpy ()
3243
3240
ort_selected = ort_output [0 ]
3244
-
3245
- # For NMS, compare only the number of valid rows
3246
- # TVM may output more rows with garbage data, but the first N rows should match
3241
+
3247
3242
min_rows = min (tvm_selected .shape [0 ], ort_selected .shape [0 ])
3248
-
3249
- # Compare the first min_rows rows
3250
3243
if min_rows > 0 :
3251
3244
tvm .testing .assert_allclose (
3252
3245
tvm_selected [:min_rows ], ort_selected [:min_rows ], rtol = 1e-5 , atol = 1e-5
0 commit comments