@@ -33,48 +33,18 @@ class Operator<RmsNorm, Device::Type::kCambricon> : public RmsNorm {
3333
3434 DispatchFunc<
3535 List<DataType::kFloat16 , DataType::kBFloat16 , DataType::kFloat32 >,
36- List<Device::Type::kCambricon >>(
37- {static_cast <int64_t >(input.dtype ()),
38- static_cast <int64_t >(Device::Type::kCambricon )},
39- 0 ,
40- [&](auto input_tag) {
41- constexpr DataType IDT = static_cast <DataType>(ListGet<0 >(input_tag));
42- using InputT = TypeMapType<IDT>;
43- DispatchFunc<
44- List<DataType::kFloat16 , DataType::kBFloat16 , DataType::kFloat32 >,
45- List<Device::Type::kCambricon >>(
46- {static_cast <int64_t >(weight.dtype ()),
47- static_cast <int64_t >(Device::Type::kCambricon )},
48- 0 ,
49- [&](auto weight_tag) {
50- constexpr DataType WDT =
51- static_cast <DataType>(ListGet<0 >(weight_tag));
52- using WeightT = TypeMapType<WDT>;
53-
54- RmsnormUnion<InputT, WeightT>(
55- workspace, core_per_cluster, cluster_count, queue,
56- out.data (), input.data (), weight.data (), out_shape_.data (),
57- out_strides_.data (), input_strides_.data (), eps, ndim_);
58- },
59- " CambriconRmsNorm::operator() - weight dispatch" , List<>{});
36+ List<DataType::kFloat16 , DataType::kBFloat16 , DataType::kFloat32 >>(
37+ {input.dtype (), weight.dtype ()},
38+ [&](auto input_tag, auto weight_tag) {
39+ using InputT = typename decltype (input_tag)::type;
40+ using WeightT = typename decltype (weight_tag)::type;
41+
42+ RmsnormUnion<InputT, WeightT>(
43+ workspace, core_per_cluster, cluster_count, queue,
44+ out.data (), input.data (), weight.data (), out_shape_.data (),
45+ out_strides_.data (), input_strides_.data (), eps, ndim_);
6046 },
61- " CambriconRmsNorm::operator() - output dispatch" , List<>{});
62- // DispatchFunc<List<DataType::kFloat16, DataType::kBFloat16,
63- // DataType::kFloat32>,
64- // List<DataType::kFloat16, DataType::kBFloat16,
65- // DataType::kFloat32>>(
66- // {input.dtype(), weight.dtype()},
67- // [&](auto input_tag, auto weight_tag) {
68- // using InputT = typename decltype(input_tag)::type;
69- // using WeightT = typename decltype(weight_tag)::type;
70-
71- // RmsnormUnion<InputT, WeightT>(
72- // workspace, core_per_cluster, cluster_count, queue,
73- // out.data(), input.data(), weight.data(),
74- // out_shape_.data(), out_strides_.data(),
75- // input_strides_.data(), eps, ndim_);
76- // },
77- // "CambriconRmsNorm::operator()");
47+ " CambriconRmsNorm::operator() - output dispatch" );
7848 }
7949
8050 ~Operator () { cnrtFree (default_workspace_); }
0 commit comments