-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
mthreads: support multi-backend and introduce mthreads backend #458
Conversation
Signed-off-by: Jian Li <[email protected]>
config: {BLOCK_M: 8, num_warps: 8} will cause the number of registers within a single thread to be exceeded when the tensor shape is 4096 * 2304, so reduce BLOCK_M to 4 to supprot cumsum.
Signed-off-by: Jian Li <[email protected]>
Signed-off-by: Jian Li <[email protected]>
- Torch_musa does not support fp64 input type, so CPU is used as a reference
- Does not support test_accuracy_groupnorm - Some use cases have accuracy issues in test_embedding
Signed-off-by: Jian Li <[email protected]>
Signed-off-by: Jian Li <[email protected]>
Signed-off-by: Jian Li <[email protected]>
Signed-off-by: Jian Li <[email protected]>
Signed-off-by: jiaqi.wang <[email protected]>
* add gather_backward op * add debug log in * perf gather backward * rebased with master * scatter rewrite done. * scatter handling internally overlapping input. * Scatter reduce now uses atomics. * remove fp16 from scatter reduce UT. * sets threadblock size to 128 for scatter. * Change atomic memory order to relaxed in scatter. --------- Co-authored-by: awayzjj <[email protected]> Co-authored-by: StrongSpoon <[email protected]>
1. update multi-backend code 2. fix argmin op might test failed under int types Co-authored-by: mx-flaggems-user <[email protected]>
Co-authored-by: junjian.zhan <[email protected]>
Signed-off-by: jiaqi.wang <[email protected]>
1. resolve_conj: ref to this link: https://jira.mthreads.com/browse/MTAI-1530 2. fill: torch_musa does not support case torch.fill(dtype=cpu, dtype=musa).
* add backward of conv2d * delete useless code * format code of tests * modify configs for tuning * modify autotune config * delete test flag * delete useless type convert --------- Co-authored-by: Jiang Bin <[email protected]>
Signed-off-by: jiaqi.wang <[email protected]>
Signed-off-by: jiaqi.wang <[email protected]>
Signed-off-by: jiaqi.wang <jiaqi.wang @mthreads.com>
… into mthreads/master-250225
Signed-off-by: jiaqi.wang <jiaqi.wang @mthreads.com>
Signed-off-by: jiaqi.wang <jiaqi.wang @mthreads.com>
tests/mt_test_entry.py
Outdated
@@ -0,0 +1,70 @@ | |||
import argparse |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file appears to be a document for internal use by the mthreads. Is there a necessity to store it on the master branch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, just for internal usage, we will remove it
tests/test_50_ops.py
Outdated
@@ -0,0 +1,100 @@ | |||
import os |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just for internal usage, we will remove it
tests/accuracy_utils.py
Outdated
@@ -138,6 +138,19 @@ def to_reference(inp, upcast=False): | |||
return ref_inp | |||
|
|||
|
|||
def to_reference_fp64(inp, upcast=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is this to_reference_fp64
function used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we will remove it.
} | ||
|
||
HEURISTICS_CONFIG = { | ||
vendors.NVIDIA: default_heuristics_for_num_warps, | ||
vendors.METAX: metax_heuristics_for_num_warps, | ||
vendors.CAMBRICON: cambricon_heuristics_for_num_warps, | ||
vendors.MTHREADS: default_heuristics_for_num_warps, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to add default_heuristics_for_num_warps here and default is default_heuristics_for_num_warps
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll remove it.
@@ -1071,7 +1071,8 @@ def __init__(self, op_desc: FunctionSchema, scalar_fn: JITFunction, config=None) | |||
|
|||
assert isinstance(scalar_fn, JITFunction) | |||
self._scalar_fn = scalar_fn | |||
self._scalar_fn_cache_key = scalar_fn.cache_key | |||
# FIXME: cache_key is too long and make open file failed. | |||
self._scalar_fn_cache_key = scalar_fn.cache_key[:33] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this problem only on mthread machines?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not occurred any more, I'll cancel this change.
src/flag_gems/ops/weightnorm.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is empty file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
file mod changed to 755, I'll change back to 644
src/flag_gems/ops/unique.py
Outdated
@@ -370,7 +370,7 @@ def sorted_quick_unique_flat(sorted_data: torch.Tensor, return_counts: bool): | |||
next_power_global_ctas_num = triton.next_power_of_2(global_ctas_num) | |||
ctas_num = global_ctas_num if global_ctas_num < 65536 else 2048 | |||
tiles_per_cta = triton.cdiv(num_tasks, tile_size * ctas_num) | |||
num_warps = 8 if tiles_per_cta == 1 else 32 | |||
num_warps = 8 if tiles_per_cta == 1 else 8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is a special change to mthreads, to not affect performance, please put it in the mthreads directory
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, we'll move to vendor ops directory
src/flag_gems/ops/isin.py
Outdated
@@ -217,11 +217,11 @@ def isin_by_search( | |||
elif M <= 4194304: # 2 ** 22 = 1024 * 4096 | |||
_, BLOCK_M, num_warps = launch_arg(None, 1024, M, 8) | |||
elif M <= 8388608: # 2 ** 23 = 1024 * 8192 | |||
_, BLOCK_M, num_warps = launch_arg(None, 2048, M, 16) | |||
_, BLOCK_M, num_warps = launch_arg(None, 2048, M, 8) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is a special change to mthreads, to not affect performance, please put it in the mthreads directory
examples/model_bert_test.py
Outdated
@@ -13,7 +13,8 @@ | |||
"prompt", | |||
["How are you today?", "What is your name?", "Who are you?", "Where are you from?"], | |||
) | |||
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) | |||
# @pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) | |||
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mthreads can choose to skip it. Don't change the original logic
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, this file seems an example, no effect to testing, we choose to keep it original
@@ -40,8 +42,9 @@ def resolve_conj_input_fn(shape, dtype, device): | |||
# Sorting Operations | |||
("topk", torch.topk, FLOAT_DTYPES, topk_input_fn), | |||
# Complex Operations | |||
("resolve_neg", torch.resolve_neg, [torch.cfloat], resolve_neg_input_fn), | |||
("resolve_conj", torch.resolve_conj, [torch.cfloat], resolve_conj_input_fn), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only disabled when device == ”musa“
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only disabled when device == ”musa“
@kiddyjinjin ok, how about change to:
special_operations = [
# Sorting Operations
("topk", torch.topk, FLOAT_DTYPES, topk_input_fn),
# Complex Operations
("resolve_neg", torch.resolve_neg, [torch.cfloat], resolve_neg_input_fn)
if flag_gems.device_name != 'musa' else (),
("resolve_conj", torch.resolve_conj, [torch.cfloat], resolve_conj_input_fn)
if flag_gems.device_name != 'musa' else (),
]
benchmark/test_reduction_perf.py
Outdated
("amax", torch.amax, FLOAT_DTYPES), | ||
("any", torch.any, FLOAT_DTYPES), | ||
# ("any", torch.any, FLOAT_DTYPES), # mt not support, disable |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if device == ”musa“:
forward_operations = []
else:
forward_operations = []
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if device == ”musa“: forward_operations = [] else: forward_operations = []
how about change to:
forward_operations = [
("all", torch.all, FLOAT_DTYPES) if flag_gems.device_name != 'musa' else (),
("amax", torch.amax, FLOAT_DTYPES),
("any", torch.any, FLOAT_DTYPES) if flag_gems.device_name != 'musa' else (),
...
]
("floor_divide", torch.floor_divide, INT_DTYPES), | ||
("remainder", torch.remainder, INT_DTYPES), | ||
# ("floor_divide", torch.floor_divide, INT_DTYPES), # mt not support, disable | ||
# ("remainder", torch.remainder, INT_DTYPES), # mt not support, disable |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
distinguish the device information here
@@ -72,6 +72,8 @@ class BenchmarkMetrics: | |||
tflops: Optional[float] = None | |||
# Utilization (not implemented yet) | |||
utilization: Optional[float] = None | |||
# Speedup compared to base data | |||
compared_speedup: Optional[float] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What’s the difference between 'compared_speedup' and 'speedup'
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
input 2 perf log files, to calc the speedup of A against B, for example, it is used to compare the absolute latency between triton flaggems running on MTGPU and on NVGPU, to get the speedup
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"I understand. Could you add a comment to indicate that this field is used in the summary_for_plot
script to calculate the speedup across log files, to avoid confusion for those reviewing the code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no problem
tests/test_blas_ops.py
Outdated
@@ -119,6 +119,7 @@ def test_accuracy_outer(M, N, dtype): | |||
gems_assert_close(res_in2_grad, ref_in2_grad, dtype, reduce_dim=M) | |||
|
|||
|
|||
@pytest.mark.skip("Segmentation fault") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
skipif
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
2a1aa57
to
f0b0476
Compare
Signed-off-by: chuanjiang.ma <[email protected]>
f0b0476
to
4a0fda6
Compare
Signed-off-by: chuanjiang.ma <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
PR Category
OP Test
Type of Change
Refactor
Description
adapt mthreads backend in multi-backend
Issue
Progress
Performance
added performance test adaptation for MT GPU backend