Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mthreads: support multi-backend and introduce mthreads backend #458

Merged
merged 201 commits into from
Feb 27, 2025

Conversation

machuanjiang
Copy link
Collaborator

@machuanjiang machuanjiang commented Feb 25, 2025

PR Category

OP Test

Type of Change

Refactor

Description

adapt mthreads backend in multi-backend

Issue

Progress

  • Change is properly reviewed (1 reviewer required, 2 recommended).
  • Change is responded to an issue.
  • Change is fully covered by a UT.

Performance

added performance test adaptation for MT GPU backend

yuzhe-wu and others added 30 commits February 20, 2025 17:33
config: {BLOCK_M: 8, num_warps: 8} will cause the number of registers
within a single thread to be exceeded when the tensor shape is 4096 * 2304,
so reduce BLOCK_M to 4 to supprot cumsum.
- Torch_musa does not support fp64 input type, so CPU is used as a reference
- Does not support test_accuracy_groupnorm

- Some use cases have accuracy issues in test_embedding
StrongSpoon and others added 16 commits February 25, 2025 08:02
* add gather_backward op

* add debug log in

* perf gather backward

* rebased with master

* scatter rewrite done.

* scatter handling internally overlapping input.

* Scatter reduce now uses atomics.

* remove fp16 from scatter reduce UT.

* sets threadblock size to 128 for scatter.

* Change atomic memory order to relaxed in scatter.

---------

Co-authored-by: awayzjj <[email protected]>
Co-authored-by: StrongSpoon <[email protected]>
1. update multi-backend code
2. fix argmin op might test failed under int types

Co-authored-by: mx-flaggems-user <[email protected]>
1. resolve_conj: ref to this link: https://jira.mthreads.com/browse/MTAI-1530
2. fill: torch_musa does not support case torch.fill(dtype=cpu, dtype=musa).
* add backward of conv2d

* delete useless code

* format code of tests

* modify configs for tuning

* modify autotune config

* delete test flag

* delete useless type convert

---------

Co-authored-by: Jiang Bin <[email protected]>
Signed-off-by: jiaqi.wang <jiaqi.wang @mthreads.com>
Signed-off-by: jiaqi.wang <jiaqi.wang @mthreads.com>
Signed-off-by: jiaqi.wang <jiaqi.wang @mthreads.com>
@@ -0,0 +1,70 @@
import argparse
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file appears to be a document for internal use by the mthreads. Is there a necessity to store it on the master branch?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, just for internal usage, we will remove it

@@ -0,0 +1,100 @@
import os
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just for internal usage, we will remove it

@@ -138,6 +138,19 @@ def to_reference(inp, upcast=False):
return ref_inp


def to_reference_fp64(inp, upcast=False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this to_reference_fp64 function used?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we will remove it.

}

HEURISTICS_CONFIG = {
vendors.NVIDIA: default_heuristics_for_num_warps,
vendors.METAX: metax_heuristics_for_num_warps,
vendors.CAMBRICON: cambricon_heuristics_for_num_warps,
vendors.MTHREADS: default_heuristics_for_num_warps,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to add default_heuristics_for_num_warps here and default is default_heuristics_for_num_warps

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll remove it.

@@ -1071,7 +1071,8 @@ def __init__(self, op_desc: FunctionSchema, scalar_fn: JITFunction, config=None)

assert isinstance(scalar_fn, JITFunction)
self._scalar_fn = scalar_fn
self._scalar_fn_cache_key = scalar_fn.cache_key
# FIXME: cache_key is too long and make open file failed.
self._scalar_fn_cache_key = scalar_fn.cache_key[:33]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this problem only on mthread machines?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not occurred any more, I'll cancel this change.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is empty file.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

file mod changed to 755, I'll change back to 644

@@ -370,7 +370,7 @@ def sorted_quick_unique_flat(sorted_data: torch.Tensor, return_counts: bool):
next_power_global_ctas_num = triton.next_power_of_2(global_ctas_num)
ctas_num = global_ctas_num if global_ctas_num < 65536 else 2048
tiles_per_cta = triton.cdiv(num_tasks, tile_size * ctas_num)
num_warps = 8 if tiles_per_cta == 1 else 32
num_warps = 8 if tiles_per_cta == 1 else 8
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is a special change to mthreads, to not affect performance, please put it in the mthreads directory

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, we'll move to vendor ops directory

@@ -217,11 +217,11 @@ def isin_by_search(
elif M <= 4194304: # 2 ** 22 = 1024 * 4096
_, BLOCK_M, num_warps = launch_arg(None, 1024, M, 8)
elif M <= 8388608: # 2 ** 23 = 1024 * 8192
_, BLOCK_M, num_warps = launch_arg(None, 2048, M, 16)
_, BLOCK_M, num_warps = launch_arg(None, 2048, M, 8)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is a special change to mthreads, to not affect performance, please put it in the mthreads directory

@@ -13,7 +13,8 @@
"prompt",
["How are you today?", "What is your name?", "Who are you?", "Where are you from?"],
)
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16])
# @pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mthreads can choose to skip it. Don't change the original logic

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, this file seems an example, no effect to testing, we choose to keep it original

@@ -40,8 +42,9 @@ def resolve_conj_input_fn(shape, dtype, device):
# Sorting Operations
("topk", torch.topk, FLOAT_DTYPES, topk_input_fn),
# Complex Operations
("resolve_neg", torch.resolve_neg, [torch.cfloat], resolve_neg_input_fn),
("resolve_conj", torch.resolve_conj, [torch.cfloat], resolve_conj_input_fn),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only disabled when device == ”musa“

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only disabled when device == ”musa“

@kiddyjinjin ok, how about change to:

special_operations = [
    # Sorting Operations
    ("topk", torch.topk, FLOAT_DTYPES, topk_input_fn),
    # Complex Operations
    ("resolve_neg", torch.resolve_neg, [torch.cfloat], resolve_neg_input_fn)
    if flag_gems.device_name != 'musa' else (),
    ("resolve_conj", torch.resolve_conj, [torch.cfloat], resolve_conj_input_fn)
    if flag_gems.device_name != 'musa' else (),
]

("amax", torch.amax, FLOAT_DTYPES),
("any", torch.any, FLOAT_DTYPES),
# ("any", torch.any, FLOAT_DTYPES), # mt not support, disable
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if device == ”musa“:
forward_operations = []
else:
forward_operations = []

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if device == ”musa“: forward_operations = [] else: forward_operations = []

how about change to:

forward_operations = [
    ("all", torch.all, FLOAT_DTYPES) if flag_gems.device_name != 'musa' else (),
    ("amax", torch.amax, FLOAT_DTYPES),
    ("any", torch.any, FLOAT_DTYPES) if flag_gems.device_name != 'musa' else (),
    ...
]

("floor_divide", torch.floor_divide, INT_DTYPES),
("remainder", torch.remainder, INT_DTYPES),
# ("floor_divide", torch.floor_divide, INT_DTYPES), # mt not support, disable
# ("remainder", torch.remainder, INT_DTYPES), # mt not support, disable
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

distinguish the device information here

@@ -72,6 +72,8 @@ class BenchmarkMetrics:
tflops: Optional[float] = None
# Utilization (not implemented yet)
utilization: Optional[float] = None
# Speedup compared to base data
compared_speedup: Optional[float] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What’s the difference between 'compared_speedup' and 'speedup'

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input 2 perf log files, to calc the speedup of A against B, for example, it is used to compare the absolute latency between triton flaggems running on MTGPU and on NVGPU, to get the speedup

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"I understand. Could you add a comment to indicate that this field is used in the summary_for_plot script to calculate the speedup across log files, to avoid confusion for those reviewing the code?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no problem

@@ -119,6 +119,7 @@ def test_accuracy_outer(M, N, dtype):
gems_assert_close(res_in2_grad, ref_in2_grad, dtype, reduce_dim=M)


@pytest.mark.skip("Segmentation fault")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

skipif

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

@machuanjiang machuanjiang force-pushed the mthreads/master-250225 branch 2 times, most recently from 2a1aa57 to f0b0476 Compare February 26, 2025 09:45
@machuanjiang machuanjiang force-pushed the mthreads/master-250225 branch from f0b0476 to 4a0fda6 Compare February 26, 2025 14:15
@machuanjiang machuanjiang changed the title upstream of mthreads' fork mthreads: support multi-backend and introduce mthreads backend Feb 27, 2025
Copy link
Collaborator

@Galaxy1458 Galaxy1458 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@Galaxy1458 Galaxy1458 merged commit d5b888f into master Feb 27, 2025
8 of 9 checks passed
@Galaxy1458 Galaxy1458 deleted the mthreads/master-250225 branch February 27, 2025 07:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.