Skip to content

Commit

Permalink
Fix nsys when running multiple ops (#75)
Browse files Browse the repository at this point in the history
Summary:
We need to isolate the operator to a single op when spawning subprocess for nsys.

Pull Request resolved: #75

Test Plan:
```
 python run.py --op embedding,rms_norm  --num-inputs 1  --metrics nsys_rep —csv --dump-csv
```

Reviewed By: FindHao

Differential Revision: D66387768

Pulled By: xuzhao9

fbshipit-source-id: 6728903b0b23a4980ae6f0b002a9b6121055e1e0
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Nov 22, 2024
1 parent 648466b commit e40e68f
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 53 deletions.
3 changes: 2 additions & 1 deletion run.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from tritonbench.operators import load_opbench_by_name
from tritonbench.operators_collection import list_operators_by_collection
from tritonbench.utils.gpu_utils import gpu_lockdown
from tritonbench.utils.parser import add_cmd_parameter, get_parser, remove_cmd_parameter
from tritonbench.utils.parser import get_parser
from tritonbench.utils.path_utils import add_cmd_parameter, remove_cmd_parameter

from tritonbench.utils.triton_op import BenchmarkOperatorResult, IS_FBCODE

Expand Down
30 changes: 0 additions & 30 deletions tritonbench/utils/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,33 +188,3 @@ def get_parser(args=None):
"Neither operator nor operator collection is specified. Running all operators in the default collection."
)
return parser


def _find_param_loc(params, key: str) -> int:
try:
return params.index(key)
except ValueError:
return -1


def _remove_params(params, loc):
if loc == -1:
return params
if loc == len(params) - 1:
return params[:loc]
if params[loc + 1].startswith("--"):
return params[:loc] + params[loc + 1 :]
if loc == len(params) - 2:
return params[:loc]
return params[:loc] + params[loc + 2 :]


def add_cmd_parameter(args: List[str], name: str, value: str) -> List[str]:
args.append(name)
args.append(value)
return args


def remove_cmd_parameter(args: List[str], name: str) -> List[str]:
loc = _find_param_loc(args, name)
return _remove_params(args, loc)
31 changes: 31 additions & 0 deletions tritonbench/utils/path_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import sys

from pathlib import Path
from typing import List

REPO_PATH = Path(os.path.abspath(__file__)).parent.parent.parent
SUBMODULE_PATH = REPO_PATH.joinpath("submodules")
Expand Down Expand Up @@ -35,3 +36,33 @@ def __enter__(self):

def __exit__(self, exc_type, exc_value, traceback):
os.environ = self.os_environ.copy()


def _find_param_loc(params, key: str) -> int:
try:
return params.index(key)
except ValueError:
return -1


def _remove_params(params, loc):
if loc == -1:
return params
if loc == len(params) - 1:
return params[:loc]
if params[loc + 1].startswith("--"):
return params[:loc] + params[loc + 1 :]
if loc == len(params) - 2:
return params[:loc]
return params[:loc] + params[loc + 2 :]


def add_cmd_parameter(args: List[str], name: str, value: str) -> List[str]:
args.append(name)
args.append(value)
return args


def remove_cmd_parameter(args: List[str], name: str) -> List[str]:
loc = _find_param_loc(args, name)
return _remove_params(args, loc)
28 changes: 6 additions & 22 deletions tritonbench/utils/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
set_random_seed,
)
from tritonbench.utils.input import input_cast
from tritonbench.utils.path_utils import add_cmd_parameter, remove_cmd_parameter

try:
from tqdm import tqdm
Expand Down Expand Up @@ -155,19 +156,6 @@ def llama_shapes():
return [(bs, n, k, None) for bs, (k, n) in product(BS, KN)]


def _find_param_loc(l, key: str) -> int:
try:
return l.index(key)
except ValueError:
return -1


def _remove_params(l, loc):
if loc == -1:
return l
return l[:loc] + l[loc + 2 :]


def _split_params_by_comma(params: Optional[str]) -> List[str]:
if params == None:
return []
Expand Down Expand Up @@ -1228,10 +1216,10 @@ def nsys_rep(self, input_id: int, fn_name: str) -> str:

op_task_args = [] if IS_FBCODE else [sys.executable]
op_task_args.extend(copy.deepcopy(sys.argv))
op_task_args = remove_cmd_parameter(op_task_args, "--op")
op_task_args = add_cmd_parameter(op_task_args, "--op", self.name)
for override_option in ["--only", "--input-id", "--num-inputs", "--metrics"]:
op_task_args = _remove_params(
op_task_args, _find_param_loc(op_task_args, override_option)
)
op_task_args = remove_cmd_parameter(op_task_args, override_option)
op_task_args.extend(
[
"--only",
Expand Down Expand Up @@ -1291,9 +1279,7 @@ def ncu_trace(
op_task_args = [] if IS_FBCODE else [sys.executable]
op_task_args.extend(copy.deepcopy(sys.argv))
for override_option in ["--only", "--input-id", "--num-inputs", "--metrics"]:
op_task_args = _remove_params(
op_task_args, _find_param_loc(op_task_args, override_option)
)
op_task_args = remove_cmd_parameter(op_task_args, override_option)
op_task_args.extend(
[
"--only",
Expand Down Expand Up @@ -1405,9 +1391,7 @@ def compile_time(

op_task_args = copy.deepcopy(self._raw_extra_args)
for override_option in ["--only", "--input-id", "--num-inputs", "--metrics"]:
op_task_args = _remove_params(
op_task_args, _find_param_loc(op_task_args, override_option)
)
op_task_args = remove_cmd_parameter(op_task_args, override_option)
op_task_args.extend(
[
"--only",
Expand Down

0 comments on commit e40e68f

Please sign in to comment.