Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix nsys when running multiple ops #75

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion run.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from tritonbench.operators import load_opbench_by_name
from tritonbench.operators_collection import list_operators_by_collection
from tritonbench.utils.gpu_utils import gpu_lockdown
from tritonbench.utils.parser import add_cmd_parameter, get_parser, remove_cmd_parameter
from tritonbench.utils.parser import get_parser
from tritonbench.utils.path_utils import add_cmd_parameter, remove_cmd_parameter

from tritonbench.utils.triton_op import BenchmarkOperatorResult, IS_FBCODE

Expand Down
30 changes: 0 additions & 30 deletions tritonbench/utils/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,33 +188,3 @@ def get_parser(args=None):
"Neither operator nor operator collection is specified. Running all operators in the default collection."
)
return parser


def _find_param_loc(params, key: str) -> int:
try:
return params.index(key)
except ValueError:
return -1


def _remove_params(params, loc):
if loc == -1:
return params
if loc == len(params) - 1:
return params[:loc]
if params[loc + 1].startswith("--"):
return params[:loc] + params[loc + 1 :]
if loc == len(params) - 2:
return params[:loc]
return params[:loc] + params[loc + 2 :]


def add_cmd_parameter(args: List[str], name: str, value: str) -> List[str]:
args.append(name)
args.append(value)
return args


def remove_cmd_parameter(args: List[str], name: str) -> List[str]:
loc = _find_param_loc(args, name)
return _remove_params(args, loc)
31 changes: 31 additions & 0 deletions tritonbench/utils/path_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import sys

from pathlib import Path
from typing import List

REPO_PATH = Path(os.path.abspath(__file__)).parent.parent.parent
SUBMODULE_PATH = REPO_PATH.joinpath("submodules")
Expand Down Expand Up @@ -35,3 +36,33 @@ def __enter__(self):

def __exit__(self, exc_type, exc_value, traceback):
os.environ = self.os_environ.copy()


def _find_param_loc(params, key: str) -> int:
try:
return params.index(key)
except ValueError:
return -1


def _remove_params(params, loc):
if loc == -1:
return params
if loc == len(params) - 1:
return params[:loc]
if params[loc + 1].startswith("--"):
return params[:loc] + params[loc + 1 :]
if loc == len(params) - 2:
return params[:loc]
return params[:loc] + params[loc + 2 :]


def add_cmd_parameter(args: List[str], name: str, value: str) -> List[str]:
args.append(name)
args.append(value)
return args


def remove_cmd_parameter(args: List[str], name: str) -> List[str]:
loc = _find_param_loc(args, name)
return _remove_params(args, loc)
28 changes: 6 additions & 22 deletions tritonbench/utils/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
set_random_seed,
)
from tritonbench.utils.input import input_cast
from tritonbench.utils.path_utils import add_cmd_parameter, remove_cmd_parameter

try:
from tqdm import tqdm
Expand Down Expand Up @@ -155,19 +156,6 @@ def llama_shapes():
return [(bs, n, k, None) for bs, (k, n) in product(BS, KN)]


def _find_param_loc(l, key: str) -> int:
try:
return l.index(key)
except ValueError:
return -1


def _remove_params(l, loc):
if loc == -1:
return l
return l[:loc] + l[loc + 2 :]


def _split_params_by_comma(params: Optional[str]) -> List[str]:
if params == None:
return []
Expand Down Expand Up @@ -1228,10 +1216,10 @@ def nsys_rep(self, input_id: int, fn_name: str) -> str:

op_task_args = [] if IS_FBCODE else [sys.executable]
op_task_args.extend(copy.deepcopy(sys.argv))
op_task_args = remove_cmd_parameter(op_task_args, "--op")
op_task_args = add_cmd_parameter(op_task_args, "--op", self.name)
for override_option in ["--only", "--input-id", "--num-inputs", "--metrics"]:
op_task_args = _remove_params(
op_task_args, _find_param_loc(op_task_args, override_option)
)
op_task_args = remove_cmd_parameter(op_task_args, override_option)
op_task_args.extend(
[
"--only",
Expand Down Expand Up @@ -1291,9 +1279,7 @@ def ncu_trace(
op_task_args = [] if IS_FBCODE else [sys.executable]
op_task_args.extend(copy.deepcopy(sys.argv))
for override_option in ["--only", "--input-id", "--num-inputs", "--metrics"]:
op_task_args = _remove_params(
op_task_args, _find_param_loc(op_task_args, override_option)
)
op_task_args = remove_cmd_parameter(op_task_args, override_option)
op_task_args.extend(
[
"--only",
Expand Down Expand Up @@ -1405,9 +1391,7 @@ def compile_time(

op_task_args = copy.deepcopy(self._raw_extra_args)
for override_option in ["--only", "--input-id", "--num-inputs", "--metrics"]:
op_task_args = _remove_params(
op_task_args, _find_param_loc(op_task_args, override_option)
)
op_task_args = remove_cmd_parameter(op_task_args, override_option)
op_task_args.extend(
[
"--only",
Expand Down
Loading