Skip to content

Commit c913436

Browse files
committed
fix: fix rmsnorm dispatch to use one dispatch
1 parent de90bab commit c913436

File tree

1 file changed

+11
-41
lines changed

1 file changed

+11
-41
lines changed

src/cambricon/rmsnorm/rms_norm.h

Lines changed: 11 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -33,48 +33,18 @@ class Operator<RmsNorm, Device::Type::kCambricon> : public RmsNorm {
3333

3434
DispatchFunc<
3535
List<DataType::kFloat16, DataType::kBFloat16, DataType::kFloat32>,
36-
List<Device::Type::kCambricon>>(
37-
{static_cast<int64_t>(input.dtype()),
38-
static_cast<int64_t>(Device::Type::kCambricon)},
39-
0,
40-
[&](auto input_tag) {
41-
constexpr DataType IDT = static_cast<DataType>(ListGet<0>(input_tag));
42-
using InputT = TypeMapType<IDT>;
43-
DispatchFunc<
44-
List<DataType::kFloat16, DataType::kBFloat16, DataType::kFloat32>,
45-
List<Device::Type::kCambricon>>(
46-
{static_cast<int64_t>(weight.dtype()),
47-
static_cast<int64_t>(Device::Type::kCambricon)},
48-
0,
49-
[&](auto weight_tag) {
50-
constexpr DataType WDT =
51-
static_cast<DataType>(ListGet<0>(weight_tag));
52-
using WeightT = TypeMapType<WDT>;
53-
54-
RmsnormUnion<InputT, WeightT>(
55-
workspace, core_per_cluster, cluster_count, queue,
56-
out.data(), input.data(), weight.data(), out_shape_.data(),
57-
out_strides_.data(), input_strides_.data(), eps, ndim_);
58-
},
59-
"CambriconRmsNorm::operator() - weight dispatch", List<>{});
36+
List<DataType::kFloat16, DataType::kBFloat16, DataType::kFloat32>>(
37+
{input.dtype(), weight.dtype()},
38+
[&](auto input_tag, auto weight_tag) {
39+
using InputT = typename decltype(input_tag)::type;
40+
using WeightT = typename decltype(weight_tag)::type;
41+
42+
RmsnormUnion<InputT, WeightT>(
43+
workspace, core_per_cluster, cluster_count, queue,
44+
out.data(), input.data(), weight.data(), out_shape_.data(),
45+
out_strides_.data(), input_strides_.data(), eps, ndim_);
6046
},
61-
"CambriconRmsNorm::operator() - output dispatch", List<>{});
62-
// DispatchFunc<List<DataType::kFloat16, DataType::kBFloat16,
63-
// DataType::kFloat32>,
64-
// List<DataType::kFloat16, DataType::kBFloat16,
65-
// DataType::kFloat32>>(
66-
// {input.dtype(), weight.dtype()},
67-
// [&](auto input_tag, auto weight_tag) {
68-
// using InputT = typename decltype(input_tag)::type;
69-
// using WeightT = typename decltype(weight_tag)::type;
70-
71-
// RmsnormUnion<InputT, WeightT>(
72-
// workspace, core_per_cluster, cluster_count, queue,
73-
// out.data(), input.data(), weight.data(),
74-
// out_shape_.data(), out_strides_.data(),
75-
// input_strides_.data(), eps, ndim_);
76-
// },
77-
// "CambriconRmsNorm::operator()");
47+
"CambriconRmsNorm::operator() - output dispatch");
7848
}
7949

8050
~Operator() { cnrtFree(default_workspace_); }

0 commit comments

Comments
 (0)