Skip to content

Commit 19d52c6

Browse files
committed
finish23
1 parent 731a3a8 commit 19d52c6

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

python/tvm/topi/vision/nms.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -227,12 +227,14 @@ def _collect_selected_indices_ir(
227227
out = ib.buffer_ptr(out)
228228

229229
# Initialize output buffer to zero
230-
# We need to get the output shape from the function signature
231-
# For now, we'll initialize only the first few rows that we know will be used
232-
# This is a temporary fix - the proper solution would be to pass shape info
233-
with ib.for_range(
234-
0, batch_classes * 10, name="init_i"
235-
) as init_i: # Initialize up to 10 rows per batch_class
230+
# Calculate the actual output shape based on max_output_boxes_per_class
231+
if isinstance(max_output_boxes_per_class, int):
232+
max_output_rows = batch_classes * max_output_boxes_per_class
233+
else:
234+
# Fallback to a reasonable default if max_output_boxes_per_class is not an integer
235+
max_output_rows = batch_classes * 10
236+
237+
with ib.for_range(0, max_output_rows, name="init_i") as init_i:
236238
with ib.for_range(0, 3, name="init_j") as init_j: # 3 columns
237239
out[init_i, init_j] = cast(0, "int64")
238240

0 commit comments

Comments
 (0)