Skip to content

Commit

Permalink
Fix import bug in data_utils.py
Browse files Browse the repository at this point in the history
Summary: As the title says

Reviewed By: nmacchioni

Differential Revision: D65881268

fbshipit-source-id: 91ab130b133e2d35e15244d971882f3b51946331
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Nov 13, 2024
1 parent f63be70 commit 03fa086
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 3 deletions.
2 changes: 1 addition & 1 deletion tritonbench/operators/softmax/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def _inner():

def get_input_iter(self):
M = 4096
shapes = (tuple(M, 128 * i) for i in range(2, 100))
shapes = [(M, 128 * i) for i in range(2, 100)]
if IS_FBCODE and self.tb_args.production_shapes:
shapes = get_production_shapes(self.name, "softmax")
for M, N in shapes:
Expand Down
2 changes: 1 addition & 1 deletion tritonbench/utils/data_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .triton_ops import IS_FBCODE
from .triton_op import IS_FBCODE


def get_production_shapes(op_name, op_type):
Expand Down
6 changes: 5 additions & 1 deletion tritonbench/utils/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,12 @@ def _find_param_loc(params, key: str) -> int:
def _remove_params(params, loc):
if loc == -1:
return params
if (loc + 1) < len(params) and params[loc + 1].startswith("--"):
if loc == len(params) - 1:
return params[:loc]
if params[loc + 1].startswith("--"):
return params[:loc] + params[loc + 1 :]
if loc == len(params) - 2:
return params[:loc]
return params[:loc] + params[loc + 2 :]


Expand Down

0 comments on commit 03fa086

Please sign in to comment.