From e40e68faab2e7486c38f3dd96c0c23220c80ef60 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Fri, 22 Nov 2024 15:06:08 -0800 Subject: [PATCH] Fix nsys when running multiple ops (#75) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: We need to isolate the operator to a single op when spawning subprocess for nsys. Pull Request resolved: https://github.com/pytorch-labs/tritonbench/pull/75 Test Plan: ``` python run.py --op embedding,rms_norm --num-inputs 1 --metrics nsys_rep —csv --dump-csv ``` Reviewed By: FindHao Differential Revision: D66387768 Pulled By: xuzhao9 fbshipit-source-id: 6728903b0b23a4980ae6f0b002a9b6121055e1e0 --- run.py | 3 ++- tritonbench/utils/parser.py | 30 ------------------------------ tritonbench/utils/path_utils.py | 31 +++++++++++++++++++++++++++++++ tritonbench/utils/triton_op.py | 28 ++++++---------------------- 4 files changed, 39 insertions(+), 53 deletions(-) diff --git a/run.py b/run.py index 9d353c2b..fdc3fb6b 100644 --- a/run.py +++ b/run.py @@ -17,7 +17,8 @@ from tritonbench.operators import load_opbench_by_name from tritonbench.operators_collection import list_operators_by_collection from tritonbench.utils.gpu_utils import gpu_lockdown -from tritonbench.utils.parser import add_cmd_parameter, get_parser, remove_cmd_parameter +from tritonbench.utils.parser import get_parser +from tritonbench.utils.path_utils import add_cmd_parameter, remove_cmd_parameter from tritonbench.utils.triton_op import BenchmarkOperatorResult, IS_FBCODE diff --git a/tritonbench/utils/parser.py b/tritonbench/utils/parser.py index e72053ae..e423cbbe 100644 --- a/tritonbench/utils/parser.py +++ b/tritonbench/utils/parser.py @@ -188,33 +188,3 @@ def get_parser(args=None): "Neither operator nor operator collection is specified. Running all operators in the default collection." ) return parser - - -def _find_param_loc(params, key: str) -> int: - try: - return params.index(key) - except ValueError: - return -1 - - -def _remove_params(params, loc): - if loc == -1: - return params - if loc == len(params) - 1: - return params[:loc] - if params[loc + 1].startswith("--"): - return params[:loc] + params[loc + 1 :] - if loc == len(params) - 2: - return params[:loc] - return params[:loc] + params[loc + 2 :] - - -def add_cmd_parameter(args: List[str], name: str, value: str) -> List[str]: - args.append(name) - args.append(value) - return args - - -def remove_cmd_parameter(args: List[str], name: str) -> List[str]: - loc = _find_param_loc(args, name) - return _remove_params(args, loc) diff --git a/tritonbench/utils/path_utils.py b/tritonbench/utils/path_utils.py index b12f285f..33e5ddb8 100644 --- a/tritonbench/utils/path_utils.py +++ b/tritonbench/utils/path_utils.py @@ -2,6 +2,7 @@ import sys from pathlib import Path +from typing import List REPO_PATH = Path(os.path.abspath(__file__)).parent.parent.parent SUBMODULE_PATH = REPO_PATH.joinpath("submodules") @@ -35,3 +36,33 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): os.environ = self.os_environ.copy() + + +def _find_param_loc(params, key: str) -> int: + try: + return params.index(key) + except ValueError: + return -1 + + +def _remove_params(params, loc): + if loc == -1: + return params + if loc == len(params) - 1: + return params[:loc] + if params[loc + 1].startswith("--"): + return params[:loc] + params[loc + 1 :] + if loc == len(params) - 2: + return params[:loc] + return params[:loc] + params[loc + 2 :] + + +def add_cmd_parameter(args: List[str], name: str, value: str) -> List[str]: + args.append(name) + args.append(value) + return args + + +def remove_cmd_parameter(args: List[str], name: str) -> List[str]: + loc = _find_param_loc(args, name) + return _remove_params(args, loc) diff --git a/tritonbench/utils/triton_op.py b/tritonbench/utils/triton_op.py index 3d6e57e1..1cad72bf 100644 --- a/tritonbench/utils/triton_op.py +++ b/tritonbench/utils/triton_op.py @@ -30,6 +30,7 @@ set_random_seed, ) from tritonbench.utils.input import input_cast +from tritonbench.utils.path_utils import add_cmd_parameter, remove_cmd_parameter try: from tqdm import tqdm @@ -155,19 +156,6 @@ def llama_shapes(): return [(bs, n, k, None) for bs, (k, n) in product(BS, KN)] -def _find_param_loc(l, key: str) -> int: - try: - return l.index(key) - except ValueError: - return -1 - - -def _remove_params(l, loc): - if loc == -1: - return l - return l[:loc] + l[loc + 2 :] - - def _split_params_by_comma(params: Optional[str]) -> List[str]: if params == None: return [] @@ -1228,10 +1216,10 @@ def nsys_rep(self, input_id: int, fn_name: str) -> str: op_task_args = [] if IS_FBCODE else [sys.executable] op_task_args.extend(copy.deepcopy(sys.argv)) + op_task_args = remove_cmd_parameter(op_task_args, "--op") + op_task_args = add_cmd_parameter(op_task_args, "--op", self.name) for override_option in ["--only", "--input-id", "--num-inputs", "--metrics"]: - op_task_args = _remove_params( - op_task_args, _find_param_loc(op_task_args, override_option) - ) + op_task_args = remove_cmd_parameter(op_task_args, override_option) op_task_args.extend( [ "--only", @@ -1291,9 +1279,7 @@ def ncu_trace( op_task_args = [] if IS_FBCODE else [sys.executable] op_task_args.extend(copy.deepcopy(sys.argv)) for override_option in ["--only", "--input-id", "--num-inputs", "--metrics"]: - op_task_args = _remove_params( - op_task_args, _find_param_loc(op_task_args, override_option) - ) + op_task_args = remove_cmd_parameter(op_task_args, override_option) op_task_args.extend( [ "--only", @@ -1405,9 +1391,7 @@ def compile_time( op_task_args = copy.deepcopy(self._raw_extra_args) for override_option in ["--only", "--input-id", "--num-inputs", "--metrics"]: - op_task_args = _remove_params( - op_task_args, _find_param_loc(op_task_args, override_option) - ) + op_task_args = remove_cmd_parameter(op_task_args, override_option) op_task_args.extend( [ "--only",