diff --git a/.github/workflows/ascend-build-and-test.yml b/.github/workflows/ascend-build-and-test.yml index 9e33248f2..8f69be82f 100644 --- a/.github/workflows/ascend-build-and-test.yml +++ b/.github/workflows/ascend-build-and-test.yml @@ -78,37 +78,29 @@ jobs: run: | set -x source /usr/local/Ascend/ascend-toolkit/set_env.sh - python3 third_party/tests/ascend/vector-add.py - python3 third_party/ascend/examples/tutorials/01-vector-add.py - python3 third_party/ascend/examples/tutorials/02-fused-softmax.py - python3 third_party/ascend/examples/tutorials/03-layer-norm.py - python3 third_party/ascend/examples/tutorials/04-fused-attention.py - python3 third_party/ascend/examples/tutorials/06-demo-autotune.py - python3 third_party/ascend/examples/tutorials/07-profiler.py - python3 third_party/ascend/examples/tutorials/09-gather.py - python3 third_party/ascend/examples/tutorials/10-gather_sorted.py - python3 third_party/ascend/examples/tutorials/11-rab_time.py - python3 third_party/ascend/examples/tutorials/13-matrix-multiplication-optimized.py - python3 third_party/ascend/examples/tutorials/13-matrix-multiplication-optimized-flagtree.py - python3 third_party/ascend/examples/tutorials/14-accuracy-comparison.py - python3 python/test/ops/01_vector_add/01_vector_add.py - python3 python/test/ops/abs/abs.py - python3 python/test/ops/addmm/addmm.py - python3 python/test/ops/addmm/addmm_ascend.py - python3 python/test/ops/amax/amax.py - python3 python/test/ops/amax/amax_ascend_perf.py - python3 python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb.py - python3 python/test/ops/apply_rotary_pos_emb/apply_rotary_pos_emb_ascend.py - python3 python/test/ops/argmin/argmin.py - python3 python/test/ops/argmin/argmin_ascend_perf.py - python3 python/test/ops/bmm/bmm_ascend.py - python3 python/test/ops/cumsum/cumsum.py - python3 python/test/ops/min_dim/min_dim.py - python3 python/test/ops/min_dim/min_dim_ascend_perf.py - python3 python/test/ops/sum_dim/sum_dim.py - python3 python/test/ops/varmean/var_mean_ascend.py - python3 -m pytest third_party/ascend/examples/pytest_ut --ignore=third_party/ascend/examples/pytest_ut/test_index_select.py \ - --ignore=third_party/ascend/examples/pytest_ut/test_linearize_permute.py \ - --ignore=third_party/ascend/examples/pytest_ut/test_logical_and.py \ - --ignore=third_party/ascend/examples/pytest_ut/test_logical_or.py \ - --ignore=third_party/ascend/examples/pytest_ut/test_triton_unified_attention.py + # tutorials + pushd third_party/ascend/tutorials + python3 01-vector-add.py + python3 02-fused-softmax.py + python3 03-layer-norm.py + python3 04-fused-attention.py + python3 06-demo-autotune.py + python3 07-profiler.py + python3 08-demo-libentry.py + python3 09-gather.py + python3 10-gather_sorted.py + python3 11-rab_time.py + python3 12-hstu_attention.py + python3 13-matrix-multiplication-optimized.py + python3 14-accuracy-comparison.py + python3 15-embedding_gather_demo.py + popd + # pytest_ut + pushd third_party/ascend/unittest/pytest_ut + python3 -m pytest . \ + --ignore=test_index_select.py \ + --ignore=test_linearize_permute.py \ + --ignore=test_logical_and.py \ + --ignore=test_logical_or.py \ + --ignore=test_triton_unified_attention.py + popd diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b55d88627..99abab826 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,7 +21,7 @@ repos: hooks: - id: ruff files: '^python/.*' - args: ["--fix", "--line-length", "120"] + args: ["--fix", "--line-length", "120", "--per-file-ignores", "*/__init__.py:E402"] stages: [pre-commit, pre-push, manual] exclude: | (?x)( diff --git a/python/triton/__init__.py b/python/triton/__init__.py index 08872dae0..bdc88f951 100644 --- a/python/triton/__init__.py +++ b/python/triton/__init__.py @@ -1,3 +1,7 @@ +# flagtree backend path specialization +from triton.runtime.driver import spec_path + +spec_path(__path__) """isort:skip_file""" __version__ = '3.2.0' diff --git a/python/triton/compiler/__init__.py b/python/triton/compiler/__init__.py index bbe8c047c..50a74d38f 100644 --- a/python/triton/compiler/__init__.py +++ b/python/triton/compiler/__init__.py @@ -1,3 +1,8 @@ +# flagtree backend path specialization +from triton.runtime.driver import spec_path + +spec_path(__path__) + from .compiler import CompiledKernel, ASTSource, compile, make_backend, LazyDict from .errors import CompilationError diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py index 6502a5348..3908513c1 100644 --- a/python/triton/language/__init__.py +++ b/python/triton/language/__init__.py @@ -1,3 +1,7 @@ +# flagtree backend path specialization +from triton.runtime.driver import spec_path + +spec_path(__path__) """isort:skip_file""" # Import order is significant here. @@ -255,6 +259,11 @@ "zeros_like", ] +# flagtree backend specialization +from triton.runtime.driver import spec + +__all__ = spec("language_modify_all", __all__) or __all__ + def str_to_ty(name): if name[0] == "*": diff --git a/python/triton/runtime/__init__.py b/python/triton/runtime/__init__.py index 0b3979d28..9c15f13bc 100644 --- a/python/triton/runtime/__init__.py +++ b/python/triton/runtime/__init__.py @@ -1,3 +1,8 @@ +# flagtree backend path specialization +from triton.runtime.driver import spec_path + +spec_path(__path__) + from .autotuner import (Autotuner, Config, Heuristics, autotune, heuristics) from .cache import RedisRemoteCacheBackend, RemoteCacheBackend from .driver import driver diff --git a/python/triton/runtime/driver.py b/python/triton/runtime/driver.py index 4cfaa7af7..ee5ea6da2 100644 --- a/python/triton/runtime/driver.py +++ b/python/triton/runtime/driver.py @@ -78,3 +78,25 @@ def spec_func(function_name: str): func = getattr(spec, function_name) return func return None + + +# flagtree backend path specialization +def spec_path(path_list: list): + """ + TODO: Read "ascend" from FLAGTREE_BACKEND file. + example: input __path__ = ["python/triton/compiler"] + backend_path = "third_party/ascend/backend/spec/python/triton/compiler" + __path__ = [backend_path, "python/triton/compiler"] + """ + import os + if not path_list: + return + current_path = path_list[0] + current_path = current_path.replace(os.sep, "/") + marker = "python/triton" + idx = current_path.find(marker) + if idx != -1: + rel_path = current_path[idx:] + backend_path = os.path.join("third_party/ascend/backend/spec", rel_path) + if os.path.isdir(backend_path): + path_list.insert(0, backend_path) diff --git a/third_party/ascend/CMakeLists.txt b/third_party/ascend/CMakeLists.txt index 1ca32d82d..e8e08543e 100644 --- a/third_party/ascend/CMakeLists.txt +++ b/third_party/ascend/CMakeLists.txt @@ -1,15 +1,42 @@ -#add_subdirectory(triton-adapter triton-adapter) -#add_subdirectory(test) - add_subdirectory(backend/spec/lib) add_subdirectory(${PROJECT_SOURCE_DIR}/include ${PROJECT_BINARY_DIR}/include) add_subdirectory(${PROJECT_SOURCE_DIR}/lib ${PROJECT_BINARY_DIR}/lib) -add_triton_plugin(TritonAscend ${CMAKE_CURRENT_SOURCE_DIR}/triton_ascend.cpp) +add_triton_library(Registrar Registrar.cc) + +list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") + +# Triton Ascend is dependent on AscendNPU IR +#set(ASCENDNPU_IR_SRC_DIR "${PROJECT_SOURCE_DIR}/third_party/ascendnpu-ir") +#set(ASCENDNPU_IR_BINARY_DIR "${PROJECT_BINARY_DIR}/third_party/ascendnpu-ir") + +set(BISHENGIR_ENABLE_A5_UNPUBLISHED_FEATURES ON) +set(BISHENGIR_BUILD_STANDALONE_IR_ONLY ON) + +#add_subdirectory(${ASCENDNPU_IR_SRC_DIR} ${ASCENDNPU_IR_BINARY_DIR}) +#include_directories(${ASCENDNPU_IR_SRC_DIR}/bishengir/include) +#include_directories(${ASCENDNPU_IR_BINARY_DIR}/bishengir/include) # Tablegen'd files + +#add_subdirectory(include) +#add_subdirectory(lib) + +add_triton_plugin(TritonAscend + ${CMAKE_CURRENT_SOURCE_DIR}/triton_ascend.cc + ${CMAKE_CURRENT_SOURCE_DIR}/ascend_ir.cc + + LINK_LIBS + TritonToLinalgIncubated + BiShengIRScopeDialect + BiShengIRHIVMDialect +) + +# target_link_libraries(TritonAscend PRIVATE Python3::Module pybind11::headers) target_include_directories(TritonAscend PRIVATE - ${CMAKE_SOURCE_DIR}/third_party/flir/include - ${CMAKE_BINARY_DIR}/third_party/flir/include) + ${CMAKE_SOURCE_DIR}/third_party/flir/include + ${CMAKE_BINARY_DIR}/third_party/flir/include) -add_triton_library(Registrar Registrar.cc) +if(TRITON_BUILD_UT) + add_subdirectory(unittest) +endif() diff --git a/third_party/ascend/README.md b/third_party/ascend/README.md deleted file mode 100644 index e69de29bb..000000000 diff --git a/third_party/ascend/ascend_ir.cc b/third_party/ascend/ascend_ir.cc new file mode 100644 index 000000000..7cab69c45 --- /dev/null +++ b/third_party/ascend/ascend_ir.cc @@ -0,0 +1,401 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * Copyright 2018-2020 Philippe Tillet + * Copyright 2020-2022 OpenAI + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "ir.h" +#include "pybind11/pybind11.h" +#include + +#include "bishengir/Dialect/Annotation/IR/Annotation.h" +#include "bishengir/Dialect/HIVM/IR/HIVM.h" +#include "bishengir/Dialect/Scope/IR/Scope.h" + +#include "mlir/AsmParser/AsmParser.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Types.h" +#include "mlir/Support/LLVM.h" +#include "llvm/IR/Instructions.h" + +using namespace mlir; +namespace py = pybind11; + +struct AscendNPUIROpBuilder : public TritonOpBuilder { + std::string target; + static constexpr char kTarget910_95[] = "Ascend910_95"; + + explicit AscendNPUIROpBuilder(MLIRContext *context, std::string target = "") + : TritonOpBuilder(context), target(target) {} + + bool is_910_95() { + // TODO: Use enum instead of strings after enabling HACC in satandalone + // build + constexpr size_t kTargetLen = sizeof(kTarget910_95) - 1; + return target.size() >= kTargetLen && + target.compare(0, kTargetLen, kTarget910_95) == 0; + } +}; + +namespace { +struct ModeAndPipes { + hivm::SyncBlockModeAttr modeAttr = {}; + hivm::PipeAttr cubePipe = {}; + hivm::PipeAttr vectorPipe = {}; +}; + +hivm::TCoreTypeAttr GetCore(MLIRContext *ctx, llvm::StringRef opName, + llvm::StringRef sender) { + // Decide core type + hivm::TCoreTypeAttr core; + if (sender == "cube") { + if (opName == "sync_block_set") + core = hivm::TCoreTypeAttr::get(ctx, hivm::TCoreType::CUBE); + else + core = hivm::TCoreTypeAttr::get(ctx, hivm::TCoreType::VECTOR); + } else { + if (sender != "vector") { + throw std::runtime_error( + "sync_block_set/wait only supports 'cube' or 'vector' as sender"); + } + if (opName == "sync_block_set") + core = hivm::TCoreTypeAttr::get(ctx, hivm::TCoreType::VECTOR); + else + core = hivm::TCoreTypeAttr::get(ctx, hivm::TCoreType::CUBE); + } + + return core; +} + +void buildSyncBlockOp(AscendNPUIROpBuilder &self, const std::string &opName, + std::string &sender, std::string &receiver, Value id, + hivm::PIPE senderPipe, hivm::PIPE receiverPipe) { + auto *ctx = self.getBuilder().getContext(); + hivm::TCoreTypeAttr coreAttr = GetCore(ctx, opName, sender); + hivm::PipeAttr prodPipe = hivm::PipeAttr::get(ctx, senderPipe); + hivm::PipeAttr consPipe = hivm::PipeAttr::get(ctx, receiverPipe); + const size_t I64 = 64; + auto i64Ty = IntegerType::get(ctx, I64); + Value idI64 = id; + if (!id.getType().isInteger(I64)) { + idI64 = mlir::convertScalarToDtype(self.getBuilder(), id.getLoc(), id, + i64Ty, true); + } + if (opName == "sync_block_set") { + self.create(coreAttr, prodPipe, consPipe, idI64); + } else if (opName == "sync_block_wait") { + self.create(coreAttr, prodPipe, consPipe, idI64); + } else { + throw std::runtime_error("Unsupported operation name for SyncBlockOp"); + } +} + +ModeAndPipes GetSyncBlockModeAndPipes(MLIRContext *ctx, + const std::string &mode) { + hivm::SyncBlockModeAttr modeAttr = {}; + hivm::PipeAttr cubePipe = {}; + hivm::PipeAttr vectorPipe = {}; + + if (mode == "all_cube") { + modeAttr = hivm::SyncBlockModeAttr::get(ctx, hivm::SyncBlockMode::ALL_CUBE); + cubePipe = hivm::PipeAttr::get(ctx, hivm::PIPE::PIPE_ALL); + vectorPipe = hivm::PipeAttr{}; + } else if (mode == "all_vector") { + modeAttr = + hivm::SyncBlockModeAttr::get(ctx, hivm::SyncBlockMode::ALL_VECTOR); + cubePipe = hivm::PipeAttr{}; + vectorPipe = hivm::PipeAttr::get(ctx, hivm::PIPE::PIPE_ALL); + } else if (mode == "all") { + modeAttr = hivm::SyncBlockModeAttr::get(ctx, hivm::SyncBlockMode::ALL); + cubePipe = hivm::PipeAttr::get(ctx, hivm::PIPE::PIPE_ALL); + vectorPipe = hivm::PipeAttr::get(ctx, hivm::PIPE::PIPE_ALL); + } else if (mode == "all_sub_vector") { + modeAttr = + hivm::SyncBlockModeAttr::get(ctx, hivm::SyncBlockMode::ALL_SUB_VECTOR); + cubePipe = hivm::PipeAttr{}; + vectorPipe = hivm::PipeAttr::get(ctx, hivm::PIPE::PIPE_ALL); + } else { + llvm::report_fatal_error( + llvm::StringRef("Invalid sync-block mode: " + mode)); + } + return {modeAttr, cubePipe, vectorPipe}; +} +} // namespace + +void init_ascend_ir(py::module &&m) { + py::enum_(m, "AddressSpace", py::module_local()) + .value("L1", hivm::AddressSpace::L1) + .value("UB", hivm::AddressSpace::UB) + .value("L0A", hivm::AddressSpace::L0A) + .value("L0B", hivm::AddressSpace::L0B) + .value("L0C", hivm::AddressSpace::L0C) + .export_values(); + + py::enum_(m, "CoreType", py::module_local()) + .value("CUBE", hivm::TCoreType::CUBE) + .value("VECTOR", hivm::TCoreType::VECTOR) + .value("CUBE_OR_VECTOR", hivm::TCoreType::CUBE_OR_VECTOR) + .value("CUBE_AND_VECTOR", hivm::TCoreType::CUBE_AND_VECTOR) + .export_values(); + + py::enum_(m, "PIPE", py::module_local()) + .value("PIPE_S", hivm::PIPE::PIPE_S) + .value("PIPE_V", hivm::PIPE::PIPE_V) + .value("PIPE_M", hivm::PIPE::PIPE_M) + .value("PIPE_MTE1", hivm::PIPE::PIPE_MTE1) + .value("PIPE_MTE2", hivm::PIPE::PIPE_MTE2) + .value("PIPE_MTE3", hivm::PIPE::PIPE_MTE3) + .value("PIPE_ALL", hivm::PIPE::PIPE_ALL) + .value("PIPE_FIX", hivm::PIPE::PIPE_FIX) + .export_values(); + + py::enum_(m, "MODE", py::module_local()) + .value("SIMD", hivm::VFMode::SIMD) + .value("SIMT", hivm::VFMode::SIMT) + .value("MIX", hivm::VFMode::MIX) + .export_values(); + + py::enum_(m, "FixpipeDMAMode", py::module_local()) + .value("NZ2DN", hivm::FixpipeDMAMode::NZ2DN) + .value("NZ2ND", hivm::FixpipeDMAMode::NZ2ND) + .value("NZ2NZ", hivm::FixpipeDMAMode::NZ2NZ) + .export_values(); + + py::enum_(m, "FixpipeDualDstMode", + py::module_local()) + .value("NO_DUAL", hivm::FixpipeDualDstMode::NO_DUAL) + .value("COLUMN_SPLIT", hivm::FixpipeDualDstMode::COLUMN_SPLIT) + .value("ROW_SPLIT", hivm::FixpipeDualDstMode::ROW_SPLIT) + .export_values(); + + py::enum_(m, "FixpipePreQuantMode", + py::module_local()) + .value("NO_QUANT", hivm::FixpipePreQuantMode::NO_QUANT) + .value("F322BF16", hivm::FixpipePreQuantMode::F322BF16) + .value("F322F16", hivm::FixpipePreQuantMode::F322F16) + .value("S322I8", hivm::FixpipePreQuantMode::S322I8) + .export_values(); + + py::enum_(m, "FixpipePreReluMode", + py::module_local()) + .value("LEAKY_RELU", hivm::FixpipePreReluMode::LEAKY_RELU) + .value("NO_RELU", hivm::FixpipePreReluMode::NO_RELU) + .value("NORMAL_RELU", hivm::FixpipePreReluMode::NORMAL_RELU) + .value("P_RELU", hivm::FixpipePreReluMode::P_RELU) + .export_values(); + py::enum_(m, "DataLayout", py::module_local()) + .value("nZ", hivm::DataLayout::nZ) + .value("zN", hivm::DataLayout::zN) + .export_values(); + + m.def("load_dialects", [](MLIRContext &context) { + // Allow unregistered dialects so we can parse HACC attributes without + // registering the dialect + context.allowUnregisteredDialects(); + DialectRegistry registry; + registry.insert(); + context.appendDialectRegistry(registry); + context.loadAllAvailableDialects(); + }); + + py::class_( + m, "ascendnpu_ir_builder", py::module_local(), py::dynamic_attr()) + .def(py::init(), py::arg("context"), + py::arg("target") = "") + .def("get_core_type_attr", + [](AscendNPUIROpBuilder &self, + hivm::TCoreType core_type) -> Attribute { + return self.getBuilder().getAttr(core_type); + }) + .def("get_pipe_attr", + [](AscendNPUIROpBuilder &self, hivm::PIPE pipe) -> Attribute { + return self.getBuilder().getAttr(pipe); + }) + .def("get_vf_mode_attr", + [](AscendNPUIROpBuilder &self, hivm::VFMode mode) -> Attribute { + return self.getBuilder().getAttr(mode); + }) + .def("get_t_core_type_attr_name", + [](AscendNPUIROpBuilder &self) -> std::string { + return hivm::TCoreTypeAttr::name.str(); + }) + .def("get_t_core_type_cube_attr", + [](AscendNPUIROpBuilder &self) -> Attribute { + return hivm::TCoreTypeAttr::get(self.getBuilder().getContext(), + hivm::TCoreType::CUBE); + }) + .def("get_t_core_type_vector_attr", + [](AscendNPUIROpBuilder &self) -> Attribute { + return hivm::TCoreTypeAttr::get(self.getBuilder().getContext(), + hivm::TCoreType::VECTOR); + }) + .def("parse_attr", + [](TritonOpBuilder &self, std::string value) -> Attribute { + auto *ctx = self.getBuilder().getContext(); + return mlir::parseAttribute(value, ctx); + }) + .def("create_fixpipe", + [](AscendNPUIROpBuilder &self, Value src, Value dst, + hivm::FixpipeDMAMode dma_mode, + hivm::FixpipeDualDstMode dual_dst_mode, + hivm::FixpipePreQuantMode pre_quant_mode, + hivm::FixpipePreReluMode pre_relu_mode) -> void { + if (!dyn_cast(src.getType())) { + llvm_unreachable("src is not of RankedTensorType"); + } + if (!dyn_cast(dst.getType())) { + llvm_unreachable("dst is not of MemRefType"); + } + auto *ctx = self.getBuilder().getContext(); + auto dma_mode_attr = + mlir::hivm::FixpipeDMAModeAttr::get(ctx, dma_mode); + auto dual_dst_mode_attr = + mlir::hivm::FixpipeDualDstModeAttr::get(ctx, dual_dst_mode); + auto pre_quant_mode_attr = + mlir::hivm::FixpipePreQuantModeAttr::get(ctx, pre_quant_mode); + auto pre_relu_mode_attr = + mlir::hivm::FixpipePreReluModeAttr::get(ctx, pre_relu_mode); + auto channel_split = BoolAttr::get(ctx, false); + auto op = self.create( + mlir::TypeRange{}, src, dst, dma_mode_attr, dual_dst_mode_attr, + pre_quant_mode_attr, pre_relu_mode_attr, channel_split); + }) + .def("create_bind_buffer", + [](TritonOpBuilder &self, Value &src, Value &alloc) -> void { + auto ctx = self.getBuilder().getContext(); + auto bind = StringAttr::get(ctx, "bind_buffer"); + self.create(src, ValueRange{alloc}, + ArrayAttr::get(ctx, bind)); + }) + .def("create_debug_barrier", + [](TritonOpBuilder &self, Value &ptr, const std::string &attrKey, + Attribute &attrVal) { + auto annotationOp = self.create(ptr); + annotationOp->setAttr(self.getBuilder().getStringAttr(attrKey), + attrVal); + }) + .def("create_custom_op", + [](AscendNPUIROpBuilder &self, const std::string &name, + const py::dict &attrs, const std::vector &ins, + const std::vector &outs) -> std::vector { + ValueRange inputs{ins}; + ValueRange outputs{outs}; + TypeRange res_types{outputs}; + auto op = + self.create(res_types, name, inputs, outputs); + for (auto &attr : attrs) { + std::string attr_name = py::cast(attr.first); + Attribute attr_value = py::cast(attr.second); + op->setAttr(attr_name, attr_value); + } + auto results = op->getResults(); + return std::vector(results.begin(), results.end()); + }) + .def("create_scope_op", + [](AscendNPUIROpBuilder &self, py::dict &scopeAttrs, + std::vector resultTypes) -> OpState { + llvm::SmallVector attrs; + for (auto item : scopeAttrs) { + std::string key = py::cast(item.first); + Attribute value = py::cast(item.second); + attrs.push_back( + NamedAttribute(self.getBuilder().getStringAttr(key), value)); + } + auto scopeOp = self.create(TypeRange(resultTypes)); + scopeOp->setAttrs(attrs); + return OpState(scopeOp); + }) + .def("scope_return", + [](AscendNPUIROpBuilder &self, + std::vector operands) -> OpState { + return self.create(ValueRange(operands)); + }) + .def("sync_block_set", + [](AscendNPUIROpBuilder &self, std::string &sender, + std::string &receiver, Value id, hivm::PIPE senderPipe, + hivm::PIPE receiverPipe) -> void { + buildSyncBlockOp(self, "sync_block_set", sender, receiver, id, + senderPipe, receiverPipe); + }) + .def("sync_block_wait", + [](AscendNPUIROpBuilder &self, std::string &sender, + std::string &receiver, Value id, hivm::PIPE senderPipe, + hivm::PIPE receiverPipe) -> void { + buildSyncBlockOp(self, "sync_block_wait", sender, receiver, id, + senderPipe, receiverPipe); + }) + .def("get_target_attribute", + [](AscendNPUIROpBuilder &self, + hivm::AddressSpace &addressSpace) -> Attribute { + return hivm::AddressSpaceAttr::get(self.getBuilder().getContext(), + addressSpace); + }) + .def("create_get_sub_vec_id", + [](AscendNPUIROpBuilder &self) -> Value { + auto subBlockIdxOp = self.create(); + auto moduleOp = subBlockIdxOp->getParentOfType(); + auto *ctx = self.getBuilder().getContext(); + // If user explicitly uses sub.block idx, add attribute to module. + // NPU compiler will parse this attribute and disable auto tile and + // bind subblock pass. + moduleOp->setAttr("hivm.disable_auto_tile_and_bind_subblock", + mlir::UnitAttr::get(ctx)); + return subBlockIdxOp; + }) + .def("sync_block_all", + [](AscendNPUIROpBuilder &self, std::string &mode, int id) -> void { + auto *ctx = self.getBuilder().getContext(); + auto [modeAttr, cubePipe, vectorPipe] = + GetSyncBlockModeAndPipes(ctx, mode); + mlir::IndexType indexType = mlir::IndexType::get(ctx); + mlir::IntegerAttr indexAttribute = + mlir::IntegerAttr::get(indexType, static_cast(id)); + self.create( + modeAttr, indexAttribute, mlir::Value{}, cubePipe, vectorPipe); + }) + .def("is_910_95", + [](AscendNPUIROpBuilder &self) -> bool { return self.is_910_95(); }) + .def("create_copy_buffer", + [](AscendNPUIROpBuilder &self, Value src, Value dst) { + self.create(mlir::TypeRange{}, src, dst); + }) + .def("create_copy_tensor", + [](AscendNPUIROpBuilder &self, Value src, Value dst) { + return self + .create(mlir::TypeRange{dst.getType()}, src, dst) + .getResult(0); + }) + .def("create_convert_layout", + [](AscendNPUIROpBuilder &self, Value src, Type memrefType) -> Value { + // src is a memref + // the layout is incorrect (temporarily) + auto *ctx = self.getBuilder().getContext(); + return self + .create( + memrefType, src, + hivm::DataLayoutAttr::get(ctx, hivm::DataLayout::ND), + hivm::DataLayoutAttr::get(ctx, hivm::DataLayout::ND)) + .getResult(); + }); +} diff --git a/third_party/ascend/backend/__init__.py b/third_party/ascend/backend/__init__.py index 0eec99724..ea66d3e62 100644 --- a/third_party/ascend/backend/__init__.py +++ b/third_party/ascend/backend/__init__.py @@ -1,2 +1,55 @@ -# -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import logging +from triton._C.libtriton.ascend import ir as ascend_ir + +from .testing import do_bench_npu + + +def _apply_ascend_patch(): + from triton.compiler.compiler import ASTSource + + if not getattr(ASTSource, "_ascend_patch_applied", False): + _original_make_ir = ASTSource.make_ir + + def _patched_make_ir(self, options, codegen_fns, module_map, context): + """ + Monkey Patch for Ascend: + Injects 'hacc.target' attribute into the module after generation. + """ + module = _original_make_ir(self, options, codegen_fns, module_map, context) + + if hasattr(options, "arch") and options.arch: + try: + builder = ascend_ir.ascendnpu_ir_builder(context, options.arch) + + target_attr_str = f'#hacc.target<"{options.arch}">' + module.set_attr("hacc.target", builder.parse_attr(target_attr_str)) + except Exception as e: + logging.warning(f"[Ascend Patch] Failed to set hacc.target: {e}") + + return module + + ASTSource.make_ir = _patched_make_ir + ASTSource._ascend_patch_applied = True + + +__all__ = ["do_bench_npu"] diff --git a/third_party/ascend/backend/backend_register.py b/third_party/ascend/backend/backend_register.py new file mode 100644 index 000000000..480f2a7fe --- /dev/null +++ b/third_party/ascend/backend/backend_register.py @@ -0,0 +1,321 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import os +from typing import Callable, Dict + + +class BackendStrategyRegistry: + + def __init__(self): + self.strategies: Dict[str:Dict[str, Callable]] = {} + + def register(self, category: str, method: str): + + def decorator(func: Callable): + if category not in self.strategies: + self.strategies[category] = {} + if method in self.strategies[category]: + raise ValueError(f"Strategy {name} already registered") + self.strategies[category][method] = func + return func + + return decorator + + def execute_func(self, category, method, *args, **kwargs): + if category not in self.strategies: + raise ValueError(f"Strategy {category} not registered") + if method not in self.strategies[category]: + raise ValueError(f"Strategy {method} not registered") + return self.strategies[category][method](*args, **kwargs) + + def list_categories(self): + return list(self.strategies.keys()) + + def list_methods(self, category): + if category not in self.strategies: + raise ValueError(f"Strategy {category} not registered") + return list(self.strategies[category].keys()) + + +class _LazyBackendStrategyRegister: + + def __init__(self): + self._instance = None + + def _get_instance(self): + if self._instance is None: + self._instance = BackendStrategyRegistry() + return self._instance + + def register(self, *args, **kwargs): + return self._get_instance().register(*args, **kwargs) + + def execute_func(self, *args, **kwargs): + return self._get_instance().execute_func(*args, **kwargs) + + +backend_strategy_registry = _LazyBackendStrategyRegister() + + +@backend_strategy_registry.register("mindspore", "version_hash") +def version_hash(): + import mindspore + return [str(mindspore.version)] + + +@backend_strategy_registry.register("torch_npu", "version_hash") +def version_hash(): + import torch + import torch_npu + return [torch.version.git_version, torch_npu.version.git_version] + + +@backend_strategy_registry.register("mindspore", "cxx_abi") +def get_mindspore_cxx_abi(): + return 0 + + +@backend_strategy_registry.register("torch_npu", "cxx_abi") +def get_torch_cxx_abi(): + import torch + return 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0 + + +@backend_strategy_registry.register("mindspore", "type_convert") +def type_convert(): + import mindspore + import numpy as np + MINDSPORE_TO_NUMPY_DTYPE = { + mindspore.float32: np.float32, + mindspore.float64: np.float64, + mindspore.float16: np.float16, + mindspore.int8: np.int8, + mindspore.uint8: np.uint8, + mindspore.int16: np.int16, + mindspore.int32: np.int32, + mindspore.int64: np.int64, + mindspore.bool: np.bool_, + mindspore.complex64: np.complex64, + mindspore.complex128: np.complex128, + } + return MINDSPORE_TO_NUMPY_DTYPE + + +@backend_strategy_registry.register("torch_npu", "type_convert") +def type_convert(): + import torch + import numpy as np + TORCH_TO_NUMPY_DTYPE = { + torch.float32: np.float32, + torch.float64: np.float64, + torch.float16: np.float16, + torch.int8: np.int8, + torch.uint8: np.uint8, + torch.int16: np.int16, + torch.int32: np.int32, + torch.int64: np.int64, + torch.bool: np.bool_, + torch.complex64: np.complex64, + torch.complex128: np.complex128, + } + return TORCH_TO_NUMPY_DTYPE + + +@backend_strategy_registry.register("mindspore", "get_device_interface") +def get_device_interface(): + import mindspore + return mindspore + + +@backend_strategy_registry.register("torch_npu", "get_device_interface") +def get_device_interface(): + import torch + return torch.npu + + +@backend_strategy_registry.register("mindspore", "get_empty_tensor") +def get_empty_tensor(size): + import mindspore + return mindspore.mint.empty(size, dtype=mindspore.int32) + + +@backend_strategy_registry.register("torch_npu", "get_empty_tensor") +def get_empty_tensor(size): + import torch + return torch.empty(size, dtype=torch.int32, device='npu') + + +@backend_strategy_registry.register("mindspore", "get_tensor_params_shape") +def get_tensor_params_shape(args): + import mindspore + tensor_params = [arg for arg in args if isinstance(arg, mindspore.Tensor)] + tensor_params_shape = [] + for t in tensor_params: + tensor_params_shape.append([s for s in t.shape]) + return tensor_params_shape + + +@backend_strategy_registry.register("torch_npu", "get_tensor_params_shape") +def get_tensor_params_shape(args): + import torch + tensor_params = [arg for arg in args if isinstance(arg, torch.Tensor)] + tensor_params_shape = [] + for t in tensor_params: + tensor_params_shape.append([s for s in t.shape]) + return tensor_params_shape + + +@backend_strategy_registry.register("mindspore", "get_cc_cmd") +def get_cc_cmd(build_pch): + import mindspore + mindspore_path = os.path.dirname(os.path.realpath(mindspore.__file__)) + cc_cmd = [ + f"-I{os.path.join(mindspore_path, 'include/third_party')}", + f"-I{os.path.join(mindspore_path, 'include/third_party/robin_hood_hashing')}", + f"-I{os.path.join(mindspore_path, 'include/mindspore/core')}", + f"-I{os.path.join(mindspore_path, 'include/mindspore/core/include')}", + f"-I{os.path.join(mindspore_path, 'include/mindspore/ccsrc')}", + f"-I{os.path.join(mindspore_path, 'include/mindspore/ccsrc/include')}", + f"-I{os.path.join(mindspore_path, 'include/mindspore/ops')}", + f"-I{os.path.join(mindspore_path, 'include/mindspore/ops/include')}", + f"-D_GLIBCXX_USE_CXX11_ABI={get_mindspore_cxx_abi()}", + "-DENABLE_FAST_HASH_TABLE=1", + ] + if not build_pch: + cc_cmd += [ + f"-L{os.path.join(mindspore_path, 'lib')}", + f"-lmindspore_pynative_utils", + ] + return cc_cmd + + +@backend_strategy_registry.register("torch_npu", "get_cc_cmd") +def get_cc_cmd(build_pch): + import torch + import torch_npu + torch_path = os.path.dirname(os.path.realpath(torch.__file__)) + torch_npu_path = os.path.dirname(os.path.realpath(torch_npu.__file__)) + cc_cmd = [ + f"-I{os.path.join(torch_path, 'include')}", + f"-I{os.path.join(torch_npu_path, 'include')}", + f"-D_GLIBCXX_USE_CXX11_ABI={get_torch_cxx_abi()}", + ] + if not build_pch: + cc_cmd += [ + f"-L{os.path.join(torch_npu_path, 'lib')}", + f"-ltorch_npu", + ] + return cc_cmd + + +@backend_strategy_registry.register("mindspore", "get_current_device") +def get_current_device(): + import mindspore + return mindspore.get_current_device().device_id + + +@backend_strategy_registry.register("torch_npu", "get_current_device") +def get_current_device(): + import torch + import torch_npu + return torch.npu.current_device() + + +@backend_strategy_registry.register("mindspore", "set_current_device") +def set_current_device(device_id): + import mindspore + return mindspore.set_device("Ascend", device_id) + + +@backend_strategy_registry.register("torch_npu", "set_current_device") +def set_current_device(device_id): + import torch + import torch_npu + return torch.npu.set_device(device_id) + + +@backend_strategy_registry.register("mindspore", "get_current_stream") +def get_current_stream(device): + import mindspore + return mindspore.current_stream().id + + +@backend_strategy_registry.register("torch_npu", "get_current_stream") +def get_current_stream(device): + import torch + import torch_npu + from torch_npu._C import _npu_getCurrentRawStream + if device is None: + device = torch.npu.current_device() + return _npu_getCurrentRawStream(device) + + +@backend_strategy_registry.register("mindspore", "header_file") +def header_file(enable_taskqueue): + return f'''#include "include/utils/device_manager_conf.h" +#include "include/runtime/hardware_abstract/device_context/device_context_manager.h" +{'#include "include/pynative/utils/runtime/op_executor.h"' if {enable_taskqueue} else ''} +{'#include "include/runtime/pipeline/pipeline.h"' if {enable_taskqueue} else ''}''' + + +@backend_strategy_registry.register("torch_npu", "header_file") +def header_file(enable_taskqueue): + return f'''#include +#include +{'#include ' if {enable_taskqueue} else ''}''' + + +@backend_strategy_registry.register("mindspore", "allocate_memory") +def allocate_memory(size, stream): + return f"device_context->device_res_manager_->AllocateMemory({size}, reinterpret_cast({stream}));" + + +@backend_strategy_registry.register("torch_npu", "allocate_memory") +def allocate_memory(size, option): + return f"const_cast(at::empty({size}, {option}).storage().data());" + + +@backend_strategy_registry.register("torch_npu", "allocate_sync_block_lock") +def allocate_sync_block_lock(size, stream): + return f"const_cast(at_npu::native::allocate_workspace({size}, {stream}).storage().data());" + + +@backend_strategy_registry.register("mindspore", "pre_launch") +def pre_launch(): + return '''static auto device_context = mindspore::device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({mindspore::device::DeviceType::kAscend, mindspore::DeviceManagerConf::GetInstance()->device_id()}); + device_context->device_res_manager_->BindDeviceToCurrentThread(false);''' + + +@backend_strategy_registry.register("torch_npu", "pre_launch") +def pre_launch(): + return "" + + +@backend_strategy_registry.register("mindspore", "async_launch") +def async_launch(func): + return f'''mindspore::runtime::OpExecutor::DispatchLaunchTask({func}); + mindspore::runtime::Pipeline::Get().launch_stage()->Wait();''' + + +@backend_strategy_registry.register("torch_npu", "async_launch") +def async_launch(func): + return f'''at_npu::native::OpCommand cmd; + cmd.Name(name.c_str()).SetCustomHandler({func}).Run();''' diff --git a/third_party/ascend/backend/compiler.py b/third_party/ascend/backend/compiler.py index 098272ef0..033af4c4e 100644 --- a/third_party/ascend/backend/compiler.py +++ b/third_party/ascend/backend/compiler.py @@ -1,3 +1,23 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + import ctypes import functools import hashlib @@ -241,6 +261,8 @@ def _parse_linalg_metadata(linalg: str, metadata: dict): """ # --- Regular expressions and examples --- + DISABLE_AUTO_TILE_AND_BIND_SUBBLOCK_REGEX = r'hivm.disable_auto_tile_and_bind_subblock' + # Example: mix_mode = "aiv" -> aiv MIX_MODE_REGEX = r'mix_mode\s*=\s*"([^"]+)"' @@ -259,6 +281,8 @@ def _parse_linalg_metadata(linalg: str, metadata: dict): # Note: Compiled Kernel requires to estimate size of shared memory to occupy # Currently, NPU backend does not limit on shared memory metadata["shared"] = 1 + # Force disable auto tile and bind subblock if attribute is present in module + metadata["auto_tile_and_bind_subblock"] = not re.search(DISABLE_AUTO_TILE_AND_BIND_SUBBLOCK_REGEX, linalg) # the mix mode is also encoded into metadata['name'] for runtime to distinguish metadata["mix_mode"] = re.search(MIX_MODE_REGEX, linalg).group(1) metadata["parallel_mode"] = re.search(PARALLEL_MODE_REGEX, linalg).group(1) @@ -320,11 +344,19 @@ def linalg_to_bin_enable_npu_compile_910_95(linalg: str, metadata, opt): bin_path = os.path.join(tmpdir, bin_file_with_ext) callback_path = os.path.join(tmpdir, "libkernel.so") _compile_option_list = get_common_bishengir_compile_options(metadata) + multibuffer = metadata["multibuffer"] if multibuffer is not None: _compile_option_list += [ f"--enable-auto-multi-buffer={multibuffer}", ] + + enable_ubuf_saving = metadata["enable_ubuf_saving"] + if enable_ubuf_saving is not None: + _compile_option_list += [ + f"--enable-ubuf-saving={enable_ubuf_saving}", + ] + enable_auto_bind_sub_block = metadata["enable_auto_bind_sub_block"] if enable_auto_bind_sub_block is not None: _compile_option_list += [ @@ -380,23 +412,67 @@ def linalg_to_bin_enable_npu_compile_910_95(linalg: str, metadata, opt): _compile_option_list += \ [f"--limit-auto-multi-buffer-of-local-buffer={auto_multi_buffer}"] + enable_mixed_cv = metadata["enable_mixed_cv"] + if enable_mixed_cv is not None: + _compile_option_list += \ + [f"--enable_mixed_cv={enable_mixed_cv}"] + + enable_cce_vf_auto_sync = metadata["enable_cce_vf_auto_sync"] + if enable_cce_vf_auto_sync is not None: + _compile_option_list += \ + [f"--apend-bisheng-options=-mllvm --cce-vf-auto-sync={enable_cce_vf_auto_sync}"] + + enable_cce_vf_remove_membar = metadata["enable_cce_vf_remove_membar"] + if enable_cce_vf_remove_membar is not None: + _compile_option_list += \ + [f"--apend-bisheng-options=-mllvm --cce-vf-remove-membar={enable_cce_vf_remove_membar}"] + + enable_drop_unit_dims = metadata["enable_drop_unit_dims"] + if enable_drop_unit_dims is not None: + _compile_option_list += \ + [f"--enable-drop-unit-dims={enable_drop_unit_dims}"] + + enable_auto_vectorize_v2 = metadata["enable_auto_vectorize_v2"] + if enable_auto_vectorize_v2 is not None: + _compile_option_list += \ + [f"--enable-auto-vectorize-v2={enable_auto_vectorize_v2}"] + + disable_auto_inject_block_sync = metadata["disable_auto_inject_block_sync"] + if disable_auto_inject_block_sync is not None: + _compile_option_list += \ + [f"--disable-auto-inject-block-sync={disable_auto_inject_block_sync}"] + if _is_auto_map_parallel_blocks_enabled(): _compile_option_list += ["--enable-auto-blockify-loop"] - npu_compiler_path = _get_npucompiler_path() + npu_compiler_path, env = _get_npucompiler_path() if npu_compiler_path.endswith("bishengir-compile"): _compile_option_list += [ "--enable-hfusion-compile=true", "--enable-triton-kernel-compile=true", ] + bisheng_options = metadata["bisheng_options"] + if bisheng_options is not None: + _compile_option_list += [f"--append-bisheng-options={bisheng_options}"] + mix_mode = opt.mix_mode + if mix_mode in ["aic"]: + _compile_option_list += ["--disable-hfusion-vectorize=true"] cmd_list = ([npu_compiler_path, ttadapter_path] + _compile_option_list + ["-o", bin_file]) - - ret = subprocess.run(cmd_list, capture_output=True, check=True) + # TODO both bishengir-compile and triton-compile use passing attr by module + auto_tile_and_bind_subblock = metadata["auto_tile_and_bind_subblock"] + if auto_tile_and_bind_subblock is False: + cmd_list += ["--enable-auto-bind-sub-block=false"] + vf_merge_level = metadata["vf_merge_level"] + if vf_merge_level: + cmd_list += [f"--enable-vf-merge-level={vf_merge_level}"] + + ret = subprocess.run(cmd_list, env=env, capture_output=True, check=True) match = re.search(r'UB\s+size\s*=\s*(\d+)\s*bits', ret.stdout.decode('utf-8')) if match: # get the ub bits of triton kernel from bisheng for inductor autotune using metadata["required_ub_bits"] = int(match.group(1)) if Path(callback_path).is_file(): lib = ctypes.CDLL(callback_path) + __get_metadata_attr_by_callback(lib, "_infer_task_type_function", metadata, "bs_task_type") __get_metadata_attr_by_callback(lib, "_infer_workspace_shape_function", metadata, "workspace_size") __get_metadata_attr_by_callback(lib, "_infer_sync_block_lock_num_function", metadata, "lock_num") __get_metadata_attr_by_callback(lib, "_infer_sync_block_lock_init_function", metadata, "lock_init_val") @@ -423,11 +499,19 @@ def linalg_to_bin_enable_npu_compile_A2_A3(linalg: str, metadata, opt): _compile_option_list = [ f"--target={NPUUtils().get_arch()}", ] + multibuffer = metadata["multibuffer"] if multibuffer is not None: _compile_option_list += [ f"--enable-auto-multi-buffer={multibuffer}", ] + + enable_ubuf_saving = metadata["enable_ubuf_saving"] + if enable_ubuf_saving is not None: + _compile_option_list += [ + f"--enable-ubuf-saving={enable_ubuf_saving}", + ] + enable_auto_bind_sub_block = metadata["enable_auto_bind_sub_block"] if enable_auto_bind_sub_block is not None: _compile_option_list += [ @@ -456,6 +540,16 @@ def linalg_to_bin_enable_npu_compile_A2_A3(linalg: str, metadata, opt): _compile_option_list += \ [f"--enable-hivm-unit-flag-sync={unit_flag}"] + enable_drop_unit_dims = metadata["enable_drop_unit_dims"] + if enable_drop_unit_dims is not None: + _compile_option_list += \ + [f"--enable-drop-unit-dims={enable_drop_unit_dims}"] + + enable_auto_vectorize_v2 = metadata["enable_auto_vectorize_v2"] + if enable_auto_vectorize_v2 is not None: + _compile_option_list += \ + [f"--enable-auto-vectorize-v2={enable_auto_vectorize_v2}"] + inject_barrier_all = metadata["inject_barrier_all"] if inject_barrier_all is not None: _compile_option_list += \ @@ -498,7 +592,7 @@ def linalg_to_bin_enable_npu_compile_A2_A3(linalg: str, metadata, opt): if _is_auto_map_parallel_blocks_enabled(): _compile_option_list += ["--enable-auto-blockify-loop"] - npu_compiler_path = _get_npucompiler_path() + npu_compiler_path, env = _get_npucompiler_path() if npu_compiler_path.endswith("bishengir-compile"): _compile_option_list += [ "--enable-hfusion-compile=true", @@ -506,13 +600,17 @@ def linalg_to_bin_enable_npu_compile_A2_A3(linalg: str, metadata, opt): "--enable-triton-kernel-compile=true", ] cmd_list = ([npu_compiler_path, ttadapter_path] + _compile_option_list + ["-o", bin_file]) - ret = subprocess.run(cmd_list, capture_output=True, check=True) + auto_tile_and_bind_subblock = metadata["auto_tile_and_bind_subblock"] + if auto_tile_and_bind_subblock is False: + cmd_list += ["--enable-auto-bind-sub-block=false"] + ret = subprocess.run(cmd_list, env=env, capture_output=True, check=True) match = re.search(r'UB\s+size\s*=\s*(\d+)\s*bits', ret.stdout.decode('utf-8')) if match: # get the ub bits of triton kernel from bisheng for inductor autotune using metadata["required_ub_bits"] = int(match.group(1)) if Path(callback_path).is_file(): lib = ctypes.CDLL(callback_path) + __get_metadata_attr_by_callback(lib, "_infer_task_type_function", metadata, "bs_task_type") __get_metadata_attr_by_callback(lib, "_infer_workspace_shape_function", metadata, "workspace_size") __get_metadata_attr_by_callback(lib, "_infer_sync_block_lock_num_function", metadata, "lock_num") __get_metadata_attr_by_callback(lib, "_infer_sync_block_lock_init_function", metadata, "lock_init_val") @@ -526,6 +624,7 @@ class NPUOptions: sanitize_overflow: bool = True llvm_version: int = 15 kernel_name: str = "triton_" + arch: str = "" cluster_dims: tuple = (1, 1, 1) num_warps: int = 4 @@ -538,23 +637,35 @@ class NPUOptions: reg_inc_consumer: int = 0 compile_on_910_95: bool = is_compile_on_910_95 - enable_linearize: bool = False + optimize_dynamic_offset: bool = False + enable_mask_fallback_conversion: bool = False enable_warp_specialization: bool = False enable_nd2nz_on_vector: bool = False enable_persistent: bool = False optimize_epilogue: bool = False enable_fp_fusion: bool = True allow_fp8e4nv: bool = False + auto_tile_and_bind_subblock: bool = True + vf_merge_level: int = 0 + supported_fp8_dtypes: Tuple[str] = ("fp8e5", "fp8e4b15", "fp8e4nv", "fp8e4b8", "fp8e5b16") + deprecated_fp8_dtypes: Tuple[str] = () + vf_merge_level: int = 1 allowed_dot_input_precisions: Tuple[str] = ("ieee", "hf32") - max_num_imprecise_acc_default: bool = None + max_num_imprecise_acc_default: int = 0 extern_libs: dict = None + bisheng_options: str = None multibuffer: bool = not is_compile_on_910_95 + enable_ubuf_saving: bool = None enable_auto_bind_sub_block: bool = not is_compile_on_910_95 - enable_select_analysis: bool = is_compile_on_910_95 + enable_select_analysis: bool = True enable_hivm_auto_cv_balance: bool = None sync_solver: bool = None unit_flag: bool = None + enable_cce_vf_auto_sync: bool = None + enable_cce_vf_remove_membar: bool = None + enable_drop_unit_dims: bool = None + enable_auto_vectorize_v2: bool = None inject_barrier_all: bool = None inject_block_all: bool = None limit_auto_multi_buffer_only_for_local_buffer: bool = None @@ -563,14 +674,22 @@ class NPUOptions: tile_mix_vector_loop: int = None tile_mix_cube_loop: int = None disable_auto_inject_block_sync: bool = None + enable_mixed_cv: bool = None stream: int = None parallel_mode: str = "simd" force_simt_only: bool = False force_simt_template: bool = False + # only take effect on the simt-only & simd-simt-mix scenarios + shared_mem_dynamic_size: int = 221184 + # enable_bishengir_simt_optimization is passed as + # -enable-bishengir-simt-optimization flag to bishengir-compile. + enable_bishengir_simt_optimization: int = 000 # compile_mode: "simd" (default), "unstructured_in_simt", "simt_only" # When compile_mode is provided, it automatically sets other fields compile_mode: str = "simd" + mix_mode: str = "" + simt_stack_limit: int = None def __post_init__(self): # Parse compile_mode and set related fields @@ -582,6 +701,7 @@ def __post_init__(self): elif self.compile_mode == "simt_only": object.__setattr__(self, "force_simt_only", True) object.__setattr__(self, "parallel_mode", "simt") + object.__setattr__(self, "shared_mem_dynamic_size", 122880) def hash(self): key = "_".join([f"{name}-{val}" for name, val in self.__dict__.items()]) @@ -604,7 +724,7 @@ class CPUOptions: optimize_epilogue: bool = False enable_fp_fusion: bool = True allow_fp8e4nv: bool = False - max_num_imprecise_acc_default: bool = None + max_num_imprecise_acc_default: int = 0 extern_libs: dict = None def hash(self): @@ -639,10 +759,18 @@ def ttir_to_npubin(mod, metadata, opt): _compile_option_list += ["--pure-simt"] _compile_option_list += [f"--num-warps={opt.num_warps}"] _compile_option_list += [f"--threads-per-warp={opt.warp_size}"] - - npu_compiler_path = _get_npucompiler_path() + if opt.enable_bishengir_simt_optimization != 000: + _compile_option_list += [ + f"--enable-bishengir-simt-optimization={opt.enable_bishengir_simt_optimization}" + ] + if opt.simt_stack_limit: + _compile_option_list += [f"--simt-stack-limit={opt.simt_stack_limit}"] + if opt.shared_mem_dynamic_size: + _compile_option_list += [f"--shared-mem-dynamic-size={opt.shared_mem_dynamic_size}"] + + npu_compiler_path, env = _get_npucompiler_path() cmd_list = ([npu_compiler_path, src_path] + _compile_option_list + ["-o", bin_file]) - ret = subprocess.run(cmd_list, capture_output=True, check=True) + ret = subprocess.run(cmd_list, env=env, capture_output=True, check=True) return Path(bin_path).read_bytes() @@ -663,6 +791,7 @@ def parse_options(self, opts) -> Any: # TODO: get available targets when building options? if self.target.backend == "npu": args = {k: opts[k] for k in NPUOptions.__dataclass_fields__.keys() if k in opts} + args.setdefault("arch", self.target.arch) options = NPUOptions(**args) else: args = {k: opts[k] for k in CPUOptions.__dataclass_fields__.keys() if k in opts} @@ -692,11 +821,13 @@ def pack_metadata(self, metadata): def get_codegen_implementation(self): # Note: a dict of functions is required to generate vendor-specific code piecies # e.g. convert custom types like fp8e4b15 + from triton.backends.ascend import _apply_ascend_patch + _apply_ascend_patch() codegen_fns = {"min_dot_size": min_dot_size(self.target)} return codegen_fns def load_dialects(self, ctx): - pass + ascend.load_dialects(ctx) def get_attrs_descriptor(self, params, args): return AscendAttrsDescriptor(params, args) diff --git a/third_party/ascend/backend/cpu_driver.py b/third_party/ascend/backend/cpu_driver.py index 2e00ad04d..783ac7620 100644 --- a/third_party/ascend/backend/cpu_driver.py +++ b/third_party/ascend/backend/cpu_driver.py @@ -1,3 +1,23 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + from triton.runtime.cache import get_cache_manager, get_dump_manager from pathlib import Path import tempfile diff --git a/third_party/ascend/backend/driver.py b/third_party/ascend/backend/driver.py index 9488be2bb..c1169d84a 100644 --- a/third_party/ascend/backend/driver.py +++ b/third_party/ascend/backend/driver.py @@ -1,3 +1,23 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + from pathlib import Path import tempfile import os @@ -11,17 +31,10 @@ from triton.runtime.cache import get_cache_manager, get_dump_manager from triton.backends.driver import DriverBase from triton.backends.compiler import GPUTarget -from triton.backends.ascend.utils import ( - _precompile_npu_hash, - _precompile_npu_ext, - _build_npu_ext, - _check_cxx11_abi, - convert_sigtype_to_int, - _is_auto_map_parallel_blocks_enabled, - get_ascend_arch_from_env, - is_ffts_supported, - force_disable_ffts, -) +from triton.backends.ascend.utils import (_precompile_npu_hash, _precompile_npu_ext, _build_npu_ext, _check_cxx11_abi, + convert_sigtype_to_int, _is_auto_map_parallel_blocks_enabled, + get_ascend_arch_from_env, is_ffts_supported, force_disable_ffts, + get_backend_func) class NPUUtils(object): @@ -126,11 +139,7 @@ def __call__(self, *args, **kwargs): print("[INFO]: skip running kernel") print(f"[INFO]: The compiled kernel cache is in {cache_manager.cache_dir}") if self.enable_msprof_register_tensor: - import torch - tensor_params = [arg for arg in args if isinstance(arg, torch.Tensor)] - tensor_params_shape = [] - for t in tensor_params: - tensor_params_shape.append([s for s in t.shape]) + tensor_params_shape = get_backend_func("get_tensor_params_shape", *args) # args[5] must be the packed metadata. # Check the launch wrapper in which PyArg_ParseTuple specifies the ordering of args args[5]['tensor_params_shape'] = tensor_params_shape @@ -150,16 +159,6 @@ def __init__(self): # flagtree backend specialization from triton.backends.ascend import spec self.spec = spec - from triton.language.core import spec_core_func - spec_core_func(spec) - from triton.language.semantic import spec_semantic_func - spec_semantic_func(spec) - from triton.language.standard import spec_standard_func - spec_standard_func(spec) - from triton.language.math import spec_math_func - spec_math_func(spec) - from triton.testing import spec_testing_func - spec_testing_func(spec) super().__init__() @classmethod @@ -194,17 +193,13 @@ def get_current_device(self): """ Get current device """ - import torch - import torch_npu - return torch.npu.current_device() + return get_backend_func("get_current_device") def set_current_device(self, device): """ Set current device as the given device """ - import torch - import torch_npu - return torch.npu.set_device(device) + return get_backend_func("set_current_device", device) def get_current_stream(self, device: Optional[int] = None) -> int: """ @@ -212,25 +207,31 @@ def get_current_stream(self, device: Optional[int] = None) -> int: """ # According to torch_npu, the content of a torch.npu.Stream is essentilly an rtStream_t # TODO: use CANN API instead of torchnpu - import torch - import torch_npu - from torch_npu._C import _npu_getCurrentRawStream - if device is None: - device = self.get_current_device() - return _npu_getCurrentRawStream(device) + return get_backend_func("get_current_stream", device) def get_benchmarker(self): from triton.testing import do_bench return do_bench def get_device_interface(self): - import torch - return torch.npu + return get_backend_func("get_device_interface") def get_empty_cache_for_benchmark(self): - import torch cache_size = 192 * 1024 * 1024 - return torch.empty(cache_size // 4, dtype=torch.int, device='npu') + return get_backend_func("get_empty_tensor", cache_size // 4) + + +# fixed the issue of corrupted gch header files in multi-threaded scenarios. +def _precompile_npu_ext_with_lock(header_path): + import fcntl + src_path = os.path.dirname(header_path) + lock_path = os.path.join(src_path, "precompiled.lock") + with open(lock_path, "a+") as f: + try: + fcntl.flock(f, fcntl.LOCK_EX) + _precompile_npu_ext(header_path) + finally: + fcntl.flock(f, fcntl.LOCK_UN) def make_npu_launcher_stub(header_src, wrapper_src, debug=False): @@ -244,7 +245,7 @@ def make_npu_launcher_stub(header_src, wrapper_src, debug=False): # if precompile header file and its gch file not exist, do precompile if header_path is None and gch_path is None: header_path = cache.put(header_src, "precompiled.h", binary=False) - _precompile_npu_ext(header_path) + _precompile_npu_ext_with_lock(header_path) # try to get cached file so_cache_key = hashlib.sha256(wrapper_src.encode("utf-8")).hexdigest() @@ -389,11 +390,9 @@ def generate_npu_header_src(): #include #include #include -#include "experiment/runtime/runtime/rt.h" -#include +#include "runtime/runtime/rt.h" #include -#include -{'#include ' if enable_taskqueue else ''} +{get_backend_func("header_file", enable_taskqueue)} #endif """ @@ -408,6 +407,7 @@ def generate_npu_wrapper_src(constants, signature, metadata): if hasattr(metadata, 'lock_init_value') else 0 lock_num = int(metadata.lock_num) \ if hasattr(metadata, 'lock_num') else -1 + bs_task_type = metadata.bs_task_type if hasattr(metadata, 'bs_task_type') else 0 mix_mode = metadata.mix_mode compile_on_910_95 = metadata.compile_on_910_95 parallel_mode = metadata.parallel_mode @@ -422,6 +422,9 @@ def _ty_to_cpp(ty): "i16": "int16_t", "i32": "int32_t", "i64": "int64_t", + "u1": "uint32_t", + "u8": "uint8_t", + "u16": "uint16_t", "u32": "uint32_t", "u64": "uint64_t", "fp16": "float", @@ -436,8 +439,13 @@ def _extracted_ty(ty): return "PyObject*" return { 'i1': 'int32_t', + 'i8': 'int8_t', + 'i16': 'int16_t', 'i32': 'int32_t', 'i64': 'int64_t', + 'u1': 'uint32_t', + 'u8': 'uint8_t', + 'u16': 'uint16_t', 'u32': 'uint32_t', 'u64': 'uint64_t', 'fp16': 'float', @@ -453,12 +461,34 @@ def _format_of(ty): "float": "f", "double": "d", "long": "l", - "uint32_t": "I", + "int8_t": "b", + "int16_t": "h", "int32_t": "i", + "int64_t": "l", + "uint8_t": "B", + "uint16_t": "H", + "uint32_t": "I", "uint64_t": "K", - "int64_t": "L", }[ty] + def _format_of_msprof_task_type_ratio(bs_task_type, mix_mode): + # Default fallback based on mix_mode + default_task_type = "MSPROF_GE_TASK_TYPE_AIV" if mix_mode == "aiv" else "MSPROF_GE_TASK_TYPE_AI_CORE" + + if not bs_task_type: + return default_task_type, 0 + + task_type_num, mix_block_dim_ratio = divmod(int(bs_task_type), 10) + task_type_map = { + 1: "MSPROF_GE_TASK_TYPE_AIV", + 2: "MSPROF_GE_TASK_TYPE_AI_CORE", + 3: "MSPROF_GE_TASK_TYPE_MIX_AIC", + 4: "MSPROF_GE_TASK_TYPE_MIX_AIV", + } + + task_type = task_type_map.get(task_type_num, default_task_type) + return task_type, mix_block_dim_ratio + arg_decls = ', '.join(f"{_ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) """ args: @@ -482,7 +512,8 @@ def _format_of(ty): enable_auto_map_parallel_blocks = _is_auto_map_parallel_blocks_enabled() npu_utils = NPUUtils() num_physical_blocks = npu_utils.get_aivector_core_num() if mix_mode == "aiv" else npu_utils.get_aicore_num() - task_type = "MSPROF_GE_TASK_TYPE_AIV" if mix_mode == "aiv" else "MSPROF_GE_TASK_TYPE_AI_CORE" + task_type, mix_block_dim_ratio = _format_of_msprof_task_type_ratio(bs_task_type, mix_mode) + is_mix_task_type = "true" if ("MIX" in task_type) else "false" LINE_CHANGE_CHAR = chr(10) # it is \n alloc_success_code = 'return 1;' sync_lock_fail_code = 'fprintf(stderr, "Error: syncBlockLock allocation failed\\n"); return;' @@ -601,9 +632,30 @@ def _format_of(ty): nodeBasicInfo.data.nodeBasicInfo.opName = opNameHashID; nodeBasicInfo.data.nodeBasicInfo.opType = opNameHashID; nodeBasicInfo.data.nodeBasicInfo.taskType = {task_type}; - nodeBasicInfo.data.nodeBasicInfo.blockDim = blockNum; + nodeBasicInfo.data.nodeBasicInfo.blockDim = nodeBasicBlockDim; MsprofReportCompactInfo(0, static_cast(&nodeBasicInfo), sizeof(MsprofCompactInfo)); + // 'mix' kernel need to report the ctxID + if ({is_mix_task_type} > 0) {{ + MsprofAdditionalInfo info; + info.level = MSPROF_REPORT_NODE_LEVEL; + info.type = MSPROF_REPORT_NODE_CONTEXT_ID_INFO_TYPE; + info.threadId = threadId; + info.timeStamp = endTime; + MsprofContextIdInfo ctxId; + ctxId.opName = opNameHashID; + ctxId.ctxIdNum = 1; + for (uint32_t i = 0; i < ctxId.ctxIdNum; i++) {{ + ctxId.ctxIds[i] = i; + }} + size_t copyLen = sizeof(MsprofContextIdInfo); + if (copyLen > MSPROF_ADDTIONAL_INFO_DATA_LENGTH) {{ + copyLen = MSPROF_ADDTIONAL_INFO_DATA_LENGTH; + }} + memcpy(info.data, &ctxId, copyLen); + MsprofReportAdditionalInfo(false, static_cast(&info), sizeof(MsprofAdditionalInfo)); + }} + // Report tensor info int max_tensors_num = tensorShapes.size() < MSPROF_GE_TENSOR_DATA_NUM ? tensorShapes.size() : MSPROF_GE_TENSOR_DATA_NUM; MsprofAdditionalInfo tensorInfo; @@ -655,12 +707,12 @@ def _format_of(ty): ret = rtKernelLaunch(func, blockNum, static_cast(&args), sizeof(args), NULL, stream); """ if compile_on_910_95 and enable_simt: - cpp_kernel_launch = """ - rtArgsEx_t argsInfo = {}; + cpp_kernel_launch = f""" + rtArgsEx_t argsInfo = {{}}; argsInfo.args = static_cast(&args); argsInfo.argsSize = sizeof(args); - rtTaskCfgInfo_t cfgInfo = {}; - cfgInfo.localMemorySize = 216 * 1024; + rtTaskCfgInfo_t cfgInfo = {{}}; + cfgInfo.localMemorySize = {metadata.shared_mem_dynamic_size}; ret = rtKernelLaunchWithFlagV2(func, blockNum, &argsInfo, NULL, stream, 0, &cfgInfo); """ @@ -688,6 +740,13 @@ def _format_of(ty): // base_ptr offset shape and stride are not used, arbitrarily set for now std::string name = ""; name.append(kernelName); + void *workspace_addr_ptr = NULL; + uint32_t blockNum4Workspace = gridX * gridY * gridZ; + {f''' + uint64_t totalWorkSpaceSize = {workspace_size} * blockNum4Workspace; + auto optionsWorkspace = at::TensorOptions().device(at::kPrivateUse1).dtype(at::kByte); + workspace_addr_ptr = {get_backend_func("allocate_memory", "totalWorkSpaceSize", "optionsWorkspace")} + ''' if workspace_size > 0 else ''} {'auto launch_call = [=]() -> rtError_t' if enable_taskqueue else ''} {{ uint32_t blockNum = gridX * gridY * gridZ; @@ -698,20 +757,22 @@ def _format_of(ty): warned = true; }} #endif - + {get_backend_func("pre_launch")} {'blockNum = std::min(blockNum, (uint32_t)' + str(num_physical_blocks) + ');' if enable_auto_map_parallel_blocks else ''} + // set mixBlockNumRation for nodeBasicBlockDim for msprof report + uint32_t mixBlockNumRation = {mix_block_dim_ratio}; + uint32_t nodeBasicBlockDim = (mixBlockNumRation << 16) + blockNum; + {'cce::internal::DebugTunnelData *DTData = cce::internal::DebugTunnel::Open(blockNum);' if enable_device_print else ''} rtError_t ret; {'void *ffts_addr = NULL; uint32_t ffts_len; ret = rtGetC2cCtrlAddr((uint64_t*)&ffts_addr, &ffts_len);' if target_support_ffts else ''} {'if (ret != RT_ERROR_NONE) return ret;' if (target_support_ffts and enable_taskqueue) else 'if (ret != RT_ERROR_NONE) return;' if (target_support_ffts and (not enable_taskqueue)) else ''} // stub argument for workspace void *syncBlockLock_ptr = NULL; - void *workspace_addr_ptr = NULL; uint16_t ModuleId = 0; {f''' uint64_t syncBlockLockSize = {lock_num} * sizeof(int64_t); - at::Tensor syncBlockLock_tensor = at_npu::native::allocate_workspace(syncBlockLockSize, stream); - syncBlockLock_ptr = const_cast(syncBlockLock_tensor.storage().data()); + syncBlockLock_ptr = {get_backend_func("allocate_sync_block_lock", "syncBlockLockSize", "stream")} if (!syncBlockLock_ptr) {{ {alloc_success_code if enable_taskqueue else sync_lock_fail_code} }} @@ -725,14 +786,6 @@ def _format_of(ty): return {'ret' if enable_taskqueue else ''}; }} ''' if lock_num > 0 else ''} - {f''' - uint64_t totalWorkSpaceSize = {workspace_size} * blockNum; - at::Tensor workspace_tensor = at_npu::native::allocate_workspace(totalWorkSpaceSize, stream); - workspace_addr_ptr = const_cast(workspace_tensor.storage().data()); - if (!workspace_addr_ptr) {{ - {alloc_success_code if enable_taskqueue else workspace_fail_code} - }} - ''' if workspace_size > 0 else ''} {'if (ret != RT_ERROR_NONE) return ret;' if (workspace_size > 0 and enable_taskqueue) else 'if (ret != RT_ERROR_NONE) return;' if (workspace_size > 0 and not enable_taskqueue) else ''} struct __attribute__((packed)) {{ {'void* ffts_addr __attribute__((aligned(8)));' if target_support_ffts else ''} @@ -745,7 +798,9 @@ def _format_of(ty): {'static_cast(ffts_addr),' if target_support_ffts else ''} {('static_cast(syncBlockLock_ptr),' if lock_num > 0 else 'nullptr,') if not metadata.force_simt_only else ''} {('static_cast(workspace_addr_ptr),' if workspace_size > 0 else 'nullptr,') if not metadata.force_simt_only else ''} - {(', '.join(f'static_cast<{_ty_to_cpp(ty)}>(arg{i})' for i, ty in signature.items() if i not in constants) + ',') if len(signature) > 0 else ''} + {(lambda _rt: (', '.join(_rt) + ',') if _rt else '')( + [f'static_cast<{_ty_to_cpp(ty)}>(arg{i})' for i, ty in signature.items() if i not in constants] + )} {', '.join(f'static_cast<{_ty_to_cpp(ty)}>(grid{mark})' for mark, ty in grid_info.items())} {', static_cast(DTData)' if enable_device_print else ''} }}; @@ -756,7 +811,7 @@ def _format_of(ty): {cpp_msprof_call_after_launch} {'return ret;' if enable_taskqueue else 'ret = rtStreamSynchronize(stream);'} }}; - {'at_npu::native::OpCommand cmd; cmd.Name(name.c_str()).SetCustomHandler(launch_call).Run();' if enable_taskqueue else ''} + {f'''{get_backend_func("async_launch", "launch_call") if enable_taskqueue else ''}'''} return; }} diff --git a/third_party/ascend/backend/npu_utils.cpp b/third_party/ascend/backend/npu_utils.cpp index 530300ddb..28c4ec135 100644 --- a/third_party/ascend/backend/npu_utils.cpp +++ b/third_party/ascend/backend/npu_utils.cpp @@ -1,3 +1,25 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + #define PY_SSIZE_T_CLEAN #include @@ -9,7 +31,7 @@ #include #include -#include "experiment/runtime/runtime/rt.h" +#include "runtime/runtime/rt.h" // Use map to differentiate same name functions from different binary static std::unordered_map registered_names; diff --git a/third_party/ascend/examples/pytest_ut/conftest.py b/third_party/ascend/backend/runtime/__init__.py similarity index 78% rename from third_party/ascend/examples/pytest_ut/conftest.py rename to third_party/ascend/backend/runtime/__init__.py index 633db8b1f..4609a317d 100644 --- a/third_party/ascend/examples/pytest_ut/conftest.py +++ b/third_party/ascend/backend/runtime/__init__.py @@ -18,16 +18,16 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. -import pytest -import torch +def _patch_autotune(): + try: + import triton + except ImportError: + return -@pytest.fixture(scope="session", autouse=True) -def assign_npu(worker_id): - npu_count = torch.npu.device_count() - if worker_id == "master": - npu_id = 0 - else: - idx = int(worker_id.replace("gw", "")) - npu_id = idx % npu_count - torch.npu.set_device(npu_id) + from .autotuner import autotune + + triton.autotune = autotune + + +_patch_autotune() diff --git a/third_party/ascend/backend/spec/triton/runtime/autoparser.py b/third_party/ascend/backend/runtime/autoparser.py similarity index 53% rename from third_party/ascend/backend/spec/triton/runtime/autoparser.py rename to third_party/ascend/backend/runtime/autoparser.py index 5b1e7014a..2176d71e5 100644 --- a/third_party/ascend/backend/spec/triton/runtime/autoparser.py +++ b/third_party/ascend/backend/runtime/autoparser.py @@ -1,3 +1,23 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + import ast from typing import Dict, List, Union @@ -43,7 +63,89 @@ def contains_target_var(self, node, var): return False -class SplitAxesParser(AutoParser): +class AxesKeyParser(AutoParser): + """ + A parser for extracting axis information from a given function's AST. + This class is designed to handle specific patterns in the function's code to + determine the axis associated with a given variable. It is particularly useful + for parsing triton DSL kernel code and identifying axis information. + It recursively processes assignment nodes and lessthan nodes to obtain the axes + corresponding to the specified var in the given function. + """ + + def __init__(self, func_ast: ast.AST, keys: Dict[str, str]): + super().__init__(func_ast) + self.keys = keys + self.checked_vars = list() + + def get_axis(self, var: str, node=None): + """ + Traverse the AST using the provided variable name and mask-based less-than + operations to obtain the corresponding axis name. + + :param var: the variable name to get the corresponding axis. + :type var: str + """ + if var in self.checked_vars: + return None + axis = None + if not node: + node = self.func_ast + for child_node in ast.walk(node): + # handle compare node + if isinstance(child_node, ast.Compare): + axis = self.handle_lt_node(var, child_node) + elif isinstance(child_node, ast.Assign): + axis = self.handle_assign_node(var, child_node) + if axis is not None: + return axis + self.checked_vars.append(var) + return None + + def handle_assign_node(self, var, node): + if not isinstance(node, ast.Assign) or not isinstance(node.targets, list): + return None + if len(node.targets) != 1 or not isinstance(node.targets[0], ast.Name): + return None + + target = node.targets[0].id + if target in self.checked_vars: + return None + # Prevent cyclic assignment. + if var == target or not self.contains_target_var(node.value, var): + return None + + axis = self.get_axis(var, node.value) + if axis: + return axis + + axis = self.get_axis(target) + return axis + + def handle_lt_node(self, var, node): + if not isinstance(node, ast.Compare) or not isinstance(node.ops, list): + return None + if len(node.ops) != 1 or not isinstance(node.ops[0], ast.Lt): + return None + if not isinstance(node.comparators, list) or len(node.comparators) != 1: + return None + if not isinstance(node.left, ast.Name) or var != node.left.id: + return None + + comparator = node.comparators[0] + if not isinstance(comparator, ast.Name) and \ + not (isinstance(comparator, ast.Call) and \ + isinstance(comparator.func, ast.Name) and \ + comparator.func.id == 'min'): + return None + + for k, v in self.keys.items(): + if self.contains_target_var(comparator, v): + return k + return None + + +class SplitAxesParser(AxesKeyParser): """ Extracts the split axis parameters from triton kernel code. The parsing is based on the `tl.program_id` statement. This class identifies potential split axes by analyzing the usage @@ -51,7 +153,7 @@ class SplitAxesParser(AutoParser): variables(currently supporting scenarios where multiplication is either direct or indirect via intermediate variables). It then filters these candidates based on a list of candidate parameters (parameters not provided by the user). After that, it confirms the split axis corresponding to - the current parameter using mask comparison and the `key` passed in `autotune`. + the current parameter using mask comparison and the `keys` passed in `autotune`. Note: 1. Split axis parameters must be multiplied with `tl.program_id`. @@ -61,23 +163,21 @@ class SplitAxesParser(AutoParser): only those parameters that can be dynamically adjusted through the autotune process are considered. """ - def __init__(self, func_ast: ast.AST, key, candidates_params: List[str]): + def __init__(self, func_ast: ast.AST, keys: Dict[str, str], candidates_params: List[str]): """ :param func_ast: Abstract syntax tree of the triton kernel function :type func_ast: ast.AST - :param key: a dict of axis name: argument name, used to confirm the split axis corresponding to + :param keys: a dict of axis name: argument name, used to confirm the split axis corresponding to the split axis parameters. - :type key: Dict[str, str] - :param candidatas_params: a list of parameters names that were not provided by the user when calling + :type keys: Dict[str, str] + :param candidates_params: a list of parameters names that were not provided by the user when calling triton kernel function. The parser will only consider these parameters as potential split axis parameters. :type candidates_params: List[str] """ - super().__init__(func_ast) + super().__init__(func_ast, keys) self.split_axes = dict() - self.key = key self.program_id_vars = list() - self.checked_vars = list() self.candidates_params = candidates_params def parse(self) -> Dict[str, str]: @@ -115,50 +215,20 @@ def visit_BinOp(self, node): if split_axes_val in self.candidates_params and \ split_axes_val not in self.split_axes.values(): - split_axes_key = self.get_split_axes_key(split_axes_val) + split_axes_key = self.get_axis(split_axes_val) if split_axes_key: self.split_axes[split_axes_key] = split_axes_val self.generic_visit(node) - def get_split_axes_key(self, var): - for node in ast.walk(self.func_ast): - if isinstance(node, ast.Compare): - if not isinstance(node.left, ast.Name) or \ - not isinstance(node.comparators[0], ast.Name): - continue - if var == node.left.id: - compared_var = node.comparators[0].id - elif var == node.comparators[0].id: - compared_var = node.left.id - else: - continue - for k, v in self.key.items(): - if v == compared_var: - return k - if isinstance(node, ast.Assign): - if len(node.targets) == 1 and \ - isinstance(node.targets[0], ast.Name) and \ - var != node.targets[0].id: # Prevent cyclic assignment. - if not self.contains_target_var(node.value, var): - continue - target_var = node.targets[0].id - if target_var in self.checked_vars: - continue - key = self.get_split_axes_key(target_var) - if key is not None: - return key - self.checked_vars.append(var) - return None - -class TilingAxesParser(AutoParser): +class TilingAxesParser(AxesKeyParser): """ Extracts the tiling axis parameters from triton kernel code. The parsing is based on the - `tl.arange` and `tl.range` statement. This class identifies potential tiling axes by analyzing - the usage of the `tl.arange` and `tl.range` within `for` loop in the program. Common parameters - between `tl.range` and `tl.arange` are extracted. It then filters these candidates based on a + `tl.arange`, `tl.range` and `range()` statement. This class identifies potential tiling axes by analyzing + the usage of the `range` and `tl.range` within `for` loop in the program. Common parameters + between `range()` or `tl.range` and `tl.arange` are extracted. It then filters these candidates based on a list of candidate parameters (parameters not provided by the user). After that, it confirms the - tiling axis corresponding to the current parameter using mask comparison and the `key` passed + tiling axis corresponding to the current parameter using mask comparison and the `keys` passed in `autotune`. Note: @@ -170,22 +240,20 @@ class TilingAxesParser(AutoParser): only those parameters that can be dynamically adjusted through the autotune process are considered. """ - def __init__(self, func_ast: ast.AST, key, candidates_params: List[str]): + def __init__(self, func_ast: ast.AST, keys: Dict[str, str], candidates_params: List[str]): """ :param func_ast: Abstract syntax tree of the triton kernel function :type func_ast: ast.AST - :param key: a dict of axis name: argument name, used to confirm the tiling axis corresponding to + :param keys: a dict of axis name: argument name, used to confirm the tiling axis corresponding to the tiling axis parameters. - :type key: Dict[str, str] - :param candidatas_params: a list of parameters names that were not provided by the user when calling + :type keys: Dict[str, str] + :param candidates_params: a list of parameters names that were not provided by the user when calling triton kernel function. The parser will only consider these parameters as potential tiling axis parameters. :type candidates_params: List[str] """ - super().__init__(func_ast) + super().__init__(func_ast, keys) self.tiling_axes = dict() - self.key = key - self.checked_vars = list() self.candidates_params = candidates_params self.candidates_params_for_loop = list() @@ -198,16 +266,26 @@ def visit_For(self, node): len(node.iter.args) == 3 and \ isinstance(node.iter.args[2], ast.Name): for_loop_param = node.iter.args[2].id - if for_loop_param in self.candidates_params: + if for_loop_param in self.candidates_params and \ + for_loop_param not in self.candidates_params_for_loop: self.candidates_params_for_loop.append(for_loop_param) self.generic_visit(node) def visit_Assign(self, node): if len(node.targets) == 1 and isinstance(node.targets[0], ast.Name): + # handle FloorDiv + if isinstance(node.value, ast.BinOp) and isinstance(node.value.op, ast.FloorDiv): + denominator = node.value.right + if isinstance(denominator, ast.Name) and \ + denominator.id in self.candidates_params and \ + denominator.id not in self.candidates_params_for_loop: + self.candidates_params_for_loop.append(denominator.id) + self.visit(self.func_ast) + tiling_axes_val = self.get_tiling_axes_val(node.value) if tiling_axes_val is not None and \ tiling_axes_val in self.candidates_params_for_loop: - tiling_axes_key = self.get_tiling_axes_key(tiling_axes_val) + tiling_axes_key = self.get_axis(tiling_axes_val) if tiling_axes_key: self.tiling_axes[tiling_axes_key] = tiling_axes_val self.generic_visit(node) @@ -218,7 +296,9 @@ def get_tiling_axes_val(self, node): isinstance(node.func.value, ast.Name) and \ node.func.value.id == 'tl': if isinstance(node.args, list) and len(node.args) == 2: - return node.args[1].id + for param in self.candidates_params_for_loop: + if self.contains_target_var(node.args[1], param): + return param for _, value in ast.iter_fields(node): if isinstance(value, list): @@ -232,38 +312,161 @@ def get_tiling_axes_val(self, node): return val return None - def get_tiling_axes_key(self, var): + +class ReductionAxesParser(AxesKeyParser): + """ + Extracts the reduction axis from triton kernel code. The parsing is based on the + reduction function (eg. tl.max, tl.min, tl.sum, ...). This class identifies the + dimensions of reduction operations by analyzing the reduction function calls in + the program. After that, It confirms the reduction axis corresponding to the current + parameter using mask comparison and the keys passed in autotune. + + Note: + 1. The call to the reduction function must start with 'tl', meaning the function must + be a function from triton.language + 2. It's preferable to specify the reduction axis dimension in the reduction function + using keyword arguments(eg. axis=xxx). Otherwise, specifying it via positional + arguments may lead to errors. + 3. Mask comparison must be performed on the potential reduction axis length parameters, + and the comparison parameters or target parameters of the comparison expression + must be sliced. Otherwise, the correspondence between dimensions and axes cannot + be confirmed, which will lead to failure in parsing the reduction axis. + 4. The identified reduction axes are limited to the candidate list provided in the keys. + """ + + def __init__(self, func_ast: ast.AST, keys: Dict[str, str]): + """ + :param func_ast: Abstract syntax tree of the triton kernel function + :type func_ast: ast.AST + :param keys: a dict of axis name: argument name, used to confirm the reduction axis. + :type keys: Dict[str, str] + """ + super().__init__(func_ast, keys) + self.reduction_axes = list() + self.reduction_func = ('sum', 'xor_sum', 'max', 'min', 'argmax', 'argmin') # tl.xxx + + def parse(self) -> List[str]: + super().parse() + return self.reduction_axes + + def visit_Call(self, node): + if not isinstance(node.func, ast.Attribute): + return + func = node.func + if not isinstance(func.value, ast.Name) or func.value.id != 'tl': + self.generic_visit(node) + return + if func.attr not in self.reduction_func: + return + + args = node.args + if len(args) == 1: + keywords = node.keywords + for keyword in keywords: + if keyword.arg == 'axis': + if isinstance(keyword.value, ast.Constant): + axis_dim = keyword.value.value + elif len(args) == 2: + if isinstance(args[1], ast.Constant): # check the second param + axis_dim = args[1].value + else: + return + + reduction_axis = self.get_axis(axis_dim) + if reduction_axis and reduction_axis not in self.reduction_axes: + self.reduction_axes.append(reduction_axis) + + def get_axis(self, axis_dim: int): + """ + Override the parent class method to accept an integer axis dimension + instead of a string. + + :param axis_dim: + :type axis_dim: int + """ + if axis_dim in self.checked_vars: + return None + self.checked_vars.append(axis_dim) for node in ast.walk(self.func_ast): - if isinstance(node, ast.Compare): - if not isinstance(node.left, ast.Name) or \ - not isinstance(node.comparators[0], ast.Name): - continue - if var == node.left.id: - compared_var = node.comparators[0].id - elif var == node.comparators[0].id: - compared_var = node.left.id - else: - continue - for k, v in self.key.items(): - if v == compared_var: - return k - elif isinstance(node, ast.Assign): - if len(node.targets) == 1 and \ - isinstance(node.targets[0], ast.Name) and \ - var != node.targets[0].id: # Prevent cyclic assignment. - if not self.contains_target_var(node.value, var): - continue - target_var = node.targets[0].id - if target_var in self.checked_vars: - continue - key = self.get_tiling_axes_key(target_var) - if key is not None: - return key - self.checked_vars.append(var) + if not isinstance(node, ast.Assign): + continue + reduction_axis = self.handle_assign_node(axis_dim, node) + if reduction_axis: + return reduction_axis return None + def handle_assign_node(self, axis_dim: int, node): + if not isinstance(node.value, ast.Compare): + return None + + # only support less than + if len(node.value.ops) != 1 or not isinstance(node.value.ops[0], ast.Lt): + return None + + target_axis_len = None + for axis_len in self.keys.values(): + if self.contains_target_var(node.value, axis_len): + target_axis_len = axis_len + break + if not target_axis_len: + return None + + # handel compare left var + if isinstance(node.value.left, ast.Name): + if self.check_compare_left(node.value.left.id, axis_dim): + reduction_axis = next((k for k, v in self.keys.items() if target_axis_len == v), None) + if reduction_axis and reduction_axis not in self.reduction_axes: + return reduction_axis + # handel compare target var + if len(node.targets) == 1 and isinstance(node.targets[0], ast.Name): + if self.check_compare_target(node.targets[0].id, axis_dim): + reduction_axis = next((k for k, v in self.keys.items() if target_axis_len == v), None) + if reduction_axis and reduction_axis not in self.reduction_axes: + return reduction_axis + return None -class LowDimsAxesParser(AutoParser): + def check_compare_left(self, var, axis_dim): + for node in ast.walk(self.func_ast): + if not isinstance(node, ast.Assign): + continue + if len(node.targets) != 1 or \ + not isinstance(node.targets[0], ast.Name) or \ + node.targets[0].id != var: + continue + if self.is_current_dim_slice(node.value, axis_dim): + return True + return False + + def check_compare_target(self, var, axis_dim, node=None): + if not node: + node = self.func_ast + for _, value in ast.iter_fields(node): + if isinstance(value, list): + for item in value: + if self.check_compare_target(var, axis_dim, item): + return True + elif isinstance(value, ast.AST): + if isinstance(value, ast.Subscript): + if not isinstance(value.value, ast.Name) or value.value.id != var: + continue + if self.is_current_dim_slice(value, axis_dim): + return True + else: + if self.check_compare_target(var, axis_dim, value): + return True + return False + + def is_current_dim_slice(self, node, dim): + for node in ast.walk(node): + if not isinstance(node, ast.Subscript) or not isinstance(node.slice, ast.Tuple): + continue + elts = node.slice.elts + if len(elts) != 0 and isinstance(elts[dim], ast.Slice): + return True + return False + + +class LowDimsAxesParser(AxesKeyParser): """ Extracts the low dimensions axis from triton kernel code. The parsing is based on the `tl.arange` statement. This class identifies low dimensions axis by analyzing the usage @@ -271,7 +474,7 @@ class LowDimsAxesParser(AutoParser): their associated operations. Then it checks if these variables are involved in slicing operations to determine dimension expansion and filters out variables that are expanded in non-lowest dimensions. After that, it compares the extracted variables with the provided - `key` to map them to specific low-dimensional axis. + `keys` to map them to specific low-dimensional axis. Note: 1. low dimensions axis must be calculated within the `tl.arange` function and involved in @@ -280,20 +483,19 @@ class LowDimsAxesParser(AutoParser): would lead to parameter parsing failure. (eg. mask = offsets < n_elements). """ - def __init__(self, func_ast: ast.AST, key: Dict[str, str]): + def __init__(self, func_ast: ast.AST, keys: Dict[str, str]): """ :param func_ast: Abstract syntax tree of the triton kernel function :type func_ast: ast.AST - :param key: a dict of axis name: argument name, used to confirm the low-dimensional axis. - :type key: Dict[str, str] + :param keys: a dict of axis name: argument name, used to confirm the low-dimensional axis. + :type keys: Dict[str, str] """ - super().__init__(func_ast) + super().__init__(func_ast, keys) self.low_dims_axis = list() - self.key = key - self.checked_compared_vars = list() + self.keys = keys self.checked_slice_vars = list() - def parse(self): + def parse(self) -> List[str]: super().parse() return self.low_dims_axis @@ -304,12 +506,12 @@ def visit_Assign(self, node): if isinstance(tl_arange_node, ast.Call): partin_other_slice = [False] if self.is_partin_low_dim_slice(node.targets[0].id, partin_other_slice): - low_dims_axis = self.get_low_dims_axes(node.targets[0].id) - if not partin_other_slice[0]: - low_dims_axis = self.get_low_dims_axes(node.targets[0].id) + low_dims_axis = self.get_axis(node.targets[0].id) + elif not partin_other_slice[0]: + low_dims_axis = self.get_axis(node.targets[0].id) elif isinstance(tl_arange_node, ast.Subscript) and \ self.is_low_dim_slice(tl_arange_node, [False]): - low_dims_axis = self.get_low_dims_axes(node.targets[0].id) + low_dims_axis = self.get_axis(node.targets[0].id) if low_dims_axis and low_dims_axis not in self.low_dims_axis: self.low_dims_axis.append(low_dims_axis) @@ -351,61 +553,36 @@ def is_low_dim_slice(self, node: ast.Subscript, partin_other_slice): if not isinstance(node.slice, ast.Tuple) or not isinstance(node.slice.elts, list): return False elts = node.slice.elts - if len(elts) != 0 and not isinstance(elts[len(elts) - 1], ast.Slice): + if len(elts) != 0 and not isinstance(elts[-1], ast.Slice): partin_other_slice[0] = True return False return True - def is_partin_low_dim_slice(self, var, partin_other_slice): - for node in ast.walk(self.func_ast): - if isinstance(node, ast.Subscript) and isinstance(node.value, ast.Name): - if var == node.value.id and self.is_low_dim_slice(node, partin_other_slice): + def is_partin_low_dim_slice(self, var, partin_other_slice, node=None): + if not node: + node = self.func_ast + for child_node in ast.walk(node): + if isinstance(child_node, ast.Subscript) and isinstance(child_node.value, ast.Name): + if var == child_node.value.id and self.is_low_dim_slice(child_node, partin_other_slice): return True - elif isinstance(node, ast.Assign): - if len(node.targets) == 1 and \ - isinstance(node.targets[0], ast.Name) and \ - var != node.targets[0].id: # Prevent cyclic assignment. - if not self.contains_target_var(node.value, var): + elif isinstance(child_node, ast.Assign): + if len(child_node.targets) == 1 and \ + isinstance(child_node.targets[0], ast.Name) and \ + var != child_node.targets[0].id: # Prevent cyclic assignment. + if not self.contains_target_var(child_node.value, var): continue - target_var = node.targets[0].id + target_var = child_node.targets[0].id if target_var in self.checked_slice_vars: continue + + if self.is_partin_low_dim_slice(var, partin_other_slice, child_node.value): + return True if self.is_partin_low_dim_slice(target_var, partin_other_slice): return True self.checked_slice_vars.append(var) return False - def get_low_dims_axes(self, var): - for node in ast.walk(self.func_ast): - if isinstance(node, ast.Compare): - if not isinstance(node.left, ast.Name) or \ - not isinstance(node.comparators[0], ast.Name): - continue - if var == node.left.id: - compared_var = node.comparators[0].id - elif var == node.comparators[0].id: - compared_var = node.left.id - else: - continue - for k, v in self.key.items(): - if v == compared_var: - return k - elif isinstance(node, ast.Assign): - if len(node.targets) == 1 and \ - isinstance(node.targets[0], ast.Name) and \ - var != node.targets[0].id: # Prevent cyclic assignment. - if not self.contains_target_var(node.value, var): - continue - target_var = node.targets[0].id - if target_var in self.checked_compared_vars: - continue - key = self.get_low_dims_axes(target_var) - if key is not None: - return key - self.checked_compared_vars.append(var) - return None - class PtrNumsParser(AutoParser): """ @@ -427,10 +604,12 @@ class PtrNumsParser(AutoParser): the input parameter through two or more levels of computation are not counted. """ - def __init__(self, func_ast: ast.AST, miss_params: List[str]): + def __init__(self, func_ast: ast.AST, keys: Dict[str, str], miss_params: List[str]): """ :param func_ast: Abstract syntax tree of the triton kernel function :type func_ast: ast.AST + :param keys: a dict of axis name: argument name, used to exclude potential ptr params. + :type keys: Dict[str, str] :param miss_params: a list of parameters names that were not provided by the user when calling triton kernel function. :type miss_params: List[str] @@ -439,6 +618,7 @@ def __init__(self, func_ast: ast.AST, miss_params: List[str]): self.checked_vars = list() self.ptr_nums = 0 self.ptr_params = list() + self.keys = keys self.miss_params = miss_params self.constexpr_params = list() @@ -461,7 +641,7 @@ def visit_FunctionDef(self, node): self.constexpr_params.append(arg.arg) continue - if self.is_in_addr_calc(arg.arg): + if self.is_in_addr_calc(arg.arg) and arg.arg not in self.keys.values(): self.ptr_params.append(arg.arg) self.ptr_nums += 1 diff --git a/third_party/ascend/backend/runtime/autotuner.py b/third_party/ascend/backend/runtime/autotuner.py new file mode 100644 index 000000000..3ea038105 --- /dev/null +++ b/third_party/ascend/backend/runtime/autotuner.py @@ -0,0 +1,528 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# Copyright 2018-2020 Philippe Tillet +# Copyright 2020-2022 OpenAI +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +from __future__ import annotations + +import builtins +import os +import time +import copy +from typing import Dict, List +from torch import Tensor + +from triton.runtime.autotuner import Autotuner, Config + +from .utils import get_byte_per_numel, is_valid_axis_name, valid_axis_names +from .autoparser import SplitAxesParser, TilingAxesParser, ReductionAxesParser, LowDimsAxesParser, PtrNumsParser + + +class AutoTilingTuner(Autotuner): + """ + Automatic generateing candidate tiling configs and evaluating their performance to get the best config. + """ + + def __init__( + self, + fn, + arg_names, + configs, + key, + reset_to_zero, + restore_value, + pre_hook=None, + post_hook=None, + prune_configs_by: Dict = None, + warmup=None, + rep=None, + use_cuda_graph=False, + do_bench=None, + auto_profile_dir=None, + hints=None, + ): + """ + :param key: a list of argument name, where the change of arguments in value will triger re-generating candidates configs and evaluating. + The parameters in the list will be assigned axis names in sequence, with the axis name being in + {'x','y','z','w','v','t','rx','ry','rz','rw','rv','rt}, where the prefix 'r' means a reduction axis. + Only the axis name in this param should add perfix 'r' if it's a reduction axis. + :type key: List[str] + """ + super().__init__( + fn, + arg_names, + configs, + key, + reset_to_zero, + restore_value, + pre_hook, + post_hook, + prune_configs_by, + warmup, + rep, + use_cuda_graph, + do_bench, + ) + if not hints: + self.hints = {} + else: + self.hints = hints + split_params = self.hints.get("split_params", None) + tiling_params = self.hints.get("tiling_params", None) + low_dim_axes = self.hints.get("low_dim_axes", None) + reduction_axes = self.hints.get("reduction_axes", None) + self._init_axis_params(key, split_params, tiling_params, low_dim_axes, reduction_axes) + + self.auto_gen_config = not configs or self.hints.get("auto_gen_config", False) + self.gen_configs = [] # generated configs from TileGenerator + self.auto_profile_dir = auto_profile_dir + if not configs: + self.user_configs = [] + else: + self.user_configs = configs + self.is_simt_mode = False + self.user_specified_warps = None + self.print_autotuning = os.getenv("TRITON_PRINT_AUTOTUNING", None) == "1" + + def _init_axis_params(self, key, split_params, tiling_params, low_dim_axes, reduction_axes): + if isinstance(key, list): + if (split_params or tiling_params or low_dim_axes or reduction_axes): + raise ValueError( + "If any axis-related parameters (split_params, tiling_params, low_dim_axes, reduction_axes)" + " are provided, 'key' must be a dict, not a list.") + if len(key) > len(valid_axis_names): + raise ValueError("Number of parameters exceeds the number of available axes.") + self.keys = {axis: param for axis, param in zip(valid_axis_names, key)} + elif isinstance(key, dict): + if not set(key.keys()).issubset(set(valid_axis_names)): + raise ValueError("All keys in 'key' must be valid axis names. Got unexpected keys.") + self.keys = key + if any([split_params, tiling_params, low_dim_axes, reduction_axes]) is None: + raise ValueError( + "If 'key' is a dict, all axis-related parameters (split_params, tiling_params, low_dim_axes," + " reduction_axes) must be provided.") + if not isinstance(split_params, dict): + raise ValueError("split_params must be a dict, got: {}".format(type(split_params))) + if not isinstance(tiling_params, dict): + raise ValueError("tiling_params must be a dict, got: {}".format(type(tiling_params))) + if not isinstance(low_dim_axes, list): + raise ValueError("low_dim_axes must be a list, got: {}".format(type(low_dim_axes))) + if not isinstance(reduction_axes, list): + raise ValueError("reduction_axes must be a list, got: {}".format(type(reduction_axes))) + + used_axes = set(split_params.keys()).union( + tiling_params.keys(), + low_dim_axes, + reduction_axes, + ) + if not used_axes.issubset(self.keys.keys()): + raise ValueError( + "The following axes are used but not present in the 'key': {}".format(used_axes - + set(self.keys.keys()))) + + self.split_params = split_params + self.tiling_params = tiling_params + self.low_dim_axes = low_dim_axes + self.reduction_axes = reduction_axes + self.dual_reduction = False + self.persistent_reduction = False + self.num_buffers = -1 + + def _autoparse_axis_params(self, all_args): + miss_params = [arg for arg in self.arg_names if arg not in all_args.keys()] + # parse pointer params nums + if self.num_buffers == -1: + self.num_buffers = self._autoparse_ptr_nums(all_args) + + # parse autotiling axes + # reduction axis must be parsed before other axes. it will alter the key + if not self.reduction_axes: + self.reduction_axes = self._autoparse_reduction_axes() + if len(self.reduction_axes) >= 2: + self.dual_reduction = True + + if not self.low_dim_axes: + self.low_dim_axes = self._autoparse_low_dim_axes() + + if len(self.reduction_axes) == 1 and \ + self.reduction_axes[0] == self.low_dim_axes[0] and \ + all_args.get(self.keys[self.reduction_axes[0]], float("inf")) < 1024: + self.persistent_reduction = True + + if not self.split_params: + self.split_params = self._autoparse_split_params(miss_params) + miss_params = [arg for arg in miss_params if arg not in self.split_params.values()] + if not self.tiling_params: + self.tiling_params = self._autoparse_tiling_params(miss_params) + miss_params = [arg for arg in miss_params if arg not in self.tiling_params.values()] + if miss_params: + raise ValueError(f"Missing required arguments: {miss_params}. " + f"These arguments must be explicitly provided and cannot be automatically tuned. " + f"Please ensure that these arguments are passed when calling the function.") + + def _gen_tile_configs(self, kv_dict: Dict[str, int], dtype: torch.dtype) -> List[Config]: + from .tile_generator import KernelMeta, TileGenerator + + axis_sizes = {} + for k, v in kv_dict.items(): + if not is_valid_axis_name(k): + continue + if not isinstance(v, int): + raise ValueError(f"Not supported dim type: {type(v)}, `int` is the only supported type") + axis_sizes[k] = v + + kernel_meta = KernelMeta( + axis_sizes, + self.split_params, + self.tiling_params, + self.low_dim_axes, + dtype, + self.persistent_reduction, + self.dual_reduction, + self.num_buffers, + self.is_simt_mode, + ) + tile_gen = TileGenerator(kernel_meta=kernel_meta) + tile_gen.descend_split_tiling() + + self.gen_configs.clear() + self.gen_configs = tile_gen.configs + + if self.is_simt_mode: + _default_cand_num_warps = [8, 16, 32, 64] + cand_num_warps = (_default_cand_num_warps + if self.user_specified_warps is None else [self.user_specified_warps]) + simt_configs = [] + for base_cfg in self.gen_configs: + for num_warps in cand_num_warps: + new_cfg = copy.deepcopy(base_cfg) + new_cfg.num_warps = num_warps + simt_configs.append(new_cfg) + + if self.print_autotuning: + print(f"Triton autotuning: Expanded to {len(simt_configs)} SIMT configs (with warps: {cand_num_warps})") + + self.gen_configs = simt_configs + + if len(self.gen_configs) == 0: + print("[WARNING] The generated candidate tiling configs are empty based on provided parameters!") + + if self.print_autotuning: + print("Generated configs number: {}".format(len(self.gen_configs))) + + def generate_key_and_configs(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + self.is_simt_mode = kwargs.get('force_simt_only', False) + if 'num_warps' in kwargs and kwargs['num_warps'] is not None: + self.user_specified_warps = kwargs['num_warps'] + + # generate key + all_args = {**self.nargs, **kwargs} + _args = {k: v for (k, v) in all_args.items() if k in self.arg_names} + key = [_args[v] for _, v in self.keys.items() if v in _args] + + # Currently, we use the dtype with maximum byte length + dtype = None + for _, arg in _args.items(): + if hasattr(arg, "dtype"): + key.append(str(arg.dtype)) + dtype = (arg.dtype if get_byte_per_numel(arg.dtype) >= get_byte_per_numel(dtype) else dtype) + if dtype is None: + raise NotImplementedError("Not support for non-Tensor inputs") + + key = tuple(key) + if key not in self.cache: + if self.auto_gen_config: + self._autoparse_axis_params(all_args) + _kv_dict = {k: _args[v] for k, v in self.keys.items() if v in _args} + self._gen_tile_configs(_kv_dict, dtype) + if len(self.gen_configs) == 0 and len(self.user_configs) == 0: + self.configs = [ + Config( + {}, + num_warps=4, + num_stages=2, + num_ctas=1, + num_buffers_warp_spec=0, + num_consumer_groups=0, + reg_dec_producer=0, + reg_inc_consumer=0, + ) + ] + else: + self.configs = self.gen_configs + self.user_configs + return key + + def run(self, *args, **kwargs): + key = self.generate_key_and_configs(*args, **kwargs) + used_cached_result = True + if key not in self.cache: + # prune configs + pruned_configs = self.prune_configs(kwargs) + if len(pruned_configs) > 1: + used_cached_result = False + bench_start = time.time() + timings = self._batch_bench(*args, configs=pruned_configs, **kwargs) + bench_end = time.time() + self.bench_time = bench_end - bench_start + self.cache[key] = builtins.min(timings, key=timings.get) + full_nargs = {**self.nargs, **kwargs, **self.cache[key].all_kwargs()} + self.pre_hook(full_nargs, reset_only=True) + self.configs_timings = timings + config = self.cache[key] + else: + config = pruned_configs[0] + else: + config = self.cache[key] + + self.best_config = config + if self.print_autotuning and not used_cached_result: + print(f"Triton autotuning for function {self.base_fn.__name__} finished after " + f"{self.bench_time:.2f}s; best config selected: {self.best_config};") + + if not used_cached_result and self.auto_profile_dir is not None: + self._profile(*args, config=self.best_config, **kwargs) + if config.pre_hook is not None: + full_nargs = {**self.nargs, **kwargs, **config.all_kwargs()} + config.pre_hook(full_nargs) + final_kwargs = dict(config.all_kwargs(), **kwargs) + ret = self.fn.run( + *args, + **final_kwargs, + ) + self.nargs = None + return ret + + def _batch_bench(self, *args, configs, **kwargs): + from triton.compiler.errors import CompileTimeAssertionFailure, MLIRCompilationError + from triton.runtime.errors import OutOfResources + + kernels_call = {config: self._make_kernel_call(*args, config=config, **kwargs) for config in configs} + run_fns = {} + exc = None + exc_stack = "" + + for config, fn in kernels_call.items(): + try: + fn() + run_fns[config] = fn + except (CompileTimeAssertionFailure, MLIRCompilationError, OutOfResources) as e: + import traceback + exc_stack = traceback.format_exc() + exc = e + + if len(run_fns) == 0: + raise RuntimeError(f"No valid triton configs. {type(exc).__name__}: {exc} \nStack trace: {exc_stack}") + + if len(run_fns) == 1: + # we ignore expensive profiling method when only single config is left + return {config: self.do_bench(fn) for config, fn in run_fns.items()} + + use_profiling = os.getenv("TRITON_BENCH_METHOD", "default").lower() == "npu" + if use_profiling: + from ..testing import do_bench_npu + + time_cost = do_bench_npu(list(run_fns.values()), clear_l2_cache=False) + assert len(time_cost) == len(run_fns) + return {config: cost for config, cost in zip(run_fns.keys(), time_cost)} + else: + return {config: self.do_bench(fn) for config, fn in run_fns.items()} + + def _make_kernel_call(self, *args, config, **meta): + # check for conflicts, i.e. meta-parameters both provided + # as kwargs and by the autotuner + conflicts = meta.keys() & config.kwargs.keys() + if conflicts: + raise ValueError(f"Conflicting meta-parameters: {', '.join(conflicts)}." + " Make sure that you don't re-define auto-tuned symbols.") + # augment meta-parameters with tunable ones + current = dict(meta, **config.all_kwargs()) + full_nargs = {**self.nargs, **current} + + def kernel_call(): + if config.pre_hook: + config.pre_hook(full_nargs) + self.pre_hook(full_nargs) + try: + self.fn.run( + *args, + **current, + ) + except Exception as e: + try: + self.post_hook(full_nargs, exception=e) + finally: + # Throw exception raised by `self.fn.run` + raise + + self.post_hook(full_nargs, exception=None) + + return kernel_call + + def warmup(self, *args, **kwargs): + _ = self.generate_key_and_configs(*args, **kwargs) + pruned_configs = self.prune_configs(kwargs) + ret = [] + for config in pruned_configs: + ret.append(self.fn.warmup(*args, **kwargs, **config.all_kwargs())) + self.nargs = None + return ret + + def _profile(self, *args, config, **meta): + from ..testing import do_bench_npu + + kernel_call = self._make_kernel_call(*args, config=config, **meta) + do_bench_npu(kernel_call, prof_dir=self.auto_profile_dir, keep_res=True) + + def _autoparse_split_params(self, candidates_params: List[str]) -> Dict[str, str]: + """ + Extracts the split axis parameters from triton kernel code. + """ + func_ast = self.fn.parse() + parser = SplitAxesParser(func_ast, self.keys, candidates_params) + split_axes = parser.parse() + if self.print_autotuning: + print(f"Ascend autotuning parse split axes: {split_axes}") + return split_axes + + def _autoparse_tiling_params(self, candidates_params: List[str]) -> Dict[str, str]: + """ + Extracts the tiling axis parameters from triton kernel code. + """ + func_ast = self.fn.parse() + parser = TilingAxesParser(func_ast, self.keys, candidates_params) + tiling_axes = parser.parse() + if self.print_autotuning: + print(f"Ascend autotuning parse tiling axes: {tiling_axes}") + return tiling_axes + + def _autoparse_reduction_axes(self) -> List[str]: + """ + Extracts the reduction axis parameters from triton kernel code. + """ + func_ast = self.fn.parse() + parser = ReductionAxesParser(func_ast, self.keys) + reduction_axes = parser.parse() + for axis in reduction_axes: + self.keys[f"r{axis}"] = self.keys.pop(axis) + reduction_axes = [f"r{axis}" for axis in reduction_axes] + + if self.print_autotuning: + print(f"Ascend autotuning parse keys: {self.keys} \n" + f"Ascend autotuning parse reduction axes: {reduction_axes}") + return reduction_axes + + def _autoparse_low_dim_axes(self) -> List[str]: + """ + Extracts the low dimension axis from triton kernel code. + """ + func_ast = self.fn.parse() + parser = LowDimsAxesParser(func_ast, self.keys) + low_dim_axes = parser.parse() + if len(low_dim_axes) < 1: + raise ValueError(f"Failed to parse low-dimensional axes.") + if self.print_autotuning: + print(f"Ascend autotuning parse low dimensional axes: {low_dim_axes}") + return low_dim_axes + + def _autoparse_ptr_nums(self, all_args: dict) -> int: + """ + Counts the number of pointer parameters from triton kernel code. + """ + ptr_nums = 0 + ptr_params = list() + for k, v in all_args.items(): + if isinstance(v, Tensor): + ptr_nums += 1 + ptr_params.append(k) + + if self.print_autotuning: + print(f"Ascend autotuning parse pointer params: {ptr_params}, pointer nums: {ptr_nums}") + return ptr_nums + + +def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_value=None, pre_hook=None, post_hook=None, + warmup=None, rep=None, use_cuda_graph=False, do_bench=None, *, auto_prof_dir=None, hints=None): + """ + Decorator for auto-tuning a :code:`triton.jit`'d function. + + .. highlight:: python + .. code-block:: python + + @triton.autotune(configs=[ + triton.Config(kwargs={'BLOCK_SIZE': 128}, num_warps=4), + triton.Config(kwargs={'BLOCK_SIZE': 1024}, num_warps=8), + ], + key=['x_size'] # the two above configs will be evaluated anytime + # the value of x_size changes + ) + @triton.jit + def kernel(x_ptr, x_size, **META): + BLOCK_SIZE = META['BLOCK_SIZE'] + :note: When all the configurations are evaluated, the kernel will run multiple times. + This means that whatever value the kernel updates will be updated multiple times. + To avoid this undesired behavior, you can use the `reset_to_zero` argument, which + resets the value of the provided tensor to `zero` before running any configuration. + + If the environment variable :code:`TRITON_PRINT_AUTOTUNING` is set to + :code:`"1"`, Triton will print a message to stdout after autotuning each + kernel, including the time spent autotuning and the best configuration. + + :param configs: a list of :code:`triton.Config` objects + :type configs: list[triton.Config] + :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs. + :type key: list[str] + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs. + :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs. + :type reset_to_zero: list[str] + :param restore_value: a list of argument names whose value will be restored after evaluating any configs. + :type restore_value: list[str] + :param pre_hook: a function that will be called before the kernel is called. + This overrides the default pre_hook used for 'reset_to_zero' and 'restore_value'. + 'kwargs': a dict of all arguments passed to the kernel. + 'reset_only': a boolean indicating whether the pre_hook is called to reset the values only, without a corresponding post_hook. + :type pre_hook: lambda args, reset_only + :param post_hook: a function that will be called after the kernel is called. + This overrides the default post_hook used for 'restore_value'. + 'kwargs': a dict of all arguments passed to the kernel. + 'exception': the exception raised by the kernel in case of a compilation or runtime error. + :type post_hook: lambda args, exception + :param warmup: warmup time (in ms) to pass to benchmarking (deprecated). + :type warmup: int + :param rep: repetition time (in ms) to pass to benchmarking (deprecated). + :type rep: int + :param do_bench: a benchmark function to measure the time of each run. + :type do_bench: lambda fn, quantiles + :param auto_prof_dir: the specified directory to store the profiling results of the best config. + If this parameter is None or the best config is retrieved from cache, the profiling process will be ignored. + :type auto_prof_dir: str + :param hints: a dict of autotune hint auguments passed to AutoTilingTuner. + """ + + def decorator(fn): + return AutoTilingTuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, pre_hook=pre_hook, + post_hook=post_hook, prune_configs_by=prune_configs_by, warmup=warmup, rep=rep, + use_cuda_graph=use_cuda_graph, do_bench=do_bench, auto_profile_dir=auto_prof_dir, + hints=hints) + + return decorator diff --git a/third_party/ascend/backend/spec/triton/runtime/tile_generator.py b/third_party/ascend/backend/runtime/tile_generator.py similarity index 86% rename from third_party/ascend/backend/spec/triton/runtime/tile_generator.py rename to third_party/ascend/backend/runtime/tile_generator.py index 260947bc6..9adf01395 100644 --- a/third_party/ascend/backend/spec/triton/runtime/tile_generator.py +++ b/third_party/ascend/backend/runtime/tile_generator.py @@ -1,3 +1,25 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# Copyright 2018-2020 Philippe Tillet +# Copyright 2020-2022 OpenAI +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + from __future__ import annotations import functools @@ -9,13 +31,15 @@ Tuple, ) +from triton.runtime.autotuner import Config + from .utils import ( get_byte_per_numel, next_power_of_2, num_vector_core, - num_ub_max, + ub_size_in_kbytes, + rf_size_in_kbytes, ) -from .autotuner import Config @dataclass @@ -46,7 +70,8 @@ def __init__( dtype: torch.dtype, persistent_reduction: bool, dual_reduction: bool, - input_ptr_num: int, + num_buffers: int, + is_simt_mode: bool, ): """ :param split_params: a dict of axis name: argument name, the argument is an adjustable parameter in a split axis, such as 'XBLOCK'. @@ -73,7 +98,6 @@ def __init__( prefix = "" if name.startswith("r"): prefix = "r" - name = name[1:] is_split_axis = name in split_params is_tiling_axis = name in tiling_params @@ -99,7 +123,8 @@ def __init__( self.dtype = dtype self.persistent_reduction = persistent_reduction self.dual_reduction = dual_reduction - self.input_ptr_num = input_ptr_num + self.num_buffers = num_buffers + self.is_simt_mode = is_simt_mode @classmethod def _validate_axis( @@ -151,8 +176,10 @@ def __init__(self, kernel_meta: KernelMeta): self.configs = [] self.dtype_bytes = get_byte_per_numel(kernel_meta.dtype) - self.input_ptr_num = 3 if kernel_meta.input_ptr_num == 0 else min(kernel_meta.input_ptr_num, 3) - self.max_numel_threshold = num_ub_max // self.input_ptr_num * 1024 + self.num_buffers = 3 if kernel_meta.num_buffers == 0 else min(kernel_meta.num_buffers, 3) + self.is_simt_mode = kernel_meta.is_simt_mode + local_mem_size = (rf_size_in_kbytes if self.is_simt_mode else ub_size_in_kbytes) + self.max_numel_threshold = local_mem_size * 1024 // self.dtype_bytes // self.num_buffers self.max_total_numel = functools.reduce(lambda x, y: x * y, [x.block_size for x in self.blocks]) if self.blocks else 1 self.tiny_kernel = self.max_total_numel < 128 * 1024 @@ -189,17 +216,14 @@ def calcu_last_split_blocks(self, axis_idx): last_blocks = (self.numels[axis_idx] + last_splits - 1) // last_splits return last_blocks - def aligned_numel(self, numel): - min_numel = 32 // self.dtype_bytes - if numel <= min_numel: - return numel - aligned = ((numel + min_numel - 1) // min_numel) * min_numel - return aligned + def aligned_numel(self, numel, align_bytes=32): + if self.is_simt_mode: + return next_power_of_2(numel) - def valid_tile_numel(self, tile_numel): - byte_num = self.dtype_bytes - max_numel = self.max_numel_threshold // byte_num - return tile_numel <= max_numel + align_numel = align_bytes // self.dtype_bytes + if numel <= align_numel: + return numel + return ((numel + align_numel - 1) // align_numel) * align_numel def calculate_tile_numel(self): tile_numel = 1 @@ -217,7 +241,10 @@ def fill_config(self, cfg, candi_block): continue block_info = self.blocks[axis.index] if axis.is_split_axis: - cfg[block_info.block_name] = candi_block[axis.index] + curr_numel = candi_block[axis.index] + if not axis.is_tiling_axis: + curr_numel = self.aligned_numel(curr_numel) + cfg[block_info.block_name] = curr_numel if axis.is_tiling_axis: tiling_numel = self.aligned_numel(block_info.sub_block_size) cfg[block_info.sub_block_name] = min(tiling_numel, candi_block[axis.index]) @@ -233,7 +260,8 @@ def add_to_configs(self, candi_block): self.fill_config(newcfg, candi_block) tile_numel = self.calculate_tile_numel() stop_numel_threshold = 0 if len(self.configs) < 10 or self.tiny_kernel else self.stop_numel + 100 - if self.valid_tile_numel(tile_numel) and not self.find_config(newcfg) and tile_numel >= stop_numel_threshold: + if (tile_numel <= self.max_numel_threshold and tile_numel >= stop_numel_threshold + and not self.find_config(newcfg)): self.configs.append(Config(newcfg, num_warps=1, num_stages=1)) return True return False @@ -313,11 +341,10 @@ def calc_total_programs(): if not slow_decend_split: self.blocks[axis_idx].block_size = numel // 2 - self.blocks[axis_idx].sub_block_size = self.blocks[axis_idx].block_size else: step = numel // 4 if numel // 4 > 1 else 1 self.blocks[axis_idx].block_size = numel - step - self.blocks[axis_idx].sub_block_size = self.blocks[axis_idx].block_size + self.blocks[axis_idx].sub_block_size = self.blocks[axis_idx].block_size total_programs = calc_total_programs() if self.blocks[axis_idx].block_size == 1 and (total_programs > program_threshold or self.dual_reduction): @@ -325,8 +352,7 @@ def calc_total_programs(): else: if numel >= 32: self.blocks[axis_idx].sub_block_size = next_power_of_2(numel // 2) - else: # numel >4 and numel < 128 : - numel = self.blocks[axis_idx].sub_block_size + else: self.blocks[axis_idx].sub_block_size = numel - 1 return reached_stop_numel @@ -348,7 +374,7 @@ def descend_all_axis(min_numel): continue if numel >= 128: self.blocks[axis.index].sub_block_size = next_power_of_2(numel // 2) - else: # numel >4 and numel < 128 : + else: numel = self.blocks[axis.index].sub_block_size numel = numel // 2 self.blocks[axis.index].sub_block_size = min(self.aligned_numel(numel), next_power_of_2(numel)) diff --git a/third_party/ascend/backend/runtime/utils.py b/third_party/ascend/backend/runtime/utils.py new file mode 100644 index 000000000..f470fc61d --- /dev/null +++ b/third_party/ascend/backend/runtime/utils.py @@ -0,0 +1,121 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# Copyright 2018-2020 Philippe Tillet +# Copyright 2020-2022 OpenAI +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import torch + +_cached_params = None + + +def _init_npu_params(): + global _cached_params + if _cached_params is not None: + return _cached_params + + from triton.runtime.driver import driver + + target = driver.active.get_current_target() + device = driver.active.get_current_device() + prop = driver.active.utils.get_device_properties(device) + + num_cube_core = prop["num_aicore"] + num_vector_core = prop["num_aicore"] + ub_size_in_kbytes = 192 + rf_size_in_kbytes = None + + ASCEND_VARIANTS = ["Ascend910B", "Ascend910_93", "Ascend910_95"] + if any(variant in target.arch for variant in ASCEND_VARIANTS): + num_vector_core = num_cube_core * 2 + + if '910_95' in target.arch: + ub_size_in_kbytes = 256 + rf_size_in_kbytes = 128 + + _cached_params = { + 'target': target, + 'device': device, + 'prop': prop, + 'num_cube_core': num_cube_core, + 'num_vector_core': num_vector_core, + 'ub_size_in_kbytes': ub_size_in_kbytes, + 'rf_size_in_kbytes': rf_size_in_kbytes, + } + return _cached_params + + +def __getattr__(name): + if name in [ + 'target', 'device', 'prop', 'num_cube_core', 'num_vector_core', 'ub_size_in_kbytes', 'rf_size_in_kbytes' + ]: + return _init_npu_params()[name] + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") + + +# wrapper npu 32 bytes align, get and pass unalign info to triton meta +# then autotune choose tiling param and send them to bishengIR +byte_per_numel = { + torch.float32: 4, # torch.float32 or torch.float + torch.float64: 8, # torch.float64 or torch.double + torch.float16: 2, # torch.float16 or torch.half + torch.bfloat16: 2, # torch.bfloat16 + torch.int32: 4, # torch.int32 or torch.int + torch.int64: 8, # torch.int64 or torch.long + torch.int16: 2, # torch.int16 or torch.short + torch.int8: 1, # torch.int8 + torch.uint8: 1, # torch.uint8 + torch.bool: 1, # torch.bool + torch.complex32: 4, # torch.complex32 (not yet available in PyTorch as of the latest stable release) + torch.complex64: 8, # torch.complex64 + torch.complex128: 16, # torch.complex128 +} + +valid_axis_names = [ + "x", + "y", + "z", + "w", + "v", + "t", +] + + +def get_byte_per_numel(dtype: torch.dtype) -> int: + return 1 if dtype is None else byte_per_numel[dtype] + + +def is_valid_axis_name(name: str) -> bool: + if name.startswith("r"): + return name[1:] in valid_axis_names + return name in valid_axis_names + + +# move to an appropriate place, currently duplicated with triton.__init__.py +def next_power_of_2(n: int): + """Return the smallest power of 2 greater than or equal to n""" + n -= 1 + n |= n >> 1 + n |= n >> 2 + n |= n >> 4 + n |= n >> 8 + n |= n >> 16 + n |= n >> 32 + n += 1 + return n diff --git a/third_party/ascend/backend/spec/__init__.py b/third_party/ascend/backend/spec/__init__.py index 0591b5feb..e6d6026a1 100644 --- a/third_party/ascend/backend/spec/__init__.py +++ b/third_party/ascend/backend/spec/__init__.py @@ -1,72 +1,6 @@ -from .triton.compiler.compiler import * -from .triton.compiler.errors import * -from .triton.compiler.code_generator import * -from .triton.runtime.jit import * -from .triton.runtime.autotuner import * -from .triton.language._utils import * -from .triton.language.core import * -from .triton.language.standard import * -from .triton.language.semantic import * -from .triton.language.math import * -from .triton.testing import * +from .triton.language import * __all__ = [ - # compiler.compiler - 'ext_ASTSource_attrs', 'opt_ascend_compile_speed', 'set_CompiledKernel_metadata_stream', 'handle_compile_error', - 'compiledKernel_getattribute_disable_init_handles', - # compiler.code_generator - 'ext_CodeGenerator_builder_with_compile_mode', 'for_op_ext_attrs', 'for_op_set_ext_attrs', - 'ext_CodeGenerator_visit_Assign_hint_anno', 'visit_For_ext_support', 'set_bind_sub_block_when_parallel', - 'check_override_bind_sub_block', 'forop_setattr_for_bind_sub_block', 'need_repr_in_CodeGenerator_CompilationError', - # runtime.jit - 'enable_stream_in_kwargs', 'ignore_params_in_JITFunction_run', 'check_grid_size', 'explicit_load_kernel_library', - 'get_JITFunction_spec_attr', 'maps_line_numbers_to_comment_hints', 'attach_line_number_to_comment_mapping', - 'enable_extra_option', - # runtime.autotuner - 'set_Autotuner_auto_profile_dir', 'ext_Autotuner_do_bench_MLIRCompilationError', 'ext_Autotuner_batch_bench', - 'ext_Autotuner_profile', 'default_Config_arg_is_none', 'set_Config_extra_options', 'ext_Config_all_kwargs', - 'ext_Config_to_str', 'new_AutoTilingTuner', - # language._utils - 'block_shape_disable_check_power_of_two', 'get_primitive_bitwidth', - # language.core - "enable_care_padding_load", "ext_cast_set_overflow_modes", "ext_cast_check_overflow_mode", - "ext_trans_unwrap_iterable", "check_dot_deprecated_param_allow_tf32", "check_dot_invalid_input_precision", "gather", - "insert_slice", "extract_slice", "get_element", "__add__", "__radd__", "__sub__", "__rsub__", "__mul__", "__rmul__", - "__mod__", "__lshift__", "__rshift__", "compile_hint", "sort", "multibuffer", "sync_block_all", "sync_block_set", - "sync_block_wait", "load_tensor_descriptor", "store_tensor_descriptor", "make_tensor_descriptor", - "index_select_simd", "dtype_to_ir", "parallel", "index_select", "index_put", "gather_out_to_ub", - "scatter_ub_to_out", "dot_scaled", "range", "core_ext_spec_api_list", "core_tensor_ext_spec_api_list", - # language.semantic - "ret_if_not_create_int_cast", "check_arange_range_power_of_two", "arange_disable_check_power_of_two", - "check_arange_less_than_max_numel", "is_cast_src_dst_scalar_type_equal", "check_unsupported_fp8_fp64", - "ext_dot_operand_types", "dot_check_hf32_input_precision", "dot_disable_check_max_num_imprecise_acc", - "reset_dot_max_num_imprecise_acc", "check_was_bool_to_int8_dtype", "cast_bool_to_specified_dtype", - "check_unexpected_dtype_float", "check_unexpected_dtype_bool", "set_load_legacy_other_input", - "disable_cast_back_when_load_legacy_ptr_is_bool", "set_attr_was_bool_to_int8", "atomic_disable_original_check", - "atomic_cas_disable_element_bitwidth_check", "ext_atomic_cas_element_typechecking", "is_atomic_max_no_bitcast", - "is_atomic_min_no_bitcast", "atomic_max_returning_tensor", "atomic_min_returning_tensor", - "is_float_format_support_bf16", "is_float_format_support_fp16", "floating_mod_returning_tensor", - "logical_check_int1_bitcast", "ext_dot_scaled_validate_lhs_dtype", "ext_dot_scaled_validate_rhs_dtype", - "ext_dot_scaled_check_same_dtype", "dot_scaled_disable_original_check", "ext_dot_scaled_check_lhs_rhs_format", - "dot_scaled_recheck_rhs_scale_is_none", "dot_scaled_check_lhs_scale_is_none", "is_dot_scaled_support_rhs_scale", - "check_dot_scaled_lhs_scale_dtype", "check_dot_scaled_rhs_scale_dtype", "dot_scaled_lhs_bitcast_to_fp_type", - "dot_scaled_rhs_bitcast_to_fp_type", "dot_scaled_lrhs_k_pack", "check_dot_scaled_dimension", - "check_dot_scaled_pack_size", "set_dot_scaled_lhs_scale_handle", "ext_semantic_gather", "ext_semantic_insert_slice", - "ext_semantic_extract_slice", "ext_semantic_get_element", "ext_semantic_compile_hint", "ext_semantic_custom_op", - "ext_semantic_sort", "ext_semantic_scalar_constant", "ext_semantic_make_scalar", - "ext_semantic_make_tensor_descriptor", "ext_semantic__load_block_pointer", "ext_semantic_index_select_simd", - "ext_semantic_flip_simd", "ext_semantic_flip", "ext_semantic_static_range", "ext_semantic_embedding_gather", - "ext_semantic_index_put", "ext_semantic_gather_out_to_ub", "ext_semantic_scatter_ub_to_out", - "semantic_ext_spec_api_list", - # language.standard - "sigmoid", "softmax", "isfinited", "finitef", "rint", "atan2", "argmax", "argmin", "topk", "max", - "standard_ext_spec_api_list", - # language.math - "umulhi", "exp", "exp2", "log", "log2", "cos", "sin", "sqrt", "sqrt_rn", "rsqrt", "div_rn", "erf", "tanh", "floor", - "ceil", "fma", "_check_dtype", "cdiv", "isnan", "isinf", "reciprocal", "relu", "log1p", "tan", "atan", "tanh", - "ilogb", "ldexp", "pow", "flip", "atan2", "div_rz", "fmod", "trunc", "round", "math_ext_base_api_list", - "math_ext_spec_api_list", - # testing - 'is_do_bench_npu', 'ext_do_bench_npu', 'patch_triton_language', 'do_bench_npu', 'do_bench_multiple_kernel_npu', - 'testing_ext_spec_api_list' + # triton.language + "language_modify_all", ] diff --git a/third_party/ascend/backend/spec/include/triton/Dialect/Triton/IR/TritonOps.td b/third_party/ascend/backend/spec/include/triton/Dialect/Triton/IR/TritonOps.td index dba1e9343..0c505a5df 100644 --- a/third_party/ascend/backend/spec/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/third_party/ascend/backend/spec/include/triton/Dialect/Triton/IR/TritonOps.td @@ -996,7 +996,12 @@ def CallOp : TT_Op<"call", [CallOpInterface, /*MemRefsNormalizable, */DeclareOpI ``` }]; - let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$operands); + let arguments = (ins FlatSymbolRefAttr:$callee, + Variadic:$operands, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs + ); + let results = (outs Variadic); let builders = [ diff --git a/third_party/ascend/backend/spec/lib/Dialect/Triton/IR/Ops.cpp b/third_party/ascend/backend/spec/lib/Dialect/Triton/IR/Ops.cpp index 8234e4012..84533f263 100644 --- a/third_party/ascend/backend/spec/lib/Dialect/Triton/IR/Ops.cpp +++ b/third_party/ascend/backend/spec/lib/Dialect/Triton/IR/Ops.cpp @@ -34,6 +34,7 @@ #include "triton/Dialect/Triton/IR/Utility.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/LogicalResult.h" +#include "flir/include/npu/Dialect/TritonAscend/IR/TritonAscendDialect.h" namespace mlir { namespace triton { @@ -197,7 +198,7 @@ LogicalResult GatherOp::inferReturnTypes( } //-- IndexSelectSimdOp -- -LogicalResult IndexSelectSimdOp::inferReturnTypes( +LogicalResult ascend::IndexSelectSimdOp::inferReturnTypes( MLIRContext *context, std::optional location, ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { diff --git a/third_party/ascend/backend/spec/triton/compiler/code_generator.py b/third_party/ascend/backend/spec/triton/compiler/code_generator.py index 8fb9400a6..172ba90b4 100644 --- a/third_party/ascend/backend/spec/triton/compiler/code_generator.py +++ b/third_party/ascend/backend/spec/triton/compiler/code_generator.py @@ -1,79 +1,1387 @@ -def ext_CodeGenerator_builder_with_compile_mode(options): - return "simt" if options.force_simt_only else "simd" - - -def for_op_ext_attrs(iterator): - return (iterator.disallow_acc_multi_buffer, iterator.flatten, iterator.warp_specialize, iterator.disable_licm) - - -def for_op_set_ext_attrs(for_op, builder, ext_attrs): - disallow_acc_multi_buffer, flatten, warp_specialize, disable_licm = ext_attrs - if disallow_acc_multi_buffer: - for_op.set_attr("tt.disallow_acc_multi_buffer", builder.get_unit_attr()) - if flatten: - for_op.set_attr("tt.flatten", builder.get_unit_attr()) - if warp_specialize: - for_op.set_attr("tt.warp_specialize", builder.get_unit_attr()) - if disable_licm: - for_op.set_attr("tt.disable_licm", builder.get_unit_attr()) - - -def ext_CodeGenerator_visit_Assign_hint_anno(code_generator, node, names, values): - import ast - from triton.compiler.code_generator import _is_triton_value - # flagtree: After normal processing, check if we need to add hint annotation - if hasattr(node, 'lineno') and hasattr(code_generator, 'jit_fn'): - line_num = node.lineno - # TODO: reparse needed in case we need to deal with complex cases, will be redesigned later - function_def = code_generator.jit_fn.parse() - line_flagtree_hints = getattr(function_def.body[0], 'line_flagtree_hints', {}) - flagtree_hints = line_flagtree_hints.get(line_num) - - # Check if this is a tl.load call with dot_pad_only_k hint - if (flagtree_hints and 'dot_pad_only_k' in flagtree_hints and isinstance(node.value, ast.Call) - and isinstance(node.value.func, ast.Attribute) and isinstance(node.value.func.value, ast.Name) - and node.value.func.value.id == 'tl' and node.value.func.attr == 'load'): - - # Add hint annotation to the loaded tensor(s) - for name, value in zip(names, values): - if _is_triton_value(value): - # print(f"[FLAGTREE] Creating hint annotation for tensor: {flagtree_hints}") - # Create hint annotation - hint_val = code_generator.builder.get_unit_attr() - code_generator.builder.create_annotation(value.handle, 'dot_pad_only_k', hint_val) - - -def visit_For_ext_support(): - import triton.language as language - return [language.parallel] - - -def set_bind_sub_block_when_parallel(IteratorClass, iterator, bind_sub_block): - import triton.language as language - if (IteratorClass is language.parallel): - return iterator.bind_sub_block - return bind_sub_block - - -def check_override_bind_sub_block(code_generator, node, bind_sub_block): - # flagtree: After normal processing, check if we need to override bind_sub_block - if hasattr(node, 'lineno') and hasattr(code_generator, 'jit_fn'): - line_num = node.lineno - # TODO: reparse needed in case we need to deal with complex cases, will be redesigned later - function_def = code_generator.jit_fn.parse() - line_flagtree_hints = getattr(function_def.body[0], 'line_flagtree_hints', {}) - flagtree_hints = line_flagtree_hints.get(line_num) - - # Check if this is a range/for loop with bind_sub_block hint - if flagtree_hints and 'bind_sub_block' in flagtree_hints: +import ast +import inspect +import re +import sys +import warnings +import os +import textwrap + +from typing import Any, Callable, Dict, Optional, Tuple, Type, Union + +import triton.language.extra.cann.extension as extension +from triton.extension.buffer.language import core as bl +from triton.extension.buffer.language.builder import setup_unified_builder_with_buffer_builder + +from .. import language +from .._C.libtriton import ir, buffer_ir +from .._C.libtriton.ascend import ir as ascend_ir +from ..language import constexpr, tensor, str_to_ty +from ..language.core import _unwrap_if_constexpr, nv_tma_desc_type, _value +from ..runtime.jit import _normalize_ty, get_jit_fn_file_line +# ideally we wouldn't need any runtime component +from ..runtime import JITFunction +from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct) +from types import ModuleType +# Central registry for all 'with' statement handlers +WITH_DISPATCH = {} + +# Import and register Ascend extension dispatch handlers +from triton.language.extra.cann.extension.dispatch import ASCEND_WITH_DISPATCH +from triton.language.extra.cann.extension.builder import setup_unified_builder + +WITH_DISPATCH.update(ASCEND_WITH_DISPATCH) + + +def mangle_ty(ty): + if ty.is_ptr(): + return 'P' + mangle_ty(ty.element_ty) + if ty.is_int(): + SIGNED = language.dtype.SIGNEDNESS.SIGNED + prefix = 'i' if ty.int_signedness == SIGNED else 'u' + return prefix + str(ty.int_bitwidth) + if ty.is_floating(): + return str(ty) + if ty.is_block(): + elt = mangle_ty(ty.scalar) + shape = '_'.join(map(str, ty.shape)) + return f'{elt}S{shape}S' + if ty.is_void(): + return 'V' + raise TypeError(f'Unsupported type {ty}') + + +mangle_ty = WITH_DISPATCH.get("mangle_ty", mangle_ty) + + +def mangle_fn(name, arg_tys, constants): + # doesn't mangle ret type, which must be a function of arg tys + mangled_arg_names = '_'.join([mangle_ty(ty) for ty in arg_tys]) + mangled_constants = '_'.join([f'{i}c{repr(constants[i])}' for i in sorted(constants)]) + mangled_constants = mangled_constants.replace('.', '_d_') + mangled_constants = mangled_constants.replace("'", '_sq_') + # [ and ] are not allowed in LLVM identifiers + mangled_constants = mangled_constants.replace('[', '_').replace(']', '_') + ret = f'{name}__{mangled_arg_names}__{mangled_constants}' + return ret + + +def _is_triton_value(o: Any) -> bool: + return isinstance(o, _value) + + +def _is_triton_tensor(o: Any) -> bool: + return isinstance(o, tensor) + + +def _is_constexpr(o: Any) -> bool: + return isinstance(o, constexpr) + + +def _is_triton_scalar(o: Any) -> bool: + return _is_triton_tensor(o) and (not o.type.is_block() or o.type.numel == 1) + + +def _is_list_like(o: Any) -> bool: + return isinstance(o, (list, tuple)) + + +def _check_fn_args(node, fn, args): + if fn.noinline: + for idx, arg in enumerate(args): + if not _is_constexpr(arg) and not _is_triton_scalar(arg): + raise UnsupportedLanguageConstruct( + fn.src, node, + f'Function {fn.__name__} is marked noinline, but was called with non-scalar argument {fn.arg_names[idx]}:{arg}' + ) + + +_condition_types = {bool, int, type(None)} # Python types accepted for conditionals inside kernels + + +class enter_sub_region: + + def __init__(self, generator): + self.generator = generator + + def __enter__(self): + # record lscope & local_defs in the parent scope + self.liveins = self.generator.lscope.copy() + self.prev_defs = self.generator.local_defs.copy() + self.generator.local_defs = {} + self.insert_block = self.generator.builder.get_insertion_block() + self.insert_point = self.generator.builder.get_insertion_point() + return self.liveins, self.insert_block + + def __exit__(self, *args, **kwargs): + self.generator.builder.restore_insertion_point(self.insert_point) + self.generator.lscope = self.liveins + self.generator.local_defs = self.prev_defs + + +# Check if the given syntax node has an "early" return +class ContainsReturnChecker(ast.NodeVisitor): + + def __init__(self, gscope): + self.gscope = gscope + + def _visit_stmts(self, body) -> bool: + return any(self.visit(s) for s in body) + + def _visit_function(self, fn) -> bool: + # Currently we only support JITFunctions defined in the global scope + if isinstance(fn, JITFunction) and not fn.noinline: + fn_node = fn.parse() + return ContainsReturnChecker(self.gscope).visit(fn_node) + return False + + def generic_visit(self, node) -> bool: + ret = False + for _, value in ast.iter_fields(node): + if isinstance(value, list): + for item in value: + if isinstance(item, ast.AST): + ret = ret or self.visit(item) + elif isinstance(value, ast.AST): + ret = ret or self.visit(value) + return ret + + def visit_Attribute(self, node: ast.Attribute) -> bool: + # If the left part is a name, it's possible that + # we call triton native function or a jit function from another module. + # If the left part is not a name, it must return a tensor or a constexpr + # whose methods do not contain return statements + # e.g., (tl.load(x)).to(y) + # So we only check if the expressions within value have return or not + if isinstance(node.value, ast.Name): + if node.value.id in self.gscope: + value = self.gscope[node.value.id] + fn = getattr(value, node.attr) + return self._visit_function(fn) + return False + return self.visit(node.value) + + def visit_Name(self, node: ast.Name) -> bool: + if type(node.ctx) is ast.Store: + return False + if node.id in self.gscope: + fn = self.gscope[node.id] + return self._visit_function(fn) + return False + + def visit_Return(self, node: ast.Return) -> bool: + return True + + def visit_Assign(self, node: ast.Assign) -> bool: + # There couldn't be an early return + # x = ... + return False + + def visit_AugAssign(self, node: ast.AugAssign) -> bool: + # There couldn't be an early return + # x += ... + return False + + def visit_Module(self, node: ast.Module) -> bool: + return self._visit_stmts(node.body) + + def visit_FunctionDef(self, node: ast.FunctionDef) -> bool: + return self._visit_stmts(node.body) + + def visit_If(self, node: ast.If) -> bool: + # TODO: optimize the following case in which we actually don't have + # a return when static_cond is false: + # if dynamic_cond + # if static_cond + # func_with_return + # else + # func_without_return + ret = self._visit_stmts(node.body) + if node.orelse: + ret = ret or self._visit_stmts(node.orelse) + return ret + + def visit_IfExp(self, node: ast.IfExp) -> bool: + return self.visit(node.body) or self.visit(node.orelse) + + def visit_Call(self, node: ast.Call) -> bool: + return self.visit(node.func) + + +class CodeGenerator(ast.NodeVisitor): + + def __init__(self, context, prototype, gscope, attributes, constants, function_name, jit_fn: JITFunction, options, + codegen_fns, module_map, module=None, is_kernel=False, function_types: Optional[Dict] = None, + noinline=False, file_name: Optional[str] = None, begin_line=0): + self.context = context + # Only NPUOptions has force_simt_only attribute, so check for NPU backend + if hasattr(options, "force_simt_only") and options.force_simt_only: + self.builder = ir.builder(context, compile_mode="simt") + else: + self.builder = ir.builder(context, compile_mode="simd") + self.file_name = file_name + # node.lineno starts from 1, so we need to subtract 1 + self.begin_line = begin_line - 1 + self.builder.set_loc(file_name, begin_line, 0) + self.builder.options = options + + # Set up unified builder interface with methods from specialized builders + self.ascend_builder = ascend_ir.ascendnpu_ir_builder(context, getattr(options, "arch", "")) + self.ascend_builder.set_loc(file_name, begin_line, 0) + setup_unified_builder(self.builder, self.ascend_builder) + self.buffer_builder = buffer_ir.buffer_builder(context) + self.buffer_builder.set_loc(file_name, begin_line, 0) + setup_unified_builder_with_buffer_builder(self.builder, self.buffer_builder) + + # dict of functions provided by the backend. Below are the list of possible functions: + # Convert custom types not natively supported on HW. + # convert_custom_types(intput_tensor, dtype, fp_downcast_rounding=None, _builder=None) + self.builder.codegen_fns = codegen_fns + self.builder.module_map = {} if module_map is None else module_map + self.module = self.builder.create_module() if module is None else module + self.function_ret_types = {} if function_types is None else function_types + self.prototype = prototype + + self.gscope = {} + for k, v in gscope.items(): + if isinstance(v, ModuleType): + self.gscope[k] = module_map.get(v.__name__, v) + continue + + module_name = getattr(v, "__module__", "") + if module_name in module_map: + self.gscope[k] = getattr(module_map[module_name], v.__name__) + else: + self.gscope[k] = v + + self.lscope = {} + self.attributes = attributes + self.constants = constants + self.jit_fn = jit_fn + self.function_name = function_name + self.is_kernel = is_kernel + self.cur_node = None + self.noinline = noinline + self.scf_stack = [] + self.ret_type = None + # SSA-construction + # name => language.tensor + self.local_defs: Dict[str, tensor] = {} + self.dereference_name: Callable[[str], Any] = self._define_name_lookup() + self.fn = None + # Are we currently visiting an ast.arg's default value? These have some + # special handling. + self.visiting_arg_default_value = False + + builtin_namespace: Dict[str, Any] = {_.__name__: _ for _ in (len, list, range, float, int, isinstance, getattr)} + builtin_namespace.update(( + ('print', language.core.device_print), + ('min', language.minimum), + ('max', language.maximum), + )) + + def _unsupported(self, node, message): + return UnsupportedLanguageConstruct(self.jit_fn.src, node, message) + + def _is_constexpr_global(self, name): + absent_marker = object() + val = self.gscope.get(name, absent_marker) + if val is absent_marker: + return False + + if _is_constexpr(val): return True - # print(f"[FLAGTREE] Found bind_sub_block hint at line {line_num}") - return bind_sub_block + if a := self.gscope.get("__annotations__", {}).get(name): + return _normalize_ty(a) == "constexpr" + + return False + + def _define_name_lookup(self): + + def local_lookup(name: str, absent): + # this needs to be re-fetched from `self` every time, because it gets switched occasionally + return self.lscope.get(name, absent) + + def global_lookup(name: str, absent): + val = self.gscope.get(name, absent) + # The high-level rule is that only constexpr globals are allowed. + # But actually a bunch of other things, such as module imports, are + # technically Python globals. We have to allow these too! + if any([ + val is absent, name in self.builtin_namespace, # + type(val) is ModuleType, # + isinstance(val, JITFunction), # + getattr(val, "__triton_builtin__", False), # + getattr(val, "__module__", "").startswith("triton.language"), # + isinstance(val, language.dtype), # + self._is_constexpr_global(name), # + # Allow accesses to globals while visiting an ast.arg + # because you should be able to do + # @triton.jit def fn(x: tl.constexpr = GLOBAL): ... + self.visiting_arg_default_value, # + os.environ.get("TRITON_ALLOW_NON_CONSTEXPR_GLOBALS", "0") == "1" + ]): + return val + raise NameError( + textwrap.dedent(f"""\ + Cannot access global variable {name} from within @jit'ed + function. Triton kernels can only access global variables that + are annotated as constexpr (`x: triton.language.constexpr = 42` + or `x = triton.language.constexpr(42)`). Alternatively, set the + envvar TRITON_ALLOW_NON_CONSTEXPR_GLOBALS=1, but we do not + promise to support this forever.""").replace("\n", " ")) + + absent_marker = object() + + def name_lookup(name: str) -> Any: + absent = absent_marker + for lookup_function in local_lookup, global_lookup, self.builtin_namespace.get: + value = lookup_function(name, absent) + if value is not absent: + return value + raise NameError(f'{name} is not defined') + + return name_lookup + + def set_value(self, name: str, value: Union[tensor, constexpr]) -> None: + ''' This function: + called by visit_Assign() & visit_FunctionDef() to store left value (lvalue) + 1. record local defined name (FIXME: should consider control flow) + 2. store tensor in self.lvalue + ''' + self.lscope[name] = value + self.local_defs[name] = value + + def _get_insertion_point_and_loc(self, builder=None): + # XXX: this is a hack to get the location of the insertion point. + # The insertion point's location could be invalid sometimes, + # so we need to explicitly set the location + _builder = self.builder if not builder else builder + loc = _builder.get_loc() + ip = _builder.get_insertion_point() + return ip, loc + + def _set_insertion_point_and_loc(self, ip, loc, builder=None): + _builder = self.builder if not builder else builder + _builder.restore_insertion_point(ip) + _builder.set_loc(loc) + + # + # AST visitor + # + def visit_compound_statement(self, stmts): + # Ensure that stmts is iterable + if not _is_list_like(stmts): + stmts = [stmts] + for stmt in stmts: + self.visit(stmt) + + # Stop parsing as soon as we hit a `return` statement; everything + # after this is dead code. + if isinstance(stmt, ast.Return): + break + + def visit_Module(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_List(self, node): + ctx = self.visit(node.ctx) + assert ctx is None + elts = [self.visit(elt) for elt in node.elts] + return elts + + # By design, only non-kernel functions can return + def visit_Return(self, node): + ret_value = self.visit(node.value) + if ret_value is None: + self.builder.ret([]) + ret_ty = language.void + elif isinstance(ret_value, tuple): + ret_values = [language.semantic.to_tensor(v, self.builder) for v in ret_value] + ret_types = [v.type for v in ret_values] + self.builder.ret([v.handle for v in ret_values]) + ret_ty = tuple(ret_types) + else: + ret = language.semantic.to_tensor(ret_value, self.builder) + self.builder.ret([ret.handle]) + ret_ty = ret.type + + if self.ret_type is None: + self.ret_type = ret_ty + elif self.ret_type != ret_ty: + raise TypeError(f'Inconsistent return types: {self.ret_type} and {ret_ty}') + + # A return op must always terminate the basic block, so we create a dead + # basic block in case there are any ops after the return. + post_ret_block = self.builder.create_block() + self.builder.set_insertion_point_to_end(post_ret_block) + + def visit_FunctionDef(self, node): + arg_names, kwarg_names = self.visit(node.args) + if self.fn: + raise self._unsupported(node, "nested function definition is not supported.") + # initialize defaults + for i, default_value in enumerate(node.args.defaults[::-1]): + arg_node = node.args.args[-i - 1] + annotation = arg_node.annotation + name = arg_node.arg + st_target = ast.Name(id=name, ctx=ast.Store()) + if annotation is None: + init_node = ast.Assign(targets=[st_target], value=default_value) + else: + init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation) + + try: + assert not self.visiting_arg_default_value + self.visiting_arg_default_value = True + self.visit(init_node) + finally: + self.visiting_arg_default_value = False + + # initialize function + visibility = "public" if self.is_kernel else "private" + self.fn = self.builder.get_or_insert_function(self.module, self.function_name, + self.prototype.to_ir(self.builder), visibility, self.noinline) + self.module.push_back(self.fn) + entry = self.fn.add_entry_block() + arg_values = [] + idx = 0 + for i in range(len(arg_names)): + if i in self.constants: + cst = self.constants[i] + if not _is_constexpr(cst): + cst = constexpr(self.constants[i]) + arg_values.append(cst) + continue + else: + if i in self.attributes: + for name, value in self.attributes[i]: + self.fn.set_arg_attr(idx, name, value) + + # Mark this argument as a pass-by-value TMA descriptor (nvidia) + if isinstance(self.prototype.param_types[idx], nv_tma_desc_type): + self.fn.set_arg_attr(idx, "tt.nv_tma_desc", 1) + + param_type = self.prototype.param_types[idx] + if isinstance(param_type, bl.buffer_type): + arg_values.append(bl.buffer(self.fn.args(idx), param_type)) + else: + arg_values.append(tensor(self.fn.args(idx), param_type)) + idx += 1 + + insert_pt = self.builder.get_insertion_block() + for arg_name, arg_value in zip(arg_names, arg_values): + self.set_value(arg_name, arg_value) + self.builder.set_insertion_point_to_start(entry) + # visit function body + self.visit_compound_statement(node.body) + + # finalize function + assert not self.builder.get_insertion_block().has_terminator() + if self.ret_type is None or self.ret_type == language.void: + self.ret_type = language.void + self.builder.ret([]) + else: + self.prototype.ret_types = list(self.ret_type) if isinstance(self.ret_type, tuple) else [self.ret_type] + self.fn.reset_type(self.prototype.to_ir(self.builder)) + self.builder.ret([ + self.builder.create_poison(ty.to_ir(self.builder)) + for ty in self.prototype.ret_types + if self.ret_type is not None + ]) + self.fn.finalize() + + if insert_pt: + self.builder.set_insertion_point_to_end(insert_pt) + + def visit_arguments(self, node): + arg_names = [] + for arg in node.args: + arg_names += [self.visit(arg)] + kwarg_names = self.visit(node.kwarg) + return arg_names, kwarg_names + + def visit_arg(self, node): + ast.NodeVisitor.generic_visit(self, node) + return node.arg + + def visit_AnnAssign(self, node): + # extract attributes + annotation = self.visit(node.annotation) + target = self.visit(node.target) + value = self.visit(node.value) + # constexpr + if annotation == constexpr: + if target in self.lscope: + raise ValueError(f'{target} is already defined.' + f' constexpr cannot be reassigned.') + if not _is_constexpr(value): + value = constexpr(value) + self.lscope[target] = value + return self.lscope[target] + # default: call visit_Assign + return self.visit_Assign(node) + + def visit_Assign(self, node): + _names = [] + if isinstance(node, ast.AnnAssign): + _names += [self.visit(node.target)] + else: + for target in node.targets: + _names += [self.visit(target)] + if len(_names) > 1: + raise self._unsupported(node, "simultaneous multiple assignment is not supported.") + names = _names[0] + values = self.visit(node.value) + if not _is_list_like(names): + names = [names] + if not _is_list_like(values): + values = [values] + native_nontensor_types = (language.dtype, ) + for name, value in zip(names, values): + # by default, constexpr are assigned into python variable + value = _unwrap_if_constexpr(value) + if value is not None and \ + not _is_triton_value(value) and \ + not isinstance(value, native_nontensor_types): + value = language.semantic.to_tensor(value, self.builder) + self.set_value(name, value) + + def visit_AugAssign(self, node): + name = node.target.id + lhs = ast.Name(id=name, ctx=ast.Load()) + rhs = ast.BinOp(lhs, node.op, node.value) + assign = ast.Assign(targets=[node.target], value=rhs) + self.visit(assign) + return self.dereference_name(name) + + def visit_Name(self, node): + if type(node.ctx) is ast.Store: + return node.id + return self.dereference_name(node.id) + + def visit_Store(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_Load(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_Tuple(self, node): + args = [self.visit(x) for x in node.elts] + return tuple(args) + + def _apply_binary_method(self, method_name, lhs, rhs): + # TODO: raise something meaningful if getattr fails below, esp for reverse method + if _is_triton_tensor(lhs): + return getattr(lhs, method_name)(rhs, _builder=self.builder) + if _is_triton_tensor(rhs): + reverse_method_name = re.sub(r"__(.*)__", r"__r\1__", method_name) + return getattr(rhs, reverse_method_name)(lhs, _builder=self.builder) + return getattr(lhs, method_name)(rhs) + + def visit_BinOp(self, node): + lhs = self.visit(node.left) + rhs = self.visit(node.right) + method_name = self._method_name_for_bin_op.get(type(node.op)) + if method_name is None: + raise self._unsupported(node, + "AST binary operator '{}' is not (currently) implemented.".format(node.op.__name__)) + return self._apply_binary_method(method_name, lhs, rhs) + + _method_name_for_bin_op: Dict[Type[ast.operator], str] = { + ast.Add: '__add__', + ast.Sub: '__sub__', + ast.Mult: '__mul__', + ast.Div: '__truediv__', + ast.FloorDiv: '__floordiv__', + ast.Mod: '__mod__', + ast.Pow: '__pow__', + ast.LShift: '__lshift__', + ast.RShift: '__rshift__', + ast.BitAnd: '__and__', + ast.BitOr: '__or__', + ast.BitXor: '__xor__', + } + + def visit_then_else_blocks(self, node, liveins, then_block, else_block): + # then block + self.builder.set_insertion_point_to_start(then_block) + self.visit_compound_statement(node.body) + then_block = self.builder.get_insertion_block() + then_defs = self.local_defs.copy() + # else block + else_defs = {} + if node.orelse: + self.builder.set_insertion_point_to_start(else_block) + self.lscope = liveins.copy() + self.local_defs = {} + self.visit_compound_statement(node.orelse) + else_defs = self.local_defs.copy() + else_block = self.builder.get_insertion_block() + + # update block arguments + names = [] + ret_types = [] + ir_ret_types = [] + # variables in livein whose value is updated in `if` + for name in liveins: + # check type + for defs, block_name in [(then_defs, 'then'), (else_defs, 'else')]: + if name in defs: + assert defs[name].type == liveins[name].type, \ + f'initial value for `{name}` is of type {liveins[name].type}, '\ + f'but the {block_name} block redefines it as {defs[name].type}' + if name in then_defs or name in else_defs: + names.append(name) + ret_types.append(then_defs[name].type if name in then_defs else else_defs[name].type) + ir_ret_types.append(then_defs[name].handle.get_type() if name in + then_defs else else_defs[name].handle.get_type()) + # variable defined in then but not in else + if name in then_defs and name not in else_defs: + else_defs[name] = liveins[name] + # variable defined in else but not in then + if name in else_defs and name not in then_defs: + then_defs[name] = liveins[name] + # variables that are both in then and else but not in liveins + # TODO: could probably be cleaned up + for name in sorted(then_defs.keys() & else_defs.keys()): + if name in names: + continue + then_ty = then_defs[name].type + else_ty = else_defs[name].type + assert then_ty == else_ty, \ + f'Mismatched type for {name} between then block ({then_ty}) '\ + f'and else block ({else_ty})' + names.append(name) + ret_types.append(then_ty) + ir_ret_types.append(then_defs[name].handle.get_type()) + + return then_defs, else_defs, then_block, else_block, names, ret_types, ir_ret_types + + def visit_if_top_level(self, cond, node): + with enter_sub_region(self) as sr: + liveins, ip_block = sr + then_block = self.builder.create_block() + else_block = self.builder.create_block() + # create branch + self.builder.set_insertion_point_to_end(ip_block) + self.builder.create_cond_branch(cond.handle, then_block, else_block) + # visit then and else blocks + then_defs, else_defs, then_block, else_block, names, ret_types, ir_ret_types = \ + self.visit_then_else_blocks(node, liveins, then_block, else_block) + # create basic-block after conditional + endif_block = self.builder.create_block() + # then terminator + self.builder.set_insertion_point_to_end(then_block) + assert not then_block.has_terminator(), f"{then_block}" + self.builder.create_branch(endif_block, [then_defs[n].handle for n in names]) + # else terminator + self.builder.set_insertion_point_to_end(else_block) + assert not else_block.has_terminator(), f"{else_block}" + self.builder.create_branch(endif_block, [else_defs[n].handle for n in names]) + for ty in ir_ret_types: + endif_block.add_argument(ty) + + # change block + self.builder.set_insertion_point_to_start(endif_block) + # update value + for i, name in enumerate(names): + new_tensor = language.core.tensor(endif_block.arg(i), ret_types[i]) + self.set_value(name, new_tensor) + + # TODO: refactor + def visit_if_scf(self, cond, node): + with enter_sub_region(self) as sr: + liveins, _ = sr + ip, last_loc = self._get_insertion_point_and_loc() + then_block = self.builder.create_block() + else_block = self.builder.create_block() if node.orelse else None + then_defs, else_defs, then_block, else_block, names, ret_types, _ = \ + self.visit_then_else_blocks(node, liveins, then_block, else_block) + # create if op + self._set_insertion_point_and_loc(ip, last_loc) + if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, True) + then_block.merge_block_before(if_op.get_then_block()) + self.builder.set_insertion_point_to_end(if_op.get_then_block()) + if len(names) > 0: + self.builder.create_yield_op([then_defs[n].handle for n in names]) + if not node.orelse: + else_block = if_op.get_else_block() + else: + else_block.merge_block_before(if_op.get_else_block()) + self.builder.set_insertion_point_to_end(if_op.get_else_block()) + if len(names) > 0: + self.builder.create_yield_op([else_defs[n].handle for n in names]) + # update values + for i, name in enumerate(names): + new_tensor = language.core.tensor(if_op.get_result(i), ret_types[i]) + self.set_value(name, new_tensor) + + def visit_If(self, node): + cond = self.visit(node.test) + + if _is_triton_tensor(cond): + cond = cond.to(language.int1, _builder=self.builder) + contains_return = ContainsReturnChecker(self.gscope).visit(node) + if contains_return: + if self.scf_stack: + raise self._unsupported( + node, "Cannot have `return` statements inside `while` or `for` statements in triton " + "(note that this also applies to `return` statements that are inside functions " + "transitively called from within `while`/`for` statements)") + self.visit_if_top_level(cond, node) + else: + self.visit_if_scf(cond, node) + else: + cond = _unwrap_if_constexpr(cond) + # not isinstance - we insist the real thing, no subclasses and no ducks + if type(cond) not in _condition_types: + raise self._unsupported( + node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format( + ', '.join(_.__name__ for _ in _condition_types), + type(cond).__name__)) + + active_block = node.body if cond else node.orelse + self.visit_compound_statement(active_block) + + def visit_IfExp(self, node): + cond = self.visit(node.test) + if _is_triton_tensor(cond): + cond = cond.to(language.int1, _builder=self.builder) + # TODO: Deal w/ more complicated return types (e.g tuple) + with enter_sub_region(self): + ip, last_loc = self._get_insertion_point_and_loc() + + then_block = self.builder.create_block() + self.builder.set_insertion_point_to_start(then_block) + then_val = language.semantic.to_tensor(self.visit(node.body), self.builder) + then_block = self.builder.get_insertion_block() + + else_block = self.builder.create_block() + self.builder.set_insertion_point_to_start(else_block) + # do not need to reset lscope since + # ternary expressions cannot define new variables + else_val = language.semantic.to_tensor(self.visit(node.orelse), self.builder) + else_block = self.builder.get_insertion_block() + + self._set_insertion_point_and_loc(ip, last_loc) + + assert then_val.type == else_val.type, \ + f'Ternary expression with dynamic condition has inconsistent types {then_val.type} and {else_val.type}' + ret_type = then_val.type + + ret_type_ir = [ret_type.to_ir(self.builder)] if ret_type != language.void else [] + if_op = self.builder.create_if_op(ret_type_ir, cond.handle, True) + then_block.merge_block_before(if_op.get_then_block()) + if ret_type_ir: + self.builder.set_insertion_point_to_end(if_op.get_then_block()) + self.builder.create_yield_op([then_val.handle]) + + self.builder.set_insertion_point_to_end(if_op.get_then_block()) + else_block.merge_block_before(if_op.get_else_block()) + if ret_type_ir: + self.builder.set_insertion_point_to_end(if_op.get_else_block()) + self.builder.create_yield_op([else_val.handle]) + return language.core.tensor(if_op.get_result(0), ret_type) if ret_type_ir else None + else: + cond = _unwrap_if_constexpr(cond) + + # not isinstance - we insist the real thing, no subclasses and no ducks + if type(cond) not in _condition_types: + raise self._unsupported( + node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format( + ', '.join(_.__name__ for _ in _condition_types), + type(cond).__name__)) + if cond: + return self.visit(node.body) + else: + return self.visit(node.orelse) + + def visit_Pass(self, node): + pass + + def visit_With(self, node): + """Handle 'with' statements using dispatch pattern.""" + assert len(node.items) == 1 + context = node.items[0].context_expr + + # Check if context is a Call and dispatch to registered handler + if isinstance(context, ast.Call): + withitemClass = self.visit(context.func) + handler = WITH_DISPATCH.get(withitemClass) + if handler: + return handler(self, node) + + # Fall back to visiting body for unhandled cases + return self.visit_compound_statement(node.body) + + def visit_Compare(self, node): + if not (len(node.comparators) == 1 and len(node.ops) == 1): + raise self._unsupported(node, "simultaneous multiple comparison is not supported") + lhs = self.visit(node.left) + rhs = self.visit(node.comparators[0]) + lhs_value = _unwrap_if_constexpr(lhs) + rhs_value = _unwrap_if_constexpr(rhs) + if type(node.ops[0]) is ast.Is: + return constexpr(lhs_value is rhs_value) + if type(node.ops[0]) is ast.IsNot: + return constexpr(lhs_value is not rhs_value) + method_name = self._method_name_for_comp_op.get(type(node.ops[0])) + if method_name is None: + raise self._unsupported( + node, "AST comparison operator '{}' is not (currently) implemented.".format(node.ops[0].__name__)) + return self._apply_binary_method(method_name, lhs, rhs) + + _method_name_for_comp_op: Dict[Type[ast.cmpop], str] = { + ast.Eq: '__eq__', ast.NotEq: '__ne__', ast.Lt: '__lt__', ast.LtE: '__le__', ast.Gt: '__gt__', ast.GtE: '__ge__' + } + + def visit_UnaryOp(self, node): + operand = self.visit(node.operand) + fn = self._method_name_for_unary_op.get(type(node.op)) + if fn is None: + raise self._unsupported(node, f"AST unary operator '{node.op.__name__}' is not (currently) implemented.") + if _is_triton_tensor(operand): + return getattr(operand, fn)(_builder=self.builder) + try: + return getattr(operand, fn)() + except AttributeError: + raise self._unsupported( + node, f"AST unary operator '{fn}' is not (currently) implemented on type {type(operand).__name__}") + + _method_name_for_unary_op: Dict[Type[ast.unaryop], str] = { + ast.USub: '__neg__', ast.UAdd: '__pos__', ast.Not: '__not__', ast.Invert: '__invert__' + } + + def _verify_loop_carried_variable(self, name, loop_val, live_val): + assert _is_triton_value(loop_val), f'cannot reassign constxpr {name} in the loop' + assert _is_triton_value(live_val), f'cannot reasign constexpr {name} in the loop' + assert type(loop_val) == type(live_val), f'Loop carried variable {name} changed type' + assert not _is_triton_tensor(loop_val) or loop_val.type == live_val.type, \ + f'Loop-carried variable {name} has initial type {live_val.type} '\ + f'but is re-assigned to {loop_val.type} in loop! '\ + f'Please make sure that the type stays consistent.' + + def visit_While(self, node): + with enter_sub_region(self) as sr: + liveins, insert_block = sr + ip, last_loc = self._get_insertion_point_and_loc() + + # loop body (the after region) + # loop_block = self.builder.create_block() + dummy = self.builder.create_block() + self.builder.set_insertion_point_to_start(dummy) + self.scf_stack.append(node) + self.visit_compound_statement(node.body) + self.scf_stack.pop() + loop_defs = self.local_defs + dummy.erase() + + # collect loop-carried values + names = [] + ret_types = [] + init_args = [] + for name in loop_defs: + if name in liveins: + # We should not def new constexpr + loop_val = loop_defs[name] + live_val = liveins[name] + self._verify_loop_carried_variable(name, loop_val, live_val) + + # these are loop-carried values + names.append(name) + ret_types.append(loop_val.type) + init_args.append(live_val) + + self._set_insertion_point_and_loc(ip, last_loc) + while_op = self.builder.create_while_op([ty.to_ir(self.builder) for ty in ret_types], + [arg.handle for arg in init_args]) + # merge the condition region + before_block = self.builder.create_block_with_parent(while_op.get_before(), + [ty.to_ir(self.builder) for ty in ret_types]) + self.builder.set_insertion_point_to_start(before_block) + for i, name in enumerate(names): + self.lscope[name] = language.core.tensor(before_block.arg(i), ret_types[i]) + self.local_defs[name] = self.lscope[name] + cond = self.visit(node.test) + self.builder.set_insertion_point_to_end(before_block) + # create ConditionOp: e.g., scf.condition(%cond) %arg0, %arg1, ... + self.builder.create_condition_op(cond.handle, [before_block.arg(i) for i in range(len(init_args))]) + # merge the loop body + after_block = self.builder.create_block_with_parent(while_op.get_after(), + [ty.to_ir(self.builder) for ty in ret_types]) + + # generate loop body + self.builder.set_insertion_point_to_start(after_block) + for i, name in enumerate(names): + self.lscope[name] = language.core.tensor(after_block.arg(i), ret_types[i]) + self.local_defs[name] = self.lscope[name] + self.scf_stack.append(node) + self.visit_compound_statement(node.body) + self.scf_stack.pop() + loop_defs = self.local_defs + yields = [] + for name in loop_defs: + if name in liveins: + yields.append(loop_defs[name]) + self.builder.create_yield_op([y.handle for y in yields]) + + # WhileOp defines new values, update the symbol table (lscope, local_defs) + for i, name in enumerate(names): + new_def = language.core.tensor(while_op.get_result(i), ret_types[i]) + self.lscope[name] = new_def + self.local_defs[name] = new_def + + for stmt in node.orelse: + assert False, "Not implemented" + ast.NodeVisitor.generic_visit(self, stmt) + + def visit_Subscript(self, node): + assert node.ctx.__class__.__name__ == "Load" + lhs = self.visit(node.value) + slices = self.visit(node.slice) + if _is_triton_tensor(lhs): + return lhs.__getitem__(slices, _builder=self.builder) + return lhs[slices] + + def visit_ExtSlice(self, node): + return [self.visit(dim) for dim in node.dims] + + def visit_For(self, node): + IteratorClass = self.visit(node.iter.func) + iter_args = [self.visit(arg) for arg in node.iter.args] + iter_kwargs = dict(self.visit(keyword) for keyword in node.iter.keywords) + if IteratorClass == language.static_range: + iterator = IteratorClass(*iter_args, **iter_kwargs) + static_range = range(iterator.start.value, iterator.end.value, iterator.step.value) + for i in static_range: + self.lscope[node.target.id] = constexpr(i) + self.visit_compound_statement(node.body) + for stmt in node.orelse: + ast.NodeVisitor.generic_visit(self, stmt) + return + num_stages = None + loop_unroll_factor = None + disallow_acc_multi_buffer = False + flatten = False + warp_specialize = False + disable_licm = False + bind_sub_block = None + if IteratorClass in [language.range, extension.parallel]: + iterator = IteratorClass(*iter_args, **iter_kwargs) + # visit iterator arguments + # note: only `range` iterator is supported now + # collect lower bound (lb), upper bound (ub), and step + lb = iterator.start + ub = iterator.end + step = iterator.step + num_stages = iterator.num_stages + loop_unroll_factor = iterator.loop_unroll_factor + disallow_acc_multi_buffer = iterator.disallow_acc_multi_buffer + flatten = iterator.flatten + warp_specialize = iterator.warp_specialize + disable_licm = iterator.disable_licm + if (IteratorClass is extension.parallel): + bind_sub_block = iterator.bind_sub_block + elif IteratorClass is range: + # visit iterator arguments + # note: only `range` iterator is supported now + # collect lower bound (lb), upper bound (ub), and step + lb = iter_args[0] if len(iter_args) > 1 else self.visit(ast.Num(0)) + ub = iter_args[1] if len(iter_args) > 1 else self.visit(node.iter.args[0]) + step = iter_args[2] if len(iter_args) > 2 else self.visit(ast.Num(1)) + else: + raise RuntimeError('Only `range` and `static_range` iterators are currently supported') + # handle negative constant step (not supported by scf.for in MLIR) + negative_step = False + if _is_constexpr(step) and step.value < 0: + step = constexpr(-step.value) + negative_step = True + lb, ub = ub, lb + lb = language.semantic.to_tensor(lb, self.builder) + ub = language.semantic.to_tensor(ub, self.builder) + step = language.semantic.to_tensor(step, self.builder) + # induction variable type + if not lb.dtype.is_int() or not ub.dtype.is_int() or not step.dtype.is_int(): + raise TypeError(f"For loop bounds and step must all be ints, are ({lb.dtype}, {ub.dtype}, {step.dtype})") + iv_type = language.semantic.integer_promote_impl(lb.dtype, ub.dtype) + iv_type = language.semantic.integer_promote_impl(iv_type, step.dtype) + iv_ir_type = iv_type.to_ir(self.builder) + iv_is_signed = iv_type.int_signedness == language.core.dtype.SIGNEDNESS.SIGNED + # lb/ub/step might be constexpr, we need to cast them to tensor + lb = lb.handle + ub = ub.handle + step = step.handle + # ForOp can only accept IndexType as lb/ub/step. Cast integer to Index + lb = self.builder.create_int_cast(lb, iv_ir_type, iv_is_signed) + ub = self.builder.create_int_cast(ub, iv_ir_type, iv_is_signed) + step = self.builder.create_int_cast(step, iv_ir_type, iv_is_signed) + # Create placeholder for the loop induction variable + iv = self.builder.create_poison(iv_ir_type) + self.set_value(node.target.id, language.core.tensor(iv, iv_type)) + + with enter_sub_region(self) as sr: + liveins, insert_block = sr + ip, last_loc = self._get_insertion_point_and_loc() + + # create loop body block + block = self.builder.create_block() + self.builder.set_insertion_point_to_start(block) + # dry visit loop body + self.scf_stack.append(node) + self.visit_compound_statement(node.body) + self.scf_stack.pop() + block.erase() + + # If a variable (name) is defined in both its parent & itself, then it's + # a loop-carried variable. (They must be of the same type) + init_args = [] + yields = [] + names = [] + for name in self.local_defs: + if name in liveins: + loop_val = self.local_defs[name] + live_val = liveins[name] + self._verify_loop_carried_variable(name, loop_val, live_val) + + names.append(name) + init_args.append(live_val) + yields.append(loop_val) + + # create ForOp + self._set_insertion_point_and_loc(ip, last_loc) + for_op = self.builder.create_for_op(lb, ub, step, [arg.handle for arg in init_args]) + if num_stages is not None: + for_op.set_attr("tt.num_stages", self.builder.get_int32_attr(num_stages)) + if loop_unroll_factor is not None: + for_op.set_attr("tt.loop_unroll_factor", self.builder.get_int32_attr(loop_unroll_factor)) + if disallow_acc_multi_buffer: + for_op.set_attr("tt.disallow_acc_multi_buffer", self.builder.get_unit_attr()) + if flatten: + for_op.set_attr("tt.flatten", self.builder.get_unit_attr()) + if warp_specialize: + for_op.set_attr("tt.warp_specialize", self.builder.get_unit_attr()) + if disable_licm: + for_op.set_attr("tt.disable_licm", self.builder.get_unit_attr()) + if (IteratorClass is extension.parallel): + for_op.set_attr("hivm.parallel_loop", self.builder.get_unit_attr()) + + self.scf_stack.append(node) + self.builder.set_insertion_point_to_start(for_op.get_body(0)) + # reset local scope to not pick up local defs from the previous dry run. + self.lscope = liveins.copy() + self.local_defs = {} + for i, name in enumerate(names): + self.set_value(name, language.core.tensor(for_op.get_body(0).arg(i + 1), yields[i].type)) + self.visit_compound_statement(node.body) + self.scf_stack.pop() + yields = [] + for name in self.local_defs: + if name in liveins: + yields.append(language.semantic.to_tensor(self.local_defs[name], self.builder)) + + # create YieldOp + if len(yields) > 0: + self.builder.create_yield_op([y.handle for y in yields]) + for_op_region = for_op.get_body(0).get_parent() + assert for_op_region.size() == 1, "We use SCF, so the loop body should only have one block" + + # update induction variable with actual value, and replace all uses + self.builder.set_insertion_point_to_start(for_op.get_body(0)) + iv = for_op.get_induction_var() + if negative_step: + iv = self.builder.create_sub(ub, iv) + iv = self.builder.create_add(iv, lb) + self.lscope[node.target.id].handle.replace_all_uses_with(iv) + self.set_value(node.target.id, language.core.tensor(iv, iv_type)) + + # update lscope & local_defs (ForOp defines new values) + for i, name in enumerate(names): + self.set_value(name, language.core.tensor(for_op.get_result(i), yields[i].type)) + + for stmt in node.orelse: + assert False, "Don't know what to do with else after for" + ast.NodeVisitor.generic_visit(self, stmt) + + def visit_Slice(self, node): + lower = self.visit(node.lower) + upper = self.visit(node.upper) + step = self.visit(node.step) + return slice(lower, upper, step) + + def visit_Index(self, node): + return self.visit(node.value) + + def visit_keyword(self, node) -> Tuple[str, Any]: + return node.arg, self.visit(node.value) + + def visit_Assert(self, node) -> Any: + test = self.visit(node.test) + msg = self.visit(node.msg) if node.msg is not None else "" + return language.core.device_assert(test, msg, _builder=self.builder) + + def call_JitFunction(self, fn: JITFunction, args, kwargs): + args = inspect.getcallargs(fn.fn, *args, **kwargs) + args = [args[name] for name in fn.arg_names] + args = [arg if _is_triton_value(arg) else constexpr(arg) for arg in args] + # generate function def + attributes = {} + constexprs = [i for i, arg in enumerate(args) if _is_constexpr(arg)] + constants = {i: args[i] for i in constexprs} + # generate call + args = [None if i in constexprs else arg for i, arg in enumerate(args)] + arg_vals = [arg.handle for arg in args if arg is not None] + arg_types = [arg.type for arg in args if arg is not None] + fn_name = mangle_fn(fn.__name__, arg_types, constants) + # generate function def if necessary + if not self.module.has_function(fn_name): + prototype = language.function_type([], arg_types) + gscope = fn.__globals__ + # If the callee is not set, we use the same debug setting as the caller + file_name, begin_line = get_jit_fn_file_line(fn) + generator = CodeGenerator(self.context, prototype, gscope, attributes, constants, module=self.module, + jit_fn=fn, function_name=fn_name, function_types=self.function_ret_types, + noinline=fn.noinline, file_name=file_name, begin_line=begin_line, + options=self.builder.options, codegen_fns=self.builder.codegen_fns, + module_map=self.builder.module_map) + try: + generator.visit(fn.parse()) + except Exception as e: + # Wrap the error in the callee with the location of the call. + raise CompilationError(self.jit_fn.src, self.cur_node, repr(e)) from e + + callee_ret_type = generator.ret_type + self.function_ret_types[fn_name] = callee_ret_type + else: + callee_ret_type = self.function_ret_types[fn_name] + symbol = self.module.get_function(fn_name) + call_op = self.builder.call(symbol, arg_vals) + if call_op.get_num_results() == 0 or callee_ret_type is None: + return None + elif call_op.get_num_results() == 1: + return tensor(call_op.get_result(0), callee_ret_type) + else: + # should return a tuple of tl.tensor + results = [] + for i in range(call_op.get_num_results()): + results.append(tensor(call_op.get_result(i), callee_ret_type[i])) + return tuple(results) + + def visit_Call(self, node): + fn = _unwrap_if_constexpr(self.visit(node.func)) + static_implementation = self.statically_implemented_functions.get(fn) + if static_implementation is not None: + return static_implementation(self, node) + + kws = dict(self.visit(keyword) for keyword in node.keywords) + args = [self.visit(arg) for arg in node.args] + if isinstance(fn, JITFunction): + _check_fn_args(node, fn, args) + return self.call_JitFunction(fn, args, kws) + if (hasattr(fn, '__self__') and _is_triton_value(fn.__self__)) or language.core.is_builtin(fn): + # Copy builder's location and insertion point. + ip, last_loc = self._get_insertion_point_and_loc() + # Use ascend_builder if this function is a builtin extension operation. + _builder = self.ascend_builder if extension.is_builtin(fn) else self.builder + self._set_insertion_point_and_loc(ip, last_loc, _builder) + extra_kwargs = {"_builder": _builder} + sig = inspect.signature(fn) + if '_generator' in sig.parameters: + extra_kwargs['_generator'] = self + try: + ret = fn(*args, **extra_kwargs, **kws) + # Sync the builder's location before return. + ip, last_loc = self._get_insertion_point_and_loc(_builder) + self._set_insertion_point_and_loc(ip, last_loc) + return ret + except Exception as e: + # Normally when we raise a CompilationError, we raise it as + # `from None`, because the original fileline from the exception + # is not relevant (and often points into code_generator.py + # itself). But when calling a function, we raise as `from e` to + # preserve the traceback of the original error, which may e.g. + # be in core.py. + raise CompilationError(self.jit_fn.src, node, repr(e)) from e + + if fn in self.builtin_namespace.values(): + args = map(_unwrap_if_constexpr, args) + return fn(*args, **kws) + + def visit_Constant(self, node): + return constexpr(node.value) + + def visit_BoolOp(self, node: ast.BoolOp): + if len(node.values) != 2: + raise self._unsupported( + node, "chained boolean operators (A or B or C) are not supported; use parentheses to split the chain.") + lhs = self.visit(node.values[0]) + rhs = self.visit(node.values[1]) + method_name = self._method_name_for_bool_op.get(type(node.op)) + if method_name is None: + raise self._unsupported( + node, "AST boolean operator '{}' is not (currently) implemented.".format(node.op.__name__)) + return self._apply_binary_method(method_name, lhs, rhs) + + _method_name_for_bool_op: Dict[Type[ast.boolop], str] = {ast.And: 'logical_and', ast.Or: 'logical_or'} + + if sys.version_info < (3, 8): + + def visit_NameConstant(self, node): + return constexpr(node.value) + + def visit_Num(self, node): + return constexpr(node.n) + + def visit_Str(self, node): + return constexpr(ast.literal_eval(node)) + + def visit_Attribute(self, node): + lhs = self.visit(node.value) + if _is_triton_tensor(lhs) and node.attr == "T": + return language.semantic.permute(lhs, (1, 0), builder=self.builder) + return getattr(lhs, node.attr) + + def visit_Expr(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_NoneType(self, node): + return None + + def visit_JoinedStr(self, node): + values = list(node.values) + for i, value in enumerate(values): + if isinstance(value, ast.Constant): + values[i] = str(value.value) + elif isinstance(value, ast.FormattedValue): + conversion_code = value.conversion + evaluated = self.visit(value.value) + if not _is_constexpr(evaluated): + raise self._unsupported( + node, + "Cannot evaluate f-string containing non-constexpr conversion values, found conversion of type " + + str(type(evaluated))) + values[i] = ("{}" if conversion_code < 0 else "{!" + chr(conversion_code) + "}").format(evaluated.value) + else: + raise AssertionError("encountered unexpected node of type {} in a JoinedStr node".format(type(value))) + return ''.join(values) + + def visit(self, node): + if node is None: + return + with warnings.catch_warnings(): + # The ast library added visit_Constant and deprecated some other + # methods but we can't move to that without breaking Python 3.6 and 3.7. + warnings.simplefilter("ignore", DeprecationWarning) # python 3.9 + warnings.simplefilter("ignore", PendingDeprecationWarning) # python 3.8 + last_node = self.cur_node + last_loc = self.builder.get_loc() + self.cur_node = node + if hasattr(node, 'lineno') and hasattr(node, 'col_offset'): + self.builder.set_loc(self.file_name, self.begin_line + node.lineno, node.col_offset) + last_loc = self.builder.get_loc() + try: + ret = super().visit(node) + except CompilationError: + raise + except Exception as e: + # Wrap the error in a CompilationError which contains the source + # of the @jit function. + raise CompilationError(self.jit_fn.src, self.cur_node, repr(e)) from None + + # Reset the location to the last one before the visit + if last_loc: + self.cur_node = last_node + self.builder.set_loc(last_loc) + return ret + + def generic_visit(self, node): + raise self._unsupported(node, "unsupported AST node type: {}".format(type(node).__name__)) + + def execute_static_assert(self, node: ast.Call) -> None: + arg_count = len(node.args) + if not (0 < arg_count <= 2) or len(node.keywords): + raise TypeError("`static_assert` requires one or two positional arguments only") + + passed = _unwrap_if_constexpr(self.visit(node.args[0])) + if not isinstance(passed, bool): + raise NotImplementedError( + "Assertion condition could not be determined at compile-time. Make sure that it depends only on `constexpr` values" + ) + if not passed: + if arg_count == 1: + message = "" + else: + try: + message = self.visit(node.args[1]) + except Exception as e: + message = "" + + raise CompileTimeAssertionFailure(self.jit_fn.src, node, _unwrap_if_constexpr(message)) + return None + + def static_executor(python_fn): + + def ret(self, node: ast.Call): + kws = { + name: _unwrap_if_constexpr(value) + for name, value in (self.visit(keyword) for keyword in node.keywords) + } + args = [_unwrap_if_constexpr(self.visit(arg)) for arg in node.args] + return constexpr(python_fn(*args, **kws)) + + return ret + + statically_implemented_functions: Dict[object, Callable[[ast.Call], Any]] = { + language.core.static_assert: execute_static_assert, + language.core.static_print: static_executor(print), + int: static_executor(int), + len: static_executor(len), + extension.int64: static_executor(extension.int64), + } + + +def kernel_suffix(signature, specialization): + # suffix format: + # <'c' if equal to 1><'d' if divisible by 16><'e' if divisible by 8> + suffix = '' + for i, _ in enumerate(signature): + suffix += str(i) + if i in specialization.equal_to_1: + suffix += 'c' + if i in specialization.divisibility_16: + suffix += 'd' + return suffix + + +def ast_to_ttir(fn, specialization, context, options, codegen_fns, module_map, module=None): + attrs = specialization.attrs + # create kernel prototype + cst_key = lambda i: fn.arg_names.index(i) if isinstance(i, str) else i + constants = {cst_key(key): value for key, value in specialization.constants.items()} + # visit kernel AST + gscope = fn.__globals__.copy() + function_name = fn.repr(specialization) + tys = list(specialization.signature.values()) + new_constants = attrs.get_constants() + for k in new_constants: + if k in tys and tys[k] == "i1" and new_constants[k] == 1: + new_constants[k] = True -def forop_setattr_for_bind_sub_block(code_generator, for_op, bind_sub_block): - for_op.set_attr("bind_sub_block", code_generator.builder.get_bool_attr(bind_sub_block)) + new_attrs = attrs.filter_out_constants() + fn_attrs = new_attrs.get_fn_attrs() + all_constants = constants.copy() + all_constants.update(new_constants) + arg_types = [str_to_ty(v) for k, v in specialization.signature.items() if k not in specialization.constants] + file_name, begin_line = get_jit_fn_file_line(fn) + prototype = language.function_type([], arg_types) + generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name, + jit_fn=fn, attributes=fn_attrs, is_kernel=True, file_name=file_name, + begin_line=begin_line, options=options, codegen_fns=codegen_fns, module_map=module_map, + module=module) + generator.visit(fn.parse()) -def need_repr_in_CodeGenerator_CompilationError(): - return True + ret = generator.module + # module takes ownership of the context + ret.context = context + return ret diff --git a/third_party/ascend/backend/spec/triton/compiler/compiler.py b/third_party/ascend/backend/spec/triton/compiler/compiler.py index c33274e9b..33c344aa9 100644 --- a/third_party/ascend/backend/spec/triton/compiler/compiler.py +++ b/third_party/ascend/backend/spec/triton/compiler/compiler.py @@ -1,37 +1,441 @@ -def ext_ASTSource_attrs(ast_source): - from triton.backends.ascend.compiler import AscendAttrsDescriptor - if ast_source.attrs is None: - ast_source.attrs = AscendAttrsDescriptor() - - -def opt_ascend_compile_speed(file_name, metadata_path, fn_cache_manager): - import os - compile_speed_opt = os.getenv("TRITON_ASCEND_COMPILE_SPEED_OPT", 'false').lower() in ('true', '1') - if (compile_speed_opt): - ttir_path = f"{file_name}.ttir" - if (metadata_path is None) and (fn_cache_manager.has_file(ttir_path)): - # Already compile once but failed. So directly return - raise Exception("already failed once") - - -def set_CompiledKernel_metadata_stream(compiled_kernel, stream): - if stream is None: - return stream - return compiled_kernel.metadata.stream - - -def handle_compile_error(e, ext, fn_cache_manager): - from .errors import MLIRCompilationError - if (ext == "ttadapter"): - stage_name = "ConvertTritonIRToLinalgIR" - elif (ext == "npubin"): - stage_name = "ConvertLinalgRToBinary" +from __future__ import annotations +import hashlib +import json +from .._C.libtriton import get_cache_invalidating_env_vars, ir, buffer_ir +from .._C.libtriton.ascend import ir as ascend_ir +from ..backends import backends +from ..backends.compiler import GPUTarget, AttrsDescriptor +from .. import __version__ +from ..runtime.autotuner import OutOfResources +from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager +from ..runtime.driver import driver +from ..tools.disasm import get_sass +# TODO: this shouldn't be here +from .code_generator import ast_to_ttir +from .errors import MLIRCompilationError +from pathlib import Path +import re +import functools +import os + +# - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func, +# and any following whitespace +# - (public\s+)? : optionally match the keyword public and any following whitespace +# - (@\w+) : match an @ symbol followed by one or more word characters +# (letters, digits, or underscores), and capture it as group 1 (the function name) +# - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing +# zero or more arguments separated by commas, and capture it as group 2 (the argument list) +# - (attributes \{[\S\s]+\})? : optionally match attributes enclosed in braces and capture it as group 3 +mlir_prototype_pattern = r"^\s*tt\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: [\S\s]+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*(attributes \{[\S\s]+\})?\s+\{\s*$" +ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)" +prototype_pattern = { + "ttir": mlir_prototype_pattern, + "ttgir": mlir_prototype_pattern, + "ptx": ptx_prototype_pattern, +} + +mlir_arg_type_pattern = r'%\w+: ((?:[^,\s<)]+|<[^>]+>)+(?: {[^}]+})?),?' +ptx_arg_type_pattern = r"\.param\s+\.(\w+)" +arg_type_pattern = { + "ttir": mlir_arg_type_pattern, + "ttgir": mlir_arg_type_pattern, + "ptx": ptx_arg_type_pattern, +} + + +def convert_type_repr(x): + # Currently we only capture the pointer type and assume the pointer is on global memory. + # TODO: Capture and support shared memory space + match = re.search(r'!tt\.ptr<([^,]+)', x) + tma = re.search(r'tt.nv_tma_desc = 1', x) + if tma is not None: + return 'nvTmaDesc' + x = re.sub(r' {[^}]+}', '', x) + if match is not None: + return '*' + convert_type_repr(match.group(1)) + return x + + +def _get_num_warps_from_ir_str(src: str): + ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:' + # TODO(jlebar): Using a regex to get num-warps is a hack, and will break if + # e.g. someone has an instruction (not module) attribute named "num-warps". + num_warps_matches = re.findall(ttgir_num_warps_pattern, src) + assert len(num_warps_matches) == 1, "Expected exactly one match for num_warps" + num_warps = int(num_warps_matches[0]) + return num_warps + + +class ASTSource: + + def __init__(self, fn, signature, constants=None, attrs=None) -> None: + self.fn = fn + self.ext = "ttir" + self.name = fn.__name__ + self.signature = signature + self.constants = constants + self.attrs = attrs + if isinstance(self.signature, str): + self.signature = {k: v.strip() for k, v in enumerate(self.signature.split(","))} + else: + for k in self.signature.keys(): + if not isinstance(k, str): + raise TypeError("Signature keys must be string") + if self.constants is None: + self.constants = {} + else: + for k in self.constants.keys(): + if not isinstance(k, str): + raise TypeError("Constants keys must be string") + if self.attrs is None: + self.attrs = AttrsDescriptor() + + def hash(self): + sorted_sig = [v for k, v in sorted(self.signature.items())] + # Note - we stringify the keys here to allow sorting to work for cases + # where constants have mixed int/str keys. + sorted_constants = sorted((str(k), v) for k, v in self.constants.items()) + key = f"{self.fn.cache_key}-{self.attrs.hash()}-{sorted_sig}-{sorted_constants}" + return hashlib.sha256(key.encode("utf-8")).hexdigest() + + def make_ir(self, options, codegen_fns, module_map, context): + return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns, + module_map=module_map) + + def parse_options(self): + return dict() + + +class IRSource: + + def __init__(self, path): + self.path = path + path = Path(path) + self.ext = path.suffix[1:] + self.src = path.read_text() + match = re.search(prototype_pattern[self.ext], self.src, re.MULTILINE) + self.name = match.group(1) + signature = match.group(2) + types = re.findall(arg_type_pattern[self.ext], signature) + self.signature = {k: convert_type_repr(ty) for k, ty in enumerate(types)} + + def hash(self): + return hashlib.sha256(self.src.encode("utf-8")).hexdigest() + + def make_ir(self, options, codegen_fns, module_map, context): + module = ir.parse_mlir_module(self.path, context) + module.context = context + return module + + def parse_options(self): + if self.ext == "ttgir": + return {'num_warps': _get_num_warps_from_ir_str(self.src)} + return dict() + + +@functools.lru_cache() +def triton_key(): + import pkgutil + TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + contents = [] + # frontend + with open(__file__, "rb") as f: + contents += [hashlib.sha256(f.read()).hexdigest()] + # compiler + path_prefixes = [ + (os.path.join(TRITON_PATH, "compiler"), "triton.compiler."), + (os.path.join(TRITON_PATH, "backends"), "triton.backends."), + ] + for path, prefix in path_prefixes: + for lib in pkgutil.walk_packages([path], prefix=prefix): + with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: + contents += [hashlib.sha256(f.read()).hexdigest()] + + # backend + libtriton_hash = hashlib.sha256() + with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f: + while True: + chunk = f.read(1024**2) + if not chunk: + break + libtriton_hash.update(chunk) + contents.append(libtriton_hash.hexdigest()) + # language + language_path = os.path.join(TRITON_PATH, 'language') + for lib in pkgutil.walk_packages([language_path], prefix="triton.language."): + with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: + contents += [hashlib.sha256(f.read()).hexdigest()] + return f'{__version__}' + '-'.join(contents) + + +def parse(full_name, ext, context): + if ext == "ttir" or ext == "ttgir": + module = ir.parse_mlir_module(full_name, context) + module.context = context + return module + if ext == "llir" or ext == "ptx": + return Path(full_name).read_text() + if ext == "cubin": + return Path(full_name).read_bytes() + + +def filter_traceback(e: BaseException): + """ + Removes code_generator.py and related files from tracebacks. + + These are uninteresting to the user -- "just show me *my* code!" + """ + if os.getenv("TRITON_FRONT_END_DEBUGGING", "0") == "1": + return + + if e.__cause__ is not None: + filter_traceback(e.__cause__) + if e.__context__ is not None: + filter_traceback(e.__context__) + + # If a user has a file that matches one of these, they're out of luck. + BAD_FILES = [ + "/triton/compiler/code_generator.py", + "/ast.py", + ] + + tb = e.__traceback__ + frames = [] + while tb is not None: + if not any(f for f in BAD_FILES if tb.tb_frame.f_code.co_filename.endswith(f)): + frames.append(tb) + tb = tb.tb_next + + for (cur_frame, next_frame) in zip(frames, frames[1:]): + cur_frame.tb_next = next_frame + + if not frames: + e.__traceback__ = None else: - stage_name = "MLIRCompile" - error_detail = e.stderr.decode('utf-8') if hasattr(e, 'stderr') and e.stderr else str(e) - error_detail += f"\n\n[INFO]: The compiled kernel cache is in {fn_cache_manager.cache_dir}\n\n" - raise MLIRCompilationError(stage_name, error_detail) + frames[-1].tb_next = None + e.__traceback__ = frames[0] + + +def compile(src, target=None, options=None): + if target is None: + target = driver.active.get_current_target() + assert isinstance(target, GPUTarget), "target must be of GPUTarget type" + backend = make_backend(target) + ir_source = not isinstance(src, ASTSource) + # create backend + if ir_source: + assert isinstance(src, str), "source must be either AST or a filepath" + src = IRSource(src) + extra_options = src.parse_options() + options = backend.parse_options(dict(options or dict(), **extra_options)) + # create cache manager + env_vars = get_cache_invalidating_env_vars() + key = f"{triton_key()}-{src.hash()}-{backend.hash()}-{options.hash()}-{str(sorted(env_vars.items()))}" + hash = hashlib.sha256(key.encode("utf-8")).hexdigest() + fn_cache_manager = get_cache_manager(hash) + # For dumping/overriding only hash the source as we want it to be independent of triton + # core changes to make it easier to track kernels by hash. + enable_override = os.environ.get("TRITON_KERNEL_OVERRIDE", "0") == "1" + enable_ir_dump = os.environ.get("TRITON_KERNEL_DUMP", "0") == "1" + fn_override_manager = get_override_manager(src.hash()) if enable_override else None + fn_dump_manager = get_dump_manager(src.hash()) if enable_ir_dump else None + # Pre-truncate the file name here to avoid hitting the 255 character limit on common platforms. + # The final file name in the cache will have a format of f"{filename}.{ext}.tmp.pid_{pid}_{uuid}". + # A PID string can be 5-character long. A UUID string has typically 36 characters. Let's truncate + # the file name to 150 characters to be safe. + file_name = src.name[:150] + metadata_filename = f"{file_name}.json" + metadata_group = fn_cache_manager.get_group(metadata_filename) or {} + metadata_path = metadata_group.get(metadata_filename) + always_compile = os.environ.get("TRITON_ALWAYS_COMPILE", "0") == "1" + if not always_compile and metadata_path is not None: + # cache hit! + metadata = json.loads(Path(metadata_path).read_text()) + return CompiledKernel(src, metadata_group, hash) + # initialize metadata + metadata = { + "hash": hash, + "target": target, + **options.__dict__, + **env_vars, + } + # run compilation pipeline and populate metadata + stages = dict() + backend.add_stages(stages, options) + first_stage = list(stages.keys()).index(src.ext) + # when the source is an IR file, don't apply the passes related to this stage. This makes it easier to write IR level tests. + if ir_source: + first_stage += 1 + context = ir.context() + ir.load_dialects(context) + buffer_ir.load_dialects(context) + ascend_ir.load_dialects(context) + backend.load_dialects(context) + codegen_fns = backend.get_codegen_implementation() + module_map = backend.get_module_map() + try: + module = src.make_ir(options, codegen_fns, module_map, context) + except Exception as e: + filter_traceback(e) + raise + use_ir_loc = os.environ.get("USE_IR_LOC", None) + for ext, compile_ir in list(stages.items())[first_stage:]: + try: + next_module = compile_ir(module, metadata) + except Exception as e: + if (ext == "ttadapter"): + stage_name = "ConvertTritonIRToLinalgIR" + elif (ext == "npubin"): + stage_name = "ConvertLinalgRToBinary" + else: + stage_name = "MLIRCompile" + error_detail = e.stderr.decode('utf-8') if hasattr(e, 'stderr') and e.stderr else str(e) + error_detail += f"\n\n[INFO]: The compiled kernel cache is in {fn_cache_manager.cache_dir}\n\n" + raise MLIRCompilationError(stage_name, error_detail) from e + ir_filename = f"{file_name}.{ext}" + if (fn_override_manager is not None and (full_name := fn_override_manager.get_file(ir_filename)) is not None): + print(f"\nOverriding kernel with file {full_name}") + next_module = parse(full_name, ext, context) + metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename) + if fn_dump_manager is not None: + fn_dump_manager.put(next_module, ir_filename) + # use an env variable to parse ir from file + if use_ir_loc == ext: + ir_full_name = fn_cache_manager.get_file(ir_filename) + next_module.create_location_snapshot(ir_full_name) + print(f"Creating new locations for {ir_full_name}") + module = next_module + # write-back metadata + metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename, + binary=False) + fn_cache_manager.put_group(metadata_filename, metadata_group) + # Compilation completed, disabling multithreading in context. + # This is needed to safely finalize threads pool inside context: if current process forks before + # python GC deletes context object, thread pool in child process will be invalid, which could + # lead to child crash or hang. + context.disable_multithreading() + # return handle to compiled kernel + return CompiledKernel(src, metadata_group, hash) + + +def make_backend(target): + actives = [x.compiler for x in backends.values() if x.compiler.supports_target(target)] + if len(actives) != 1: + raise RuntimeError( + f"{len(actives)} compatible backends for target ({target.backend}) ({actives}). There should only be one.") + return actives[0](target) + + +class LazyDict: + + def __init__(self, data): + self.data = data + self.extras = [] + + def get(self) -> None: + for func, args in self.extras: + self.data = self.data | func(*args) + self.extras.clear() + return self.data + + def add(self, func, args): + self.extras.append((func, args)) + + +class AsmDict(dict): + + def __missing__(self, key): + + if key == "sass": + value = get_sass(self["cubin"]) + else: + raise KeyError("Unknown key: '%s'" % key) + + self[key] = value + return value + + +class CompiledKernel: + + # Hooks for external tools to monitor the execution of triton kernels + # TODO: move out of this namespace since it's a runtime thing + launch_enter_hook = None + launch_exit_hook = None + + def __init__(self, src, metadata_group, hash): + from collections import namedtuple + metadata_path = next((Path(p) for c, p in metadata_group.items() if c.endswith(".json"))) + metadata = json.loads(metadata_path.read_text()) + metadata['cluster_dims'] = tuple(metadata['cluster_dims']) + # JSON serialization dumps the target as a dict. Restore it to a GPUTarget. + target = metadata['target'] + metadata['target'] = GPUTarget(target['backend'], target['arch'], target['warp_size']) + KernelMetadata = namedtuple('KernelMetadata', sorted(list(metadata.keys()))) + self.metadata = KernelMetadata(**metadata) + backend = make_backend(self.metadata.target) + self.packed_metadata = backend.pack_metadata(self.metadata) + self.src = src + self.hash = hash + self.name = self.metadata.name + # stores the text of each level of IR that was generated during compilation + asm_files = [Path(p) for c, p in metadata_group.items() if not c.endswith(".json")] + binary_ext = backend.binary_ext + self.asm = AsmDict({ + file.suffix[1:]: file.read_bytes() if file.suffix[1:] == binary_ext else file.read_text() + for file in asm_files + }) + self.kernel = self.asm[binary_ext] + # binaries are lazily initialized + # because it involves doing runtime things + # (e.g., checking amount of shared memory on current device) + self.module = None + self.function = None + + def _init_handles(self): + if self.module is not None: + return + device = driver.active.get_current_device() + # create launcher + self.run = driver.active.launcher_cls(self.src, self.metadata) + # not enough shared memory to run the kernel + max_shared = driver.active.utils.get_device_properties(device)["max_shared_mem"] + if self.metadata.shared > max_shared: + raise OutOfResources(self.metadata.shared, max_shared, "shared memory") + # TODO: n_regs, n_spills should be metadata generated when calling `ptxas` + self.module, self.function, self.n_regs, self.n_spills = driver.active.utils.load_binary( + self.name, self.kernel, self.metadata.shared, device) + + def __getattribute__(self, name): + if name == 'run': + self._init_handles() + return super().__getattribute__(name) + + def launch_metadata(self, grid, stream, *args): + if CompiledKernel.launch_enter_hook is None: + return None + ret = LazyDict({"name": self.name, "function": self.function, "stream": stream}) + if not isinstance(self.src, ASTSource) or self.src.fn.launch_metadata is None: + return ret + arg_dict = {} + arg_idx = 0 + for i, arg_name in enumerate(self.src.fn.arg_names): + if i in self.src.fn.constexprs: + arg_dict[arg_name] = self.src.constants[arg_name] + else: + arg_dict[arg_name] = args[arg_idx] + arg_idx += 1 + ret.add(self.src.fn.launch_metadata, (grid, self.metadata, arg_dict)) + return ret + + def __getitem__(self, grid): + self._init_handles() + def runner(*args, stream=None): + if stream is None: + device = driver.active.get_current_device() + stream = driver.active.get_current_stream(device) + launch_metadata = self.launch_metadata(grid, stream, *args) + self.run(grid[0], grid[1], grid[2], stream, self.function, self.packed_metadata, launch_metadata, + CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, *args) -def compiledKernel_getattribute_disable_init_handles(): - return True + return runner diff --git a/third_party/ascend/backend/spec/triton/compiler/errors.py b/third_party/ascend/backend/spec/triton/compiler/errors.py index 55d701be2..c72ca65c2 100644 --- a/third_party/ascend/backend/spec/triton/compiler/errors.py +++ b/third_party/ascend/backend/spec/triton/compiler/errors.py @@ -1,7 +1,54 @@ -import importlib.util -import sys +import ast from typing import Optional -from triton.compiler.errors import TritonError +from ..errors import TritonError + + +class CompilationError(TritonError): + """Base class for all errors raised during compilation""" + source_line_count_max_in_message = 12 + + def _format_message(self) -> str: + node = self.node + if self.src is None: + source_excerpt = " " + else: + if hasattr(node, 'lineno'): + source_excerpt = self.src.split('\n')[:node.lineno][-self.source_line_count_max_in_message:] + if source_excerpt: + source_excerpt.append(' ' * node.col_offset + '^') + source_excerpt = '\n'.join(source_excerpt) + else: + source_excerpt = " " + else: + source_excerpt = self.src + + message = "at {}:{}:\n{}".format(node.lineno, node.col_offset, source_excerpt) if hasattr( + node, 'lineno') else source_excerpt + if self.error_message: + message += '\n' + self.error_message + return message + + def __init__(self, src: Optional[str], node: ast.AST, error_message: Optional[str] = None): + self.src = src + self.node = node + self.error_message = error_message + self.message = self._format_message() + + def __str__(self): + return self.message + + def __reduce__(self): + # this is necessary to make CompilationError picklable + return type(self), (self.src, self.node, self.error_message) + + +class CompileTimeAssertionFailure(CompilationError): + """Specific exception for failed tests in `static_assert` invocations""" + pass + + +class UnsupportedLanguageConstruct(CompilationError): + pass class MLIRCompilationError(TritonError): diff --git a/third_party/ascend/backend/spec/triton/language/__init__.py b/third_party/ascend/backend/spec/triton/language/__init__.py new file mode 100644 index 000000000..a1cd9c50d --- /dev/null +++ b/third_party/ascend/backend/spec/triton/language/__init__.py @@ -0,0 +1,19 @@ +def language_modify_all(all_array): + try: + import acl + is_compile_on_910_95 = acl.get_soc_name().startswith("Ascend910_95") + except Exception as e: + is_compile_on_910_95 = False + + from .standard import topk + from .core import ( + make_tensor_descriptor, + load_tensor_descriptor, + store_tensor_descriptor, + gather, + ) + all_array.append("topk") + all_array.append("make_tensor_descriptor") + all_array.append("load_tensor_descriptor") + all_array.append("store_tensor_descriptor") + all_array.append("gather") diff --git a/third_party/ascend/backend/spec/triton/language/_utils.py b/third_party/ascend/backend/spec/triton/language/_utils.py index 0d740f8a4..65a015a9d 100644 --- a/third_party/ascend/backend/spec/triton/language/_utils.py +++ b/third_party/ascend/backend/spec/triton/language/_utils.py @@ -1,14 +1,78 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Union, Dict +from typing import List, TYPE_CHECKING, Any, Union, Dict + if TYPE_CHECKING: - from triton.language import core + from .language import core IterableType = Union[list[Any], tuple[Any, ...], core.tuple, core.tuple_type] ObjPath = tuple[int, ...] +TRITON_MAX_TENSOR_NUMEL = 1048576 + + +def is_power_of_two(x): + return (x & (x - 1)) == 0 + + +def validate_block_shape(shape: List[int]): + numel = 1 + for i, d in enumerate(shape): + if not isinstance(d, int): + raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d)}]") + # FIXME:patched triton community + # if not is_power_of_two(d): + # raise ValueError(f"Shape element {i} must be a power of 2") + numel *= d + + if numel > TRITON_MAX_TENSOR_NUMEL: + raise ValueError(f"numel ({numel}) exceeds triton maximum tensor numel ({TRITON_MAX_TENSOR_NUMEL})") + return numel -def block_shape_disable_check_power_of_two(): - return True + +type_canonicalisation_dict = { + # we canonicalise all bools to be unsigned: + "bool": "u1", + "int1": "u1", + "uint1": "u1", + "i1": "u1", + # floating-point dtypes: + "float8e4nv": "fp8e4nv", + "float8e5": "fp8e5", + "float8e4b15": "fp8e4b15", + "float8_e4m3fn": "fp8e4nv", + "float8e4b8": "fp8e4b8", + "float8_e4m3fnuz": "fp8e4b8", + "float8_e5m2": "fp8e5", + "float8e5b16": "fp8e5b16", + "float8_e5m2fnuz": "fp8e5b16", + "half": "fp16", + "float16": "fp16", + "bfloat16": "bf16", + "float": "fp32", + "float32": "fp32", + "double": "fp64", + "float64": "fp64", + # signed integers: + "int8": "i8", + "int16": "i16", + "int": "i32", + "int32": "i32", + "int64": "i64", + # unsigned integers: + "uint8": "u8", + "uint16": "u16", + "uint32": "u32", + "uint64": "u64", + "void": "void", +} + +for v in list(type_canonicalisation_dict.values()): + type_canonicalisation_dict[v] = v + + +def canonicalize_dtype(dtype): + dtype_str = str(dtype).split(".")[-1] + return type_canonicalisation_dict[dtype_str] BITWIDTH_DICT: Dict[str, int] = { @@ -24,6 +88,9 @@ def block_shape_disable_check_power_of_two(): "void": 0, } +for k, v in type_canonicalisation_dict.items(): + BITWIDTH_DICT[k] = BITWIDTH_DICT[v] + def get_primitive_bitwidth(dtype: str) -> int: return BITWIDTH_DICT[dtype] diff --git a/third_party/ascend/backend/spec/triton/language/core.py b/third_party/ascend/backend/spec/triton/language/core.py index 02ccb342c..9ed34a899 100644 --- a/third_party/ascend/backend/spec/triton/language/core.py +++ b/third_party/ascend/backend/spec/triton/language/core.py @@ -1,306 +1,3358 @@ -from typing import List, Sequence, Union -from triton._C.libtriton import ir -import triton.language.semantic as semantic -from . import semantic as semantic_spec -from triton.language.core import ( - builtin, - _tensor_member_fn, - _unwrap_iterable, - _constexpr_to_value, - constexpr, - tensor, - range, - float32, - check_bit_width, - _unwrap_if_constexpr, - add, - sub, - mul, -) - -from .tensor_descriptor import tensor_descriptor, tensor_descriptor_base - -try: - import acl - is_compile_on_910_95 = acl.get_soc_name().startswith("Ascend910_95") -except Exception as e: - is_compile_on_910_95 = False - - -def enable_care_padding_load(): - return True - - -def ext_cast_set_overflow_modes(): - return ["trunc", "saturate"] - - -def ext_cast_check_overflow_mode(overflow_mode, overflow_modes, ret, _builder): - if overflow_mode is not None: - if overflow_mode in overflow_modes: - semantic_spec.ext_semantic_compile_hint(ret, "overflow_mode", overflow_mode, _builder) +from __future__ import annotations + +from warnings import warn +from contextlib import contextmanager +from enum import Enum +from functools import partial, wraps +import typing +from typing import Union, Callable, List, Sequence, TypeVar, Optional +import builtins +from ..runtime.jit import jit +import inspect +import os + +from .._C.libtriton import ir +from . import semantic +from ._utils import TRITON_MAX_TENSOR_NUMEL, validate_block_shape, get_primitive_bitwidth + +T = TypeVar('T') + +TRITON_BUILTIN = "__triton_builtin__" + +PropagateNan = ir.PROPAGATE_NAN + + +def builtin(fn: T) -> T: + """Mark a function as a builtin.""" + assert callable(fn) + + @wraps(fn) + def wrapper(*args, **kwargs): + if "_builder" not in kwargs or kwargs["_builder"] is None: + print(kwargs) + raise ValueError("Did you forget to add @triton.jit ? " + "(`_builder` argument must be provided outside of JIT functions.)") + return fn(*args, **kwargs) + + setattr(wrapper, TRITON_BUILTIN, True) + + return wrapper + + +def _tensor_member_fn(fn: T) -> T: + """Decorator that adds this free function as a member fn on class tensor. + + When called as a member function on class tensor, the first argument to `fn` + is `self`, i.e. the tensor object. + + If there are multiple decorators on a function, you probably want this one + to be the highest one (i.e. furthest from the function's `def`), so it's + applied last. + + Unfortunately you still need to add a type stub to the body of class tensor + in order for pytype to know about it. + """ + assert callable(fn) + orig_sig = inspect.signature(fn) + # Does fn take args other than _builder, _generator, and the tensor itself? + has_args = len(orig_sig.parameters.keys() - {"_builder", "_generator"}) > 1 + + if not fn.__doc__: + fn.__doc__ = "" + fn.__doc__ += f""" + This function can also be called as a member function on :py:class:`tensor`, + as :code:`x.{fn.__name__}({"..." if has_args else ""})` instead of + :code:`{fn.__name__}(x{", ..." if has_args else ""})`. + """ + + def wrapper(*args, **kwargs): + return fn(*args, **kwargs) + + # Match the signature of `fn`, but change the first arg to `self` so the + # docs are a little less weird. + new_params = list(orig_sig.parameters.values()) + new_params[0] = new_params[0].replace(name='self') + new_sig = orig_sig.replace(parameters=new_params) + wrapper.__signature__ = new_sig + wrapper.__doc__ = f"Forwards to :py:func:`{fn.__name__}` free function" + # If fn is a builtin, mark the wrapper as a builtin too. + if is_builtin(fn): + setattr(wrapper, TRITON_BUILTIN, True) + + setattr(tensor, fn.__name__, wrapper) + return fn + + +def _unwrap_iterable(x): + """Returns x[0] if x has one element and x[0] is iterable.""" + if len(x) == 1: + # Determine whether x[0] is iterable. + # + # You might want to use collections.abc.Iterable instead of this + # try/except block. Unfortunately, this doesn't work with constexpr. + # + # The problem is that abc.Iterable checks for __iter__ on the *class*. + # But we want constexpr to expose an __iter__ method if and only if the + # wrapped *object* (i.e. self.value) is iterable. Therefore there's no + # right answer for whether the class constexpr defines __iter__, and + # abc.Iterable doesn't work (at least not without some metaclass magic). + try: + iter(x[0]) + return x[0] + except TypeError: + pass + + return x + + +def is_builtin(fn) -> bool: + """Is this a registered triton builtin function?""" + return getattr(fn, TRITON_BUILTIN, False) + + +@builtin +def to_tensor(x, _builder=None): + return semantic.to_tensor(x, _builder) + + +# ----------------------- +# constexpr +# ----------------------- + + +class const: + """ + This class is used as a type annotation to mark pointers to constant data. + The `store` function cannot be called with a pointer to const. Constness + is part of the pointer type and the usual Triton type consistency rules + apply. For example you cannot have a function that returns constant pointer + in one return statement and non-constant pointer in another. + """ + pass + + +class constexpr: + """ + This class is used to store a value that is known at compile-time. + """ + + def __init__(self, value): + if isinstance(value, constexpr): + self.value = value.value + else: + self.value = value + + def __repr__(self) -> str: + return f"constexpr[{self.value}]" + + def __index__(self): + return self.value + + # In interpreter mode, constant values are not wrapped in constexpr, + # and therefore do not have a .value attribute. + # As a result, from here and below, we need to call the _constexpr_to_value + # function to obtain either constexpr.value or the value itself. + def __add__(self, other): + return constexpr(self.value + _constexpr_to_value(other)) + + def __radd__(self, other): + return constexpr(_constexpr_to_value(other) + self.value) + + def __sub__(self, other): + return constexpr(self.value - _constexpr_to_value(other)) + + def __rsub__(self, other): + return constexpr(_constexpr_to_value(other) - self.value) + + def __mul__(self, other): + return constexpr(self.value * _constexpr_to_value(other)) + + def __mod__(self, other): + return constexpr(self.value % _constexpr_to_value(other)) + + def __rmul__(self, other): + return constexpr(_constexpr_to_value(other) * self.value) + + def __truediv__(self, other): + return constexpr(self.value / _constexpr_to_value(other)) + + def __rtruediv__(self, other): + return constexpr(_constexpr_to_value(other) / self.value) + + def __floordiv__(self, other): + return constexpr(self.value // _constexpr_to_value(other)) + + def __rfloordiv__(self, other): + return constexpr(_constexpr_to_value(other) // self.value) + + def __gt__(self, other): + return constexpr(self.value > _constexpr_to_value(other)) + + def __rgt__(self, other): + return constexpr(_constexpr_to_value(other) > self.value) + + def __ge__(self, other): + return constexpr(self.value >= _constexpr_to_value(other)) + + def __rge__(self, other): + return constexpr(_constexpr_to_value(other) >= self.value) + + def __lt__(self, other): + return constexpr(self.value < _constexpr_to_value(other)) + + def __rlt__(self, other): + return constexpr(_constexpr_to_value(other) < self.value) + + def __le__(self, other): + return constexpr(self.value <= _constexpr_to_value(other)) + + def __rle__(self, other): + return constexpr(_constexpr_to_value(other) <= self.value) + + def __eq__(self, other): + return constexpr(self.value == _constexpr_to_value(other)) + + def __ne__(self, other): + return constexpr(self.value != _constexpr_to_value(other)) + + def __bool__(self): + return bool(self.value) + + def __neg__(self): + return constexpr(-self.value) + + def __and__(self, other): + return constexpr(self.value & _constexpr_to_value(other)) + + def logical_and(self, other): + return constexpr(self.value and _constexpr_to_value(other)) + + def __or__(self, other): + return constexpr(self.value | _constexpr_to_value(other)) + + def __xor__(self, other): + return constexpr(self.value ^ _constexpr_to_value(other)) + + def logical_or(self, other): + return constexpr(self.value or _constexpr_to_value(other)) + + def __pos__(self): + return constexpr(+self.value) + + def __invert__(self): + return constexpr(~self.value) + + def __pow__(self, other): + return constexpr(self.value**_constexpr_to_value(other)) + + def __rpow__(self, other): + return constexpr(_constexpr_to_value(other)**self.value) + + def __rshift__(self, other): + return constexpr(self.value >> _constexpr_to_value(other)) + + def __lshift__(self, other): + return constexpr(self.value << _constexpr_to_value(other)) + + def __not__(self): + return constexpr(not self.value) + + def __iter__(self): + return iter(self.value) + + def __call__(self, *args, **kwds): + return self.value(*args, **kwds) + + +CONSTEXPR_0 = constexpr(0) + + +def _unwrap_if_constexpr(o): + return o.value if isinstance(o, constexpr) else o + + +def check_bit_width(value, shift_value): + if isinstance(value, tensor) and isinstance(shift_value, constexpr): + bitwidth = value.type.scalar.primitive_bitwidth + if shift_value.value >= bitwidth: + warn( + f"Value {shift_value.value} exceeds the maximum bitwidth ({bitwidth}) for type '{value.dtype}'. This may result in undefined behavior." + ) + + +# ----------------------- +# dtype +# ----------------------- + + +class dtype: + SINT_TYPES = ['int8', 'int16', 'int32', 'int64'] + UINT_TYPES = ['int1', 'uint8', 'uint16', 'uint32', 'uint64'] + FP_TYPES = ['fp8e4b15', 'fp8e4nv', 'fp8e4b8', 'fp8e5', 'fp8e5b16', 'fp16', 'bf16', 'fp32', 'fp64'] + STANDARD_FP_TYPES = ['fp16', 'bf16', 'fp32', 'fp64'] + OTHER_TYPES = ['void'] + + class SIGNEDNESS(Enum): + SIGNED = 0 + UNSIGNED = 1 + + class KIND(Enum): + BOOLEAN = 0 + INTEGRAL = 1 + FLOATING = 2 + + def __init__(self, name): + name = _unwrap_if_constexpr(name) + self.name = name + assert name in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES, name + if name in dtype.SINT_TYPES: + self.int_signedness = dtype.SIGNEDNESS.SIGNED + self.int_bitwidth = int(name.split('int')[-1]) + self.primitive_bitwidth = self.int_bitwidth + elif name in dtype.UINT_TYPES: + self.int_signedness = dtype.SIGNEDNESS.UNSIGNED + self.int_bitwidth = int(name.split('int')[-1]) + self.primitive_bitwidth = self.int_bitwidth + elif name in dtype.FP_TYPES: + if name == 'fp8e4b15': + self.fp_mantissa_width = 3 + self.primitive_bitwidth = 8 + self.exponent_bias = 15 + elif name == 'fp8e4nv': + self.fp_mantissa_width = 3 + self.primitive_bitwidth = 8 + self.exponent_bias = 7 + elif name == 'fp8e4b8': + self.fp_mantissa_width = 3 + self.primitive_bitwidth = 8 + self.exponent_bias = 8 + elif name == 'fp8e5': + self.fp_mantissa_width = 2 + self.primitive_bitwidth = 8 + self.exponent_bias = 15 + elif name == 'fp8e5b16': + self.fp_mantissa_width = 2 + self.primitive_bitwidth = 8 + self.exponent_bias = 16 + elif name == 'fp16': + self.fp_mantissa_width = 10 + self.primitive_bitwidth = 16 + self.exponent_bias = 15 + elif name == 'bf16': + self.fp_mantissa_width = 7 + self.primitive_bitwidth = 16 + self.exponent_bias = 127 + elif name == 'fp32': + self.fp_mantissa_width = 23 + self.primitive_bitwidth = 32 + self.exponent_bias = 127 + elif name == 'fp64': + self.fp_mantissa_width = 52 + self.primitive_bitwidth = 64 + self.exponent_bias = 1023 + else: + raise RuntimeError(f'Unsupported floating-point type {name}') + elif name == 'void': + self.primitive_bitwidth = 0 + + def is_fp8(self): + return 'fp8' in self.name + + def is_fp8e4nv(self): + return self.name == 'fp8e4nv' + + def is_fp8e4b8(self): + return self.name == 'fp8e4b8' + + def is_fp8e4b15(self): + return self.name == 'fp8e4b15' + + def is_fp8e5(self): + return self.name == 'fp8e5' + + def is_fp8e5b16(self): + return self.name == 'fp8e5b16' + + def is_fp16(self): + return self.name == 'fp16' + + def is_bf16(self): + return self.name == 'bf16' + + def is_fp32(self): + return self.name == 'fp32' + + def is_fp64(self): + return self.name == 'fp64' + + def is_int1(self): + return self.name == 'int1' + + def is_int8(self): + return self.name == 'int8' + + def is_int16(self): + return self.name == 'int16' + + def is_int32(self): + return self.name == 'int32' + + def is_int64(self): + return self.name == 'int64' + + def is_uint8(self): + return self.name == 'uint8' + + def is_uint16(self): + return self.name == 'uint16' + + def is_uint32(self): + return self.name == 'uint32' + + def is_uint64(self): + return self.name == 'uint64' + + def is_floating(self): + return self.name in dtype.FP_TYPES + + def is_standard_floating(self): + return self.name in dtype.STANDARD_FP_TYPES + + def is_int_signed(self): + return self.name in dtype.SINT_TYPES + + def is_int_unsigned(self): + return self.name in dtype.UINT_TYPES + + def is_int(self): + return self.name in dtype.SINT_TYPES + dtype.UINT_TYPES + + def is_bool(self): + return self.is_int1() + + def kind(self): + # Return int value following the type ordering bool < integer < fp + if self.is_bool(): + return dtype.KIND.BOOLEAN + elif self.is_int(): + return dtype.KIND.INTEGRAL + else: + assert self.is_floating() + return dtype.KIND.FLOATING + + def get_int_max_value(self): + if self.is_int_signed(): + return 2**(self.int_bitwidth - 1) - 1 + if self.is_int_unsigned(): + return 2**self.int_bitwidth - 1 + assert False + + def get_int_min_value(self): + if self.is_int_signed(): + return -2**(self.int_bitwidth - 1) + if self.is_int_unsigned(): + return 0 + assert False + + @staticmethod + def is_dtype(type_str): + return type_str in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES + + @staticmethod + def is_void(): + raise RuntimeError("Not implemented") + + @staticmethod + def is_block(): + return False + + @staticmethod + def is_ptr(): + return False + + @staticmethod + def is_const(): + return False + + def __eq__(self, other: dtype): + if not isinstance(other, dtype): + return False + return self.name == other.name + + def __ne__(self, other: dtype): + return not self.__eq__(other) + + def __hash__(self): + return hash((self.name, )) + + @property + def scalar(self): + return self + + def to_ir(self, builder: ir.builder) -> ir.type: + if self.name.startswith("fp8"): + if self.name not in builder.options.supported_fp8_dtypes: + raise ValueError(f'type {self} not supported in this architecture. ' + f'The supported fp8 dtypes are {builder.options.supported_fp8_dtypes}') + if self.name in builder.options.deprecated_fp8_dtypes: + warn(f"{self.name} is deprecated in this architecture and will be removed in a future triton release") + + if self.name == 'void': + return builder.get_void_ty() + elif self.name == 'int1': + return builder.get_int1_ty() + elif self.name in ('int8', 'uint8'): + return builder.get_int8_ty() + elif self.name in ('int16', 'uint16'): + return builder.get_int16_ty() + elif self.name in ('int32', 'uint32'): + return builder.get_int32_ty() + elif self.name in ('int64', 'uint64'): + return builder.get_int64_ty() + elif self.name == 'fp8e5': + return builder.get_fp8e5_ty() + elif self.name == 'fp8e5b16': + return builder.get_fp8e5b16_ty() + elif self.name == 'fp8e4nv': + return builder.get_fp8e4nv_ty() + elif self.name == 'fp8e4b8': + return builder.get_fp8e4b8_ty() + elif self.name == 'fp8e4b15': + return builder.get_fp8e4b15_ty() + elif self.name == 'fp16': + return builder.get_half_ty() + elif self.name == 'bf16': + return builder.get_bf16_ty() + elif self.name == 'fp32': + return builder.get_float_ty() + elif self.name == 'fp64': + return builder.get_double_ty() + raise ValueError(f'fail to convert {self} to ir type') + + def __str__(self): + return self.name + + def codegen_name(self): + if self.name.startswith("fp"): + return "float" + self.name[2:] + elif self.name.startswith("bf"): + return "bfloat" + self.name[2:] + else: + return self.name + + @property + def cache_key_part(self) -> str: + """See cache_key_part() in triton.cc.""" + return self.name + + def __repr__(self): + """Output of repr needs to be an evaluatable expression""" + return f'triton.language.{self.codegen_name()}' + + +# Some functions have a param named `dtype`, which shadows the `dtype` class. +# We can't change the param name because it is part of function's public API. +# Declare an alias so those functions can still reference the dtype class. +_DtypeClass = dtype + + +class pointer_type(dtype): + + def __init__(self, element_ty: dtype, address_space: int = 1, const: bool = False): + element_ty = _unwrap_if_constexpr(element_ty) + if not isinstance(element_ty, dtype): + raise TypeError(f'element_ty has type `{type(element_ty).__name__}`; expected `dtype`.') + self.element_ty = element_ty + self.address_space = address_space + self.const = const + self.name = f'pointer<{element_ty}>' if not const else f'const_pointer<{element_ty}>' + + def to_ir(self, builder: ir.builder) -> ir.pointer_type: + return builder.get_ptr_ty(self.element_ty.to_ir(builder), self.address_space) + + def __str__(self): + return self.name + + def __repr__(self): + return self.__str__() + + def is_ptr(self): + return True + + def is_const(self): + return self.const + + def __eq__(self, other: pointer_type) -> bool: + if not isinstance(other, pointer_type): + return False + return self.element_ty == other.element_ty and self.address_space == other.address_space and self.const == other.const + + def __ne__(self, other: pointer_type) -> bool: + return not self.__eq__(other) + + @property + def scalar(self): + return self + + +class nv_tma_desc_type(pointer_type): + + def __init__(self, const=True, address_space=0): + super().__init__(uint8, const=const, address_space=address_space) + self.name = 'nv_tma_desc_type' + + +class block_type(dtype): + + def __init__(self, element_ty: dtype, shape: List): + self.element_ty = element_ty + + # Note that block_type's shape is a list of int + # while tensor's shape is a list of constexpr. + + # shape can be empty ([]) when an input is a 0D tensor. + self.shape = _unwrap_shape(shape) + if not self.shape: + raise TypeError('0d block_type is forbidden') + + self.numel = validate_block_shape(self.shape) + self.name = f'<{self.shape}, {self.element_ty}>' + + def to_ir(self, builder: ir.builder) -> ir.block_type: + return builder.get_block_ty(self.element_ty.to_ir(builder), self.shape) + + def __str__(self): + return self.name + + def __repr__(self): + return self.__str__() + + def is_block(self): + return True + + def get_block_shapes(self) -> List[int]: + return self.shape + + def __eq__(self, other: block_type) -> bool: + if not isinstance(other, block_type): + return False + return self.element_ty == other.element_ty and self.shape == other.shape + + def __ne__(self, other: block_type) -> bool: + return not self.__eq__(other) + + @property + def scalar(self): + return self.element_ty + + +class function_type(dtype): + + def __init__(self, ret_types: List[dtype], param_types: List[dtype]) -> None: + self.ret_types = ret_types + self.param_types = param_types + + def __str__(self): + return f'fn ({self.param_types}) -> {self.ret_types}' + + def to_ir(self, builder: ir.builder): + ir_param_types = [ty.to_ir(builder) for ty in self.param_types] + ret_types = [ret_type.to_ir(builder) for ret_type in self.ret_types] + return builder.get_function_ty(ir_param_types, ret_types) + + +# scalar types +void = dtype('void') +int1 = dtype('int1') +int8 = dtype('int8') +int16 = dtype('int16') +int32 = dtype('int32') +int64 = dtype('int64') +uint8 = dtype('uint8') +uint16 = dtype('uint16') +uint32 = dtype('uint32') +uint64 = dtype('uint64') +float8e5 = dtype('fp8e5') +float8e5b16 = dtype('fp8e5b16') +float8e4nv = dtype('fp8e4nv') +float8e4b8 = dtype('fp8e4b8') +float8e4b15 = dtype('fp8e4b15') +float16 = dtype('fp16') +bfloat16 = dtype('bf16') +float32 = dtype('fp32') +float64 = dtype('fp64') +# pointer types +pi32_t = pointer_type(int32) + + +def get_int_dtype(bitwidth: int, signed: bool) -> dtype: + if bitwidth == 1: + return int1 + elif bitwidth == 8 and signed: + return int8 + elif bitwidth == 8 and not signed: + return uint8 + elif bitwidth == 16 and signed: + return int16 + elif bitwidth == 16 and not signed: + return uint16 + elif bitwidth == 32 and signed: + return int32 + elif bitwidth == 32 and not signed: + return uint32 + elif bitwidth == 64 and signed: + return int64 + elif bitwidth == 64 and not signed: + return uint64 + else: + raise ValueError(f'Unsupported bitwidth {bitwidth} and signedness {signed}') + + +class _value: + """Base class of values that exist in the triton IR (i.e. not constexprs). + """ + + def __init__(self, handle): + self.handle = handle + + +# ----------------------- +# tensor +# ----------------------- + + +class tensor(_value): + """Represents an N-dimensional array of values or pointers. + + :code:`tensor` is the fundamental data structure in Triton programs. Most + functions in :py:mod:`triton.language` operate on and return tensors. + + Most of the named member functions here are duplicates of the free functions + in :code:`triton.language`. For example, :code:`triton.language.sqrt(x)` is + equivalent to :code:`x.sqrt()`. + + :code:`tensor` also defines most of the magic/dunder methods, so you can + write :code:`x+y`, :code:`x << 2`, etc. + + .. rubric:: Constructors + .. + For some reason Sphinx includes __init__ before printing the full table + of methods. Not what I want, but I can't figure out how to fix it. Give + it its own section so it looks intentional. :) + """ + + def __init__(self, handle, type: dtype): + """Not called by user code.""" + # IR handle + super().__init__(handle) + # Block shape + self.shape = type.shape if type.is_block() else () + self.numel = 1 + for s in self.shape: + self.numel *= s + self.numel = constexpr(self.numel) + self.type = type # Tensor type (can be block_type) + # Following the practice in pytorch, dtype is scalar type + self.dtype = type.scalar + self.shape = [constexpr(s) for s in self.shape] + + def __str__(self) -> str: + # ex. "float32[16, 32]" + return str(self.dtype) + '[' + ', '.join(str(s) for s in self.shape) + ']' + + @builtin + def __add__(self, other, _builder=None): + return add(self, other, sanitize_overflow=False, _builder=_builder) + + @builtin + def __radd__(self, other, _builder=None): + return add(other, self, sanitize_overflow=False, _builder=_builder) + + @builtin + def __sub__(self, other, _builder=None): + return sub(self, other, sanitize_overflow=False, _builder=_builder) + + @builtin + def __rsub__(self, other, _builder=None): + return sub(other, self, sanitize_overflow=False, _builder=_builder) + + @builtin + def __mul__(self, other, _builder=None): + return mul(self, other, sanitize_overflow=False, _builder=_builder) + + @builtin + def __rmul__(self, other, _builder=None): + return mul(other, self, sanitize_overflow=False, _builder=_builder) + + @builtin + def __truediv__(self, other, _builder=None): + other = _unwrap_if_constexpr(other) + return semantic.truediv(self, other, _builder) + + @builtin + def __rtruediv__(self, other, _builder=None): + other = _unwrap_if_constexpr(other) + return semantic.truediv(other, self, _builder) + + @builtin + def __floordiv__(self, other, _builder=None): + other = _unwrap_if_constexpr(other) + return semantic.floordiv(self, other, _builder) + + @builtin + def __rfloordiv__(self, other, _builder=None): + other = _unwrap_if_constexpr(other) + return semantic.floordiv(other, self, _builder) + + @builtin + def __mod__(self, other, _builder=None): + other = _unwrap_if_constexpr(other) + return semantic.mod(self, other, _builder) + + @builtin + def __rmod__(self, other, _builder=None): + other = _unwrap_if_constexpr(other) + return semantic.mod(other, self, _builder) + + # unary operators + @builtin + def __neg__(self, _builder=None): + return semantic.minus(self, _builder) + + @builtin + def __invert__(self, _builder=None): + return semantic.invert(self, _builder) + + # bitwise operators + + @builtin + def __and__(self, other, _builder=None): + other = _unwrap_if_constexpr(other) + return semantic.and_(self, other, _builder) + + @builtin + def __rand__(self, other, _builder=None): + other = _unwrap_if_constexpr(other) + return semantic.and_(other, self, _builder) + + @builtin + def __or__(self, other, _builder=None): + other = _unwrap_if_constexpr(other) + return semantic.or_(self, other, _builder) + + @builtin + def __ror__(self, other, _builder=None): + other = _unwrap_if_constexpr(other) + return semantic.or_(other, self, _builder) + + @builtin + def __xor__(self, other, _builder=None): + other = _unwrap_if_constexpr(other) + return semantic.xor_(self, other, _builder) + + @builtin + def __rxor__(self, other, _builder=None): + other = _unwrap_if_constexpr(other) + return semantic.xor_(other, self, _builder) + + @builtin + def __lshift__(self, other, _builder=None): + check_bit_width(self, other) + other = _unwrap_if_constexpr(other) + return semantic.shl(self, other, _builder) + + @builtin + def __rlshift__(self, other, _builder=None): + check_bit_width(other, self) + other = _unwrap_if_constexpr(other) + return semantic.shl(other, self, _builder) + + @builtin + def __rshift__(self, other, _builder=None): + check_bit_width(self, other) + other = _unwrap_if_constexpr(other) + if self.dtype.is_int_signed(): + return semantic.ashr(self, other, _builder) + else: + return semantic.lshr(self, other, _builder) + + @builtin + def __rrshift__(self, other, _builder=None): + check_bit_width(other, self) + other = _unwrap_if_constexpr(other) + if self.dtype.is_int_signed(): + return semantic.ashr(other, self, _builder) + else: + return semantic.lshr(other, self, _builder) + + # > + @builtin + def __gt__(self, other, _builder=None): + other = semantic.to_tensor(other, _builder) + return semantic.greater_than(self, other, _builder) + + @builtin + def __rgt__(self, other, _builder=None): + other = semantic.to_tensor(other, _builder) + return semantic.greater_than(other, self, _builder) + + # >= + @builtin + def __ge__(self, other, _builder=None): + other = semantic.to_tensor(other, _builder) + return semantic.greater_equal(self, other, _builder) + + @builtin + def __rge__(self, other, _builder=None): + other = semantic.to_tensor(other, _builder) + return semantic.greater_equal(other, self, _builder) + + # < + @builtin + def __lt__(self, other, _builder=None): + other = semantic.to_tensor(other, _builder) + return semantic.less_than(self, other, _builder) + + @builtin + def __rlt__(self, other, _builder=None): + other = semantic.to_tensor(other, _builder) + return semantic.less_than(other, self, _builder) + + # <= + @builtin + def __le__(self, other, _builder=None): + other = semantic.to_tensor(other, _builder) + return semantic.less_equal(self, other, _builder) + + @builtin + def __rle__(self, other, _builder=None): + other = semantic.to_tensor(other, _builder) + return semantic.less_equal(other, self, _builder) + + # == + @builtin + def __eq__(self, other, _builder=None): + other = semantic.to_tensor(other, _builder) + return semantic.equal(self, other, _builder) + + @builtin + def __req__(self, other, _builder=None): + other = semantic.to_tensor(other, _builder) + return semantic.equal(other, self, _builder) + + @builtin + def __ne__(self, other, _builder=None): + other = semantic.to_tensor(other, _builder) + return semantic.not_equal(self, other, _builder) + + @builtin + def __rne__(self, other, _builder=None): + other = semantic.to_tensor(other, _builder) + return semantic.not_equal(other, self, _builder) + + @builtin + def logical_and(self, other, _builder=None): + other = semantic.to_tensor(other, _builder) + return semantic.logical_and(self, other, _builder) + + @builtin + def logical_or(self, other, _builder=None): + other = semantic.to_tensor(other, _builder) + return semantic.logical_or(self, other, _builder) + + # note: __not__ isn't actually a magic method in python + # but it's ok because our ASTVisitor handles it + @builtin + def __not__(self, _builder=None): + return semantic.not_(self, _builder) + + @builtin + def __getitem__(self, slices, _builder=None): + if isinstance(slices, (slice, constexpr)) or slices is None: + slices = [slices] + ret = self + for dim, sl in enumerate(slices): + if sl is None or isinstance(sl, constexpr) and sl.value is None: + ret = semantic.expand_dims(ret, dim, _builder) + elif isinstance(sl, slice) and sl.start is None and sl.stop is None and sl.step is None: + pass + else: + raise ValueError(f"unsupported tensor index: {sl}") + return ret + + @property + def T(self): + """Transposes a 2D tensor.""" + assert False, "Transposition must be created by the AST Visitor" + + @builtin + def to(self, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False, + overflow_mode: Optional[str] = None, _builder=None): + """ + Alias for :py:func:`tensor.cast`. + """ + # Triton doesn't like core functions calling other core functions, so we + # just copy-paste the implementation of cast here. It's not too bad. + dtype = _unwrap_if_constexpr(dtype) + bitcast = _unwrap_if_constexpr(bitcast) + if bitcast: + return semantic.bitcast(self, dtype, _builder) + return semantic.cast(self, dtype, _builder, fp_downcast_rounding, overflow_mode) + + # Type stubs for functions added by the _tensor_member_fn decorator. + # (Unfortunately these can't be created automatically.) + # + # We couldn't write these definitions out even if we wanted to, because some + # of these functions are defined in standard.py. + def broadcast_to(self, *shape) -> tensor: + ... + + def trans(self, *dims) -> tensor: + ... + + def permute(self, *dims) -> tensor: + ... + + def split(self) -> tuple[tensor, tensor]: + ... + + def view(self, *shape) -> tensor: + ... + + def reshape(self, *shape) -> tensor: + ... + + def expand_dims(self, axis) -> tensor: + ... + + def cast(self, dtype, fp_downcast_rounding=None, bitcast=False) -> tensor: + ... + + def store(self, value, mask=None, boundary_check=(), cache_modifier="", eviction_policy="") -> tensor: + ... + + def advance(self, offsets) -> tensor: + ... + + def atomic_cas(self, cmp, val, sem=None, scope=None) -> tensor: + ... + + def atomic_xchg(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_add(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_max(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_min(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_and(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_or(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def atomic_xor(self, val, mask=None, sem=None, scope=None) -> tensor: + ... + + def exp(self) -> tensor: + ... + + def log(self) -> tensor: + ... + + def cos(self) -> tensor: + ... + + def sin(self) -> tensor: + ... + + def sqrt(self) -> tensor: + ... + + def rsqrt(self) -> tensor: + ... + + def abs(self) -> tensor: + ... + + def reduce(self, axis, combine_fn, keep_dims=False) -> tensor: + ... + + def associative_scan(self, axis, combine_fn, reverse=False) -> tensor: + ... + + def gather(self, indices, axis) -> tensor: + ... + + def histogram(self, num_bins) -> tensor: + ... + + def cdiv(self, div) -> tensor: + ... + + def sigmoid(self) -> tensor: + ... + + def softmax(self, ieee_rounding=False) -> tensor: + ... + + def ravel(self) -> tensor: + ... + + def max(self, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False) -> tensor: + ... + + def argmax(self, axis, tie_break_left=True, keep_dims=False) -> tensor: + ... + + def min(self, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False) -> tensor: + ... + + def argmin(self, axis, tie_break_left=True, keep_dims=False) -> tensor: + ... + + def sum(self, axis=None, keep_dims=False) -> tensor: + ... + + def xor_sum(self, axis=None, keep_dims=False) -> tensor: + ... + + def cumsum(self, axis=0, reverse=False) -> tensor: + ... + + def cumprod(self, axis=0, reverse=False) -> tensor: + ... + + def sort(self, dim: constexpr = None, descending: constexpr = CONSTEXPR_0) -> tensor: + ... + + def flip(self, dim=None) -> tensor: + ... + + +def get_bool_env_var(var_name): + v = os.getenv(var_name, "0") + return v == "1" or v == "true" or v == "on" + + +# ----------------------- +# SPMD Programming Model +# ----------------------- +def _constexpr_to_value(v): + if isinstance(v, constexpr): + return v.value + return v + + +@builtin +def program_id(axis, _builder=None): + """ + Returns the id of the current program instance along the given :code:`axis`. + + :param axis: The axis of the 3D launch grid. Must be 0, 1 or 2. + :type axis: int + """ + # if axis == -1: + # pid0 = program_id(0, _builder) + # pid1 = program_id(1, _builder) + # pid2 = program_id(2, _builder) + # npg0 = num_programs(0, _builder) + # npg1 = num_programs(1, _builder) + # return pid0 + pid1*npg0 + pid2*npg0*npg1 + axis = _constexpr_to_value(axis) + return semantic.program_id(axis, _builder) + + +@builtin +def num_programs(axis, _builder=None): + """ + Returns the number of program instances launched along the given :code:`axis`. + + :param axis: The axis of the 3D launch grid. Must be 0, 1 or 2. + :type axis: int + """ + axis = _constexpr_to_value(axis) + return semantic.num_programs(axis, _builder) + + +# ----------------------- +# Block Initialization +# ----------------------- + + +@builtin +def arange(start, end, _builder=None): + start = _constexpr_to_value(start) + end = _constexpr_to_value(end) + return semantic.arange(start, end, _builder) + + +arange.__doc__ = f""" + Returns contiguous values within the half-open interval :code:`[start, + end)`. :code:`end - start` must be less than or equal to + :code:`TRITON_MAX_TENSOR_NUMEL = {TRITON_MAX_TENSOR_NUMEL}` + + :param start: Start of the interval. Must be a power of two. + :type start: int32 + :param end: End of the interval. Must be a power of two greater than + :code:`start`. + :type end: int32 +""" + + +def _unwrap_shape(shape): + shape = _constexpr_to_value(shape) + return [_constexpr_to_value(s) for s in shape] + + +def _shape_check_impl(shape): + shape = _unwrap_shape(shape) + validate_block_shape(shape) + return shape + + +@builtin +def full(shape, value, dtype, _builder=None): + """ + Returns a tensor filled with the scalar value for the given :code:`shape` and :code:`dtype`. + + :param shape: Shape of the new array, e.g., (8, 16) or (8, ) + :type shape: tuple of ints + :param value: A scalar value to fill the array with + :type value: scalar + :param dtype: Data type of the new array, e.g., :code:`tl.float16` + :type dtype: tl.dtype + """ + shape = _shape_check_impl(shape) + value = _constexpr_to_value(value) + dtype = _constexpr_to_value(dtype) + return semantic.full(shape, value, dtype, _builder) + + +# ----------------------- +# Shape Manipulation +# ----------------------- + + +@builtin +def broadcast(input, other, _builder=None): + """ + Tries to broadcast the two given blocks to a common compatible shape. + + :param input: The first input tensor. + :type input: Block + :param other: The second input tensor. + :type other: Block + """ + return semantic.broadcast_impl_value(input, other, _builder) + + +@_tensor_member_fn +@builtin +def broadcast_to(input, *shape, _builder=None): + """ + Tries to broadcast the given tensor to a new :code:`shape`. + + :param input: The input tensor. + :type input: Block + :param shape: The desired shape. + :type shape: + + :code:`shape` can be passed as a tuple or as individual parameters: :: + + # These are equivalent + broadcast_to(x, (32, 32)) + broadcast_to(x, 32, 32) + """ + shape = _shape_check_impl(_unwrap_iterable(shape)) + return semantic.broadcast_impl_shape(input, shape, _builder) + + +@_tensor_member_fn +@builtin +def trans(input: tensor, *dims, _builder=None): + """ + Permutes the dimensions of a tensor. + + If the parameter :code:`dims` is not specified, the function defaults to a (1,0) permutation, + effectively transposing a 2D tensor. + + :param input: The input tensor. + :param dims: The desired ordering of dimensions. For example, + :code:`(2, 1, 0)` reverses the order dims in a a 3D tensor. + + :code:`dims` can be passed as a tuple or as individual parameters: :: + + # These are equivalent + trans(x, (2, 1, 0)) + trans(x, 2, 1, 0) + + :py:func:`permute` is equivalent to this function, except it doesn't + have the special case when no permutation is specified. + """ + if not dims: + dims = (1, 0) + dims = _unwrap_iterable(dims) + return semantic.permute(input, dims, _builder) + + +@_tensor_member_fn +@builtin +def permute(input, *dims, _builder=None): + """ + Permutes the dimensions of a tensor. + + :param input: The input tensor. + :type input: Block + :param dims: The desired ordering of dimensions. For example, + :code:`(2, 1, 0)` reverses the order dims in a a 3D tensor. + + :code:`dims` can be passed as a tuple or as individual parameters: :: + + # These are equivalent + permute(x, (2, 1, 0)) + permute(x, 2, 1, 0) + + :py:func:`trans` is equivalent to this function, except when + :code:`dims` is empty, it tries to do a (1,0) permutation. + """ + dims = _unwrap_iterable(dims) + return semantic.permute(input, dims, _builder) + + +@builtin +def cat(input, other, can_reorder=False, _builder=None): + """ + Concatenate the given blocks + + :param input: The first input tensor. + :type input: Tensor + :param other: The second input tensor. + :type other: Tensor + :param reorder: Compiler hint. If true, the compiler is + allowed to reorder elements while concatenating inputs. Only use if the + order does not matter (e.g., result is only used in reduction ops). + Current implementation of `cat` supports only can_reorder=True. + """ + return semantic.cat(input, other, can_reorder, _builder) + + +@builtin +def join(a, b, _builder=None): + """ + Join the given tensors in a new, minor dimension. + + For example, given two tensors of shape (4,8), produces a new tensor of + shape (4,8,2). Given two scalars, returns a tensor of shape (2). + + The two inputs are broadcasted to be the same shape. + + If you want to join more than two elements, you can use multiple calls to + this function. This reflects the constraint in Triton that tensors must + have power-of-two sizes. + + join is the inverse of split. + + :param a: The first input tensor. + :type a: Tensor + :param b: The second input tensor. + :type b: Tensor + """ + return semantic.join(a, b, _builder) + + +@jit +def _take_first(a, b): + return a + + +@_tensor_member_fn +@builtin +def split(a, _builder=None, _generator=None) -> tuple[tensor, tensor]: + """ + Split a tensor in two along its last dim, which must have size 2. + + For example, given a tensor of shape (4,8,2), produces two tensors of shape + (4,8). Given a tensor of shape (2), returns two scalars. + + If you want to split into more than two pieces, you can use multiple calls + to this function (probably plus calling reshape). This reflects the + constraint in Triton that tensors must have power-of-two sizes. + + split is the inverse of join. + + :param a: The tensor to split. + :type a: Tensor + """ + # If len(a.shape) == 1, i.e. a.shape == [2], we should return two scalars. + # But semantic.split can only handle returning tensors. Work around this by + # expanding the input to shape [1,2] and then reducing the result. + was_rank_1 = len(a.shape) == 1 + if was_rank_1: + a = semantic.expand_dims(a, 0, _builder) + + out_lhs, out_rhs = semantic.split(a, _builder) + + if was_rank_1: + # Currently `reduce` is the best way to convert a tensor of shape [1] to a scalar. + out_lhs = typing.cast(tensor, reduce(out_lhs, None, _take_first, _builder=_builder, _generator=_generator)) + out_rhs = typing.cast(tensor, reduce(out_rhs, None, _take_first, _builder=_builder, _generator=_generator)) + + return out_lhs, out_rhs + + +@_tensor_member_fn +@builtin +def view(input, *shape, _builder=None): + """ + Returns a tensor with the same elements as `input` but a different shape. + The order of the elements may not be preserved. + + :param input: The input tensor. + :type input: Block + :param shape: The desired shape. + + :code:`shape` can be passed as a tuple or as individual parameters: :: + + # These are equivalent + view(x, (32, 32)) + view(x, 32, 32) + """ + warn("view is deprecated, please use reshape with can_reorder being true.") + shape = _shape_check_impl(_unwrap_iterable(shape)) + return semantic.reshape(input, shape, can_reorder=False, builder=_builder) + + +@_tensor_member_fn +@builtin +def reshape(input, *shape, can_reorder=False, _builder=None): + """ + Returns a tensor with the same number of elements as input but with the + provided shape. + + :param input: The input tensor. + :type input: Block + :param shape: The new shape. + + :code:`shape` can be passed as a tuple or as individual parameters: :: + + # These are equivalent + reshape(x, (32, 32)) + reshape(x, 32, 32) + """ + shape = _shape_check_impl(_unwrap_iterable(shape)) + return semantic.reshape(input, shape, can_reorder, _builder) + + +def _wrap_axis(axis, ndim): + if not (-ndim <= axis < ndim): + raise ValueError(f"invalid axis {axis}. Expected {-ndim} <= axis < {ndim}") + + return axis if axis >= 0 else axis + ndim + + +@_tensor_member_fn +@builtin +def expand_dims(input, axis, _builder=None): + """ + Expand the shape of a tensor, by inserting new length-1 dimensions. + + Axis indices are with respect to the resulting tensor, so + ``result.shape[axis]`` will be 1 for each axis. + + :param input: The input tensor. + :type input: tl.tensor + :param axis: The indices to add new axes + :type axis: int | Sequence[int] + + """ + input = semantic.to_tensor(input, _builder) + axis = _constexpr_to_value(axis) + axes = list(axis) if isinstance(axis, Sequence) else [axis] + new_ndim = len(input.shape) + len(axes) + axes = [_wrap_axis(_constexpr_to_value(d), new_ndim) for d in axes] + + if len(set(axes)) != len(axes): + raise ValueError(f"expand_dims received duplicate axes, normalized axes = {axes}") + + ret = input + for a in sorted(axes): + ret = semantic.expand_dims(ret, a, _builder) + return ret + + +@_tensor_member_fn +@builtin +def cast(input, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False, _builder=None): + """ + Casts a tensor to the given :code:`dtype`. + + :param dtype: The target data type. + :type dtype: tl.dtype + :param fp_downcast_rounding: The rounding mode for downcasting + floating-point values. This parameter is only used when self is a + floating-point tensor and dtype is a floating-point type with a + smaller bitwidth. Supported values are :code:`"rtne"` (round to + nearest, ties to even) and :code:`"rtz"` (round towards zero). + :type fp_downcast_rounding: str, optional + :param bitcast: If true, the tensor is bitcasted to the given + :code:`dtype`, instead of being numerically casted. + :type bitcast: bool, optional + """ + input = semantic.to_tensor(input, _builder) + if isinstance(bitcast, constexpr): + bitcast = bitcast.value + if bitcast: + return semantic.bitcast(input, dtype, _builder) + return semantic.cast(input, dtype, _builder, fp_downcast_rounding) + + +# ----------------------- +# Linear Algebra +# ----------------------- + + +@builtin +def dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_imprecise_acc=None, out_dtype=float32, + _builder=None): + """ + Returns the matrix product of two blocks. + + The two blocks must both be two-dimensional or three-dimensional and have compatible inner dimensions. + For three-dimensional blocks, `tl.dot` performs the batched matrix product, + where the first dimension of each block represents the batch dimension. + + :param input: The first tensor to be multiplied. + :type input: 2D or 3D tensor of scalar-type in {:code:`int8`, :code: `float8_e5m2`, :code:`float16`, :code:`bfloat16`, :code:`float32`} + :param other: The second tensor to be multiplied. + :type other: 2D or 3D tensor of scalar-type in {:code:`int8`, :code: `float8_e5m2`, :code:`float16`, :code:`bfloat16`, :code:`float32`} + :param acc: The accumulator tensor. If not None, the result is added to this tensor. + :type acc: 2D or 3D tensor of scalar-type in {:code:`float16`, :code:`float32`, :code:`int32`} + :param input_precision: How to exercise the Tensor Cores for f32 x f32. If + the device does not have Tensor Cores or the inputs are not of dtype f32, + this option is ignored. For devices that do have tensor cores, the + default precision is tf32. + :type input_precision: string. Available options for nvidia: :code:`"tf32"`, :code:`"tf32x3"`, :code:`"ieee"`. Default: :code:`"tf32"`. Avaliable options for amd: :code:`"ieee"`. + :param allow_tf32: *Deprecated.* If true, input_precision is set to "tf32". + Only one of :code:`input_precision` and :code:`allow_tf32` can be + specified (i.e. at least one must be :code:`None`). + """ + assert input_precision is None or allow_tf32 is None, "Only one of input_precision and allow_tf32 can be specified" + assert not allow_tf32, "allow_tf32 is deprecated, please use input_precision='hf32' on Ascend instead." + if input_precision is None: + supports_tf32 = _builder and "tf32" in _builder.options.allowed_dot_input_precisions + default_precision = "tf32" if (supports_tf32 and (allow_tf32 or allow_tf32 is None)) else "ieee" + input_precision = os.getenv("TRITON_F32_DEFAULT", default_precision) + else: + assert input_precision not in [ + "tf32", "tf32x3" + ], "input_precision == tf32 or tf32x3 is invalid, please use input_precision='hf32' on Ascend instead." + input_precision = _constexpr_to_value(input_precision) + out_dtype = _constexpr_to_value(out_dtype) + max_num_imprecise_acc = _constexpr_to_value(max_num_imprecise_acc) + return semantic.dot(input, other, acc, input_precision, max_num_imprecise_acc, out_dtype, _builder) + + +@builtin +def dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None, out_dtype=float32, lhs_k_pack=True, + rhs_k_pack=True, _builder=None): + """ + Returns the matrix product of two blocks in microscaling format. + lhs and rhs use microscaling formats described here: + https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf + :param lhs: The first tensor to be multiplied. + :type lhs: 2D tensor of f8, f6 or f4 format packed in int32 format. + :param lhs_scale: Scale factor for lhs tensor. + :type lhs_scale: ue8m0 float8 type (currently represented as an int8 tensor). + :param lhs_format: format of the lhs tensor, available formats: {:code:`e4m3`, :code: `e5m2`, :code:`e2m3`, :code:`e3m2`, :code:`e2m1`}. + :param rhs: The second tensor to be multiplied. + :type rhs: 2D tensor of f8, f6 or f4 format packed in int32 format. + :param rhs_scale: Scale factor for rhs tensor. + :type rhs_scale: ue8m0 float8 type (currently represented as an int8 tensor). + :param rhs_format: format of the rhs tensor, available formats: {:code:`e4m3`, :code: `e5m2`, :code:`e2m3`, :code:`e3m2`, :code:`e2m1`}. + :param acc: The accumulator tensor. If not None, the result is added to this tensor. + """ + out_dtype = _constexpr_to_value(out_dtype) + assert out_dtype == float32, "Only float32 is supported for out_dtype at the moment" + return semantic.dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc, out_dtype, lhs_k_pack, + rhs_k_pack, _builder) + + +# ----------------------- +# Non-Atomic Memory Operations +# ----------------------- + + +@builtin +def load(pointer, mask=None, other=None, boundary_check=(), padding_option="", cache_modifier="", eviction_policy="", + volatile=False, care_padding=True, _builder=None): + """ + Return a tensor of data whose values are loaded from memory at location defined by `pointer`: + + (1) If `pointer` is a single element pointer, a scalar is be loaded. In + this case: + + - `mask` and `other` must also be scalars, + - `other` is implicitly typecast to `pointer.dtype.element_ty`, and + - `boundary_check` and `padding_option` must be empty. + + (2) If `pointer` is an N-dimensional tensor of pointers, an + N-dimensional tensor is loaded. In this case: + + - `mask` and `other` are implicitly broadcast to `pointer.shape`, + - `other` is implicitly typecast to `pointer.dtype.element_ty`, and + - `boundary_check` and `padding_option` must be empty. + + (3) If `pointer` is a block pointer defined by `make_block_ptr`, a + tensor is loaded. In this case: + + - `mask` and `other` must be `None`, and + - `boundary_check` and `padding_option` can be specified to control the behavior of out-of-bound access. + + :param pointer: Pointer to the data to be loaded + :type pointer: `triton.PointerType`, or block of `dtype=triton.PointerType` + :param mask: if `mask[idx]` is false, do not load the data at address `pointer[idx]` + (must be `None` with block pointers) + :type mask: Block of `triton.int1`, optional + :param other: if `mask[idx]` is false, return `other[idx]` + :type other: Block, optional + :param boundary_check: tuple of integers, indicating the dimensions which should do the boundary check + :type boundary_check: tuple of ints, optional + :param padding_option: should be one of {"", "zero", "nan"}, the padding value to use while out of bounds. "" means an undefined value. + :param cache_modifier: changes cache option in NVIDIA PTX + :type cache_modifier: str, optional, should be one of {"", "ca", "cg"}, where "ca" stands for + cache at all levels and "cg" stands for cache at global level (cache in L2 and below, not L1), see + `cache operator `_ for more details. + :param eviction_policy: changes eviction policy in NVIDIA PTX + :type eviction_policy: str, optional + :param volatile: changes volatile option in NVIDIA PTX + :type volatile: bool, optional + :param care_padding: represents whether user cares about padding value or not, default is True, works as below: + 1. if 'other' is not None, 'care_padding' takes no effect. + 2. if 'other' is None and 'care_padding' = True, loaded tensor will fill zeroes on masked places. + 3. if 'other' is None and 'care_padding' = False, masked places on loaded tensor will be random values, and tl.load may have a better performence. + :type care_padding: bool, optional + """ + # `mask` and `other` can be constexpr + mask = _constexpr_to_value(mask) + other = _constexpr_to_value(other) + if mask is not None: + mask = semantic.to_tensor(mask, _builder) + if other is not None: + other = semantic.to_tensor(other, _builder) + padding_option = _constexpr_to_value(padding_option) + cache_modifier = _constexpr_to_value(cache_modifier) + eviction_policy = _constexpr_to_value(eviction_policy) + volatile = _constexpr_to_value(volatile) + care_padding = _constexpr_to_value(care_padding) + return semantic.load(pointer, mask, other, boundary_check, padding_option, cache_modifier, eviction_policy, + volatile, care_padding, _builder) + + +@builtin +def _experimental_descriptor_load(desc_pointer, offsets, shape, dtype, _builder=None): + """ + Experimental feature to access TMA descriptors loads. This is an escape hatch to easily exercise TTGIR operations. + This will be removed in the future and shouldn't be used in production code. + + This loads a tensor of data based on the descriptor and offsets. + """ + type = block_type(_constexpr_to_value(dtype), shape) + return semantic.descriptor_load(desc_pointer, offsets, "", "", type, _builder) + + +@builtin +def _experimental_descriptor_store(desc_pointer, value, offsets, _builder=None): + """ + Experimental feature to access TMA descriptors stores. This is an escape hatch to easily exercise TTGIR operations. + This will be removed in the future and shouldn't be used in production code. + + This stores a tensor of data based on the descriptor and offsets. + """ + return semantic.descriptor_store(desc_pointer, value, offsets, _builder) + + +@_tensor_member_fn +@builtin +def store(pointer, value, mask=None, boundary_check=(), cache_modifier="", eviction_policy="", _builder=None): + """ + Store a tensor of data into memory locations defined by `pointer`. + + (1) If `pointer` is a single element pointer, a scalar is stored. In + this case: + + - `mask` must also be scalar, and + - `boundary_check` and `padding_option` must be empty. + + (2) If `pointer` is an N-dimensional tensor of pointers, an + N-dimensional block is stored. In this case: + + - `mask` is implicitly broadcast to `pointer.shape`, and + - `boundary_check` must be empty. + + (3) If `pointer` is a block pointer defined by `make_block_ptr`, a block + of data is stored. In this case: + + - `mask` must be None, and + - `boundary_check` can be specified to control the behavior of out-of-bound access. + + `value` is implicitly broadcast to `pointer.shape` and typecast to `pointer.dtype.element_ty`. + + :param pointer: The memory location where the elements of `value` are stored + :type pointer: `triton.PointerType`, or block of `dtype=triton.PointerType` + :param value: The tensor of elements to be stored + :type value: Block + :param mask: If `mask[idx]` is false, do not store `value[idx]` at `pointer[idx]` + :type mask: Block of triton.int1, optional + :param boundary_check: tuple of integers, indicating the dimensions which should do the boundary check + :type boundary_check: tuple of ints, optional + :param cache_modifier: changes cache option in NVIDIA PTX + :type cache_modifier: str, optional, should be one of {"", ".wb", ".cg", ".cs", ".wt"}, where ".wb" stands for + cache write-back all coherent levels, ".cg" stands for cache global, ".cs" stands for cache streaming, ".wt" + stands for cache write-through, see `cache operator `_ for more details. + :param eviction_policy: changes eviction policy in NVIDIA PTX + :type eviction_policy: str, optional, should be one of {"", "evict_first", "evict_last"} + """ + # `value` can be constexpr + value = semantic.to_tensor(value, _builder) + mask = _constexpr_to_value(mask) + if mask is not None: + mask = semantic.to_tensor(mask, _builder) + cache_modifier = _constexpr_to_value(cache_modifier) + eviction_policy = _constexpr_to_value(eviction_policy) + return semantic.store(pointer, value, mask, boundary_check, cache_modifier, eviction_policy, _builder) + + +@builtin +def make_block_ptr(base: tensor, shape, strides, offsets, block_shape, order, _builder=None): + """ + Returns a pointer to a block in a parent tensor + + :param base: The base pointer to the parent tensor + :param shape: The shape of the parent tensor + :param strides: The strides of the parent tensor + :param offsets: The offsets to the block + :param block_shape: The shape of the block + :param order: The order of the original data format + """ + return semantic.make_block_ptr(base, shape, strides, offsets, block_shape, order, _builder) + + +@_tensor_member_fn +@builtin +def advance(base, offsets, _builder=None): + """ + Advance a block pointer + + :param base: the block pointer to advance + :param offsets: the offsets to advance, a tuple by dimension + """ + return semantic.advance(base, offsets, _builder) + + +# ----------------------- +# Atomic Memory Operations +# ----------------------- + + +def _add_atomic_docstr(name: str, has_cmp: bool = False) -> Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = f""" + Performs an atomic {name} at the memory location specified by :code:`pointer`. + + Return the data stored at :code:`pointer` before the atomic operation. + + :param pointer: The memory locations to operate on + :type pointer: Block of dtype=triton.PointerDType""" + if has_cmp: + docstr += """ + :param cmp: The values expected to be found in the atomic object + :type cmp: Block of dtype=pointer.dtype.element_ty""" + docstr += """ + :param val: The values with which to perform the atomic operation + :type val: Block of dtype=pointer.dtype.element_ty + :param sem: Specifies the memory semantics for the operation. Acceptable values are "acquire", + "release", "acq_rel" (stands for "ACQUIRE_RELEASE"), and "relaxed". If not provided, + the function defaults to using "acq_rel" semantics. + :type sem: str, optional + :param scope: Defines the scope of threads that observe the synchronizing effect of the atomic operation. + Acceptable values are "gpu" (default), "cta" (cooperative thread array, thread block), or "sys" (stands for "SYSTEM"). The default value is "gpu". + :type scope: str, optional + """ + func.__doc__ = docstr + return func + + return _decorator + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("compare-and-swap", has_cmp=True) +def atomic_cas(pointer, cmp, val, sem=None, scope=None, _builder=None): + cmp = semantic.to_tensor(cmp, _builder) + val = semantic.to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + return semantic.atomic_cas(pointer, cmp, val, sem, scope, _builder) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("exchange") +def atomic_xchg(pointer, val, mask=None, sem=None, scope=None, _builder=None): + val = semantic.to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + mask = _constexpr_to_value(mask) + return semantic.atomic_xchg(pointer, val, mask, sem, scope, _builder) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("add") +def atomic_add(pointer, val, mask=None, sem=None, scope=None, _builder=None): + val = semantic.to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + mask = _constexpr_to_value(mask) + return semantic.atomic_add(pointer, val, mask, sem, scope, _builder) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("max") +def atomic_max(pointer, val, mask=None, sem=None, scope=None, _builder=None): + val = semantic.to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + mask = _constexpr_to_value(mask) + return semantic.atomic_max(pointer, val, mask, sem, scope, _builder) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("min") +def atomic_min(pointer, val, mask=None, sem=None, scope=None, _builder=None): + val = semantic.to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + mask = _constexpr_to_value(mask) + return semantic.atomic_min(pointer, val, mask, sem, scope, _builder) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("logical and") +def atomic_and(pointer, val, mask=None, sem=None, scope=None, _builder=None): + val = semantic.to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + mask = _constexpr_to_value(mask) + return semantic.atomic_and(pointer, val, mask, sem, scope, _builder) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("logical or") +def atomic_or(pointer, val, mask=None, sem=None, scope=None, _builder=None): + val = semantic.to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + mask = _constexpr_to_value(mask) + return semantic.atomic_or(pointer, val, mask, sem, scope, _builder) + + +@_tensor_member_fn +@builtin +@_add_atomic_docstr("logical xor") +def atomic_xor(pointer, val, mask=None, sem=None, scope=None, _builder=None): + val = semantic.to_tensor(val, _builder) + sem = _constexpr_to_value(sem) + scope = _constexpr_to_value(scope) + mask = _constexpr_to_value(mask) + return semantic.atomic_xor(pointer, val, mask, sem, scope, _builder) + + +# ----------------------- +# Conditioning +# ----------------------- + + +@builtin +def where(condition, x, y, _builder=None): + """ + Returns a tensor of elements from either :code:`x` or :code:`y`, depending on :code:`condition`. + + Note that :code:`x` and :code:`y` are always evaluated regardless of the value of :code:`condition`. + + If you want to avoid unintended memory operations, use the :code:`mask` arguments in `triton.load` and `triton.store` instead. + + The shape of :code:`x` and :code:`y` are both broadcast to the shape of :code:`condition`. + :code:`x` and :code:`y` must have the same data type. + + :param condition: When True (nonzero), yield x, otherwise yield y. + :type condition: Block of triton.bool + :param x: values selected at indices where condition is True. + :param y: values selected at indices where condition is False. + """ + condition = semantic.to_tensor(condition, _builder) + x = _unwrap_if_constexpr(x) + y = _unwrap_if_constexpr(y) + return semantic.where(condition, x, y, _builder) + + +# ----------------------- +# Math +# ----------------------- + + +@builtin +def add(x, y, sanitize_overflow: constexpr = True, _builder=None): + x = _unwrap_if_constexpr(x) + y = _unwrap_if_constexpr(y) + return semantic.add(x, y, sanitize_overflow, _builder) + + +@builtin +def sub(x, y, sanitize_overflow: constexpr = True, _builder=None): + x = _unwrap_if_constexpr(x) + y = _unwrap_if_constexpr(y) + return semantic.sub(x, y, sanitize_overflow, _builder) + + +@builtin +def mul(x, y, sanitize_overflow: constexpr = True, _builder=None): + x = _unwrap_if_constexpr(x) + y = _unwrap_if_constexpr(y) + return semantic.mul(x, y, sanitize_overflow, _builder) + + +@builtin +def minimum(x, y, propagate_nan: constexpr = PropagateNan.NONE, _builder=None): + """ + Computes the element-wise minimum of :code:`x` and :code:`y`. + + :param x: the first input tensor + :type x: Block + :param y: the second input tensor + :type y: Block + :param propagate_nan: whether to propagate NaN values. + :type propagate_nan: tl.PropagateNan + + .. seealso:: :class:`tl.PropagateNan` + """ + x = semantic.to_tensor(x, _builder) + y = semantic.to_tensor(y, _builder) + x = _promote_bfloat16_to_float32(x, _builder=_builder) + y = _promote_bfloat16_to_float32(y, _builder=_builder) + propagate_nan = _constexpr_to_value(propagate_nan) + return semantic.minimum(x, y, propagate_nan, _builder) + + +@builtin +def maximum(x, y, propagate_nan: constexpr = PropagateNan.NONE, _builder=None): + """ + Computes the element-wise maximum of :code:`x` and :code:`y`. + + :param x: the first input tensor + :type x: Block + :param y: the second input tensor + :type y: Block + :param propagate_nan: whether to propagate NaN values. + :type propagate_nan: tl.PropagateNan + + .. seealso:: :class:`tl.PropagateNan` + """ + x = semantic.to_tensor(x, _builder) + y = semantic.to_tensor(y, _builder) + x = _promote_bfloat16_to_float32(x, _builder=_builder) + y = _promote_bfloat16_to_float32(y, _builder=_builder) + propagate_nan = _constexpr_to_value(propagate_nan) + return semantic.maximum(x, y, propagate_nan, _builder) + + +@builtin +def clamp(x, min, max, propagate_nan: constexpr = PropagateNan.NONE, _builder=None): + """ + Clamps the input tensor :code:`x` within the range [min, max]. + Behavior when :code:`min` > :code:`max` is undefined. + + :param x: the input tensor + :type x: Block + :param min: the lower bound for clamping + :type min: Block + :param max: the upper bound for clamping + :type max: Block + :param propagate_nan: whether to propagate NaN values. Applies only to the :code:`x` tensor. + If either :code:`min` or :code:`max` is NaN, the result is undefined. + :type propagate_nan: tl.PropagateNan + + .. seealso:: :class:`tl.PropagateNan` + """ + x = semantic.to_tensor(x, _builder) + min = semantic.to_tensor(min, _builder) + max = semantic.to_tensor(max, _builder) + x = _promote_bfloat16_to_float32(x, _builder=_builder) + min = _promote_bfloat16_to_float32(min, _builder=_builder) + max = _promote_bfloat16_to_float32(max, _builder=_builder) + + propagate_nan = _constexpr_to_value(propagate_nan) + + return semantic.clamp(x, min, max, propagate_nan, _builder) + + +# ----------------------- +# Reductions +# ----------------------- + + +def _add_reduction_docstr(name: str, return_indices_arg: str = None, tie_break_arg: str = None) -> Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = """ + Returns the {name} of all elements in the :code:`input` tensor along the provided :code:`axis` + + :param input: the input values + :type input: Tensor + :param axis: the dimension along which the reduction should be done. If None, reduce all dimensions + :type axis: int + :param keep_dims: if true, keep the reduced dimensions with length 1 + :type keep_dims: bool""" + if return_indices_arg is not None: + docstr += f""" + :param {return_indices_arg}: if true, return index corresponding to the {name} value + :type {return_indices_arg}: bool""" + if tie_break_arg is not None: + docstr += f""" + :param {tie_break_arg}: if true, in case of a tie (i.e., multiple elements have the same {name} value), return the left-most index for values that aren't NaN + :type {tie_break_arg}: bool""" + + func.__doc__ = docstr.format(name=name) + return func + + return _decorator + + +@contextmanager +def _insertion_guard(builder): + ip = builder.get_insertion_point() + yield + builder.restore_insertion_point(ip) + + +@_tensor_member_fn +@builtin +def reduce(input, axis, combine_fn, keep_dims=False, _builder=None, _generator=None): + """Applies the combine_fn to all elements in :code:`input` tensors along the provided :code:`axis` + + :param input: the input tensor, or tuple of tensors + :type input: Tensor + :param axis: the dimension along which the reduction should be done. If None, reduce all dimensions + :type axis: int | None + :param combine_fn: a function to combine two groups of scalar tensors (must be marked with @triton.jit) + :type combine_fn: Callable + :param keep_dims: if true, keep the reduced dimensions with length 1 + :type keep_dims: bool + + """ + if isinstance(input, tensor): + return reduce((input, ), axis, combine_fn, keep_dims=keep_dims, _builder=_builder, _generator=_generator)[0] + + def make_combine_region(reduce_op): + in_scalar_tys = [t.type.scalar for t in input] + prototype = function_type(in_scalar_tys, in_scalar_tys * 2) + + region = reduce_op.get_region(0) + with _insertion_guard(_builder): + param_types = [ty.to_ir(_builder) for ty in prototype.param_types] + block = _builder.create_block_with_parent(region, param_types) + args = [tensor(block.arg(i), ty) for i, ty in enumerate(prototype.param_types)] + results = _generator.call_JitFunction(combine_fn, args, kwargs={}) + if isinstance(results, tensor): + handles = [results.handle] + else: + handles = [r.handle for r in results] + _builder.create_reduce_ret(*handles) + + def expand_ndims(t, ndims): + for _ in builtins.range(ndims): + t = expand_dims(t, 0, _builder=_builder) + return t + + axis = _constexpr_to_value(axis) + keep_dims = _constexpr_to_value(keep_dims) + if axis is not None: + axis = _wrap_axis(axis, len(input[0].shape)) + ret = semantic.reduction(input, axis, make_combine_region, _builder) + if keep_dims: + if axis is not None: + ret = tuple(expand_dims(t, axis, _builder=_builder) for t in ret) + else: + ret = tuple(expand_ndims(t, len(input[0].shape)) for t in ret) + return ret + + +@builtin +def _promote_bfloat16_to_float32(t, _builder=None): + scalar_ty = t.type.scalar + + # hardware doesn't support FMAX, FMIN, CMP for bfloat16 + if scalar_ty is bfloat16: + return t.to(float32, _builder=_builder) + return t + + +@builtin +def _reduce_with_indices(input, axis, combine_fn, keep_dims=False, _builder=None, _generator=None): + axis = _constexpr_to_value(axis) + n = input.shape[axis] + index = arange(0, n, _builder=_builder) + + if len(input.shape) > 1: + # Broadcast index across the non-reduced axes + axes_to_expand = [constexpr(d) for d in builtins.range(len(input.shape))] + del axes_to_expand[axis] + index = expand_dims(index, axes_to_expand, _builder=_builder) + index = broadcast_to(index, input.shape, _builder=_builder) + + rvalue, rindices = reduce((input, index), axis, combine_fn, keep_dims=keep_dims, _builder=_builder, + _generator=_generator) + return rvalue, rindices + + +# ----------------------- +# Scans +# ----------------------- + + +def _add_scan_docstr(name: str) -> Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = """ + Returns the {name} of all elements in the :code:`input` tensor along the provided :code:`axis` + + :param input: the input values + :type input: Tensor + :param axis: the dimension along which the scan should be done + :type axis: int""" + func.__doc__ = docstr.format(name=name) + return func + + return _decorator + + +@_tensor_member_fn +@builtin +def associative_scan(input, axis, combine_fn, reverse=False, _builder=None, _generator=None): + """Applies the combine_fn to each elements with a carry in :code:`input` tensors along the provided :code:`axis` and update the carry + + :param input: the input tensor, or tuple of tensors + :type input: Tensor + :param axis: the dimension along which the reduction should be done + :type axis: int + :param combine_fn: a function to combine two groups of scalar tensors (must be marked with @triton.jit) + :type combine_fn: Callable + :param reverse: whether to apply the associative scan in the reverse direction along axis + :type reverse: bool + + """ + if isinstance(input, tensor): + return associative_scan((input, ), axis, combine_fn, reverse, _builder=_builder, _generator=_generator)[0] + + def make_combine_region(scan_op): + in_scalar_tys = [t.type.scalar for t in input] + prototype = function_type(in_scalar_tys, in_scalar_tys * 2) + + region = scan_op.get_region(0) + with _insertion_guard(_builder): + param_types = [ty.to_ir(_builder) for ty in prototype.param_types] + block = _builder.create_block_with_parent(region, param_types) + args = [tensor(block.arg(i), ty) for i, ty in enumerate(prototype.param_types)] + results = _generator.call_JitFunction(combine_fn, args, kwargs={}) + if isinstance(results, tensor): + handles = [results.handle] + else: + handles = [r.handle for r in results] + _builder.create_scan_ret(*handles) + + axis = _constexpr_to_value(axis) + if axis is not None: + axis = _wrap_axis(axis, len(input[0].shape)) + return semantic.associative_scan(input, axis, make_combine_region, reverse, _builder) + + +@_tensor_member_fn +@builtin +def histogram(input, num_bins, _builder=None, _generator=None): + """computes an histogram based on input tensor with num_bins bins, the bins have a width of 1 and start at 0. + + :param input: the input tensor + :type input: Tensor + :param num_bins: number of histogram bins + :type num_bins: int + + """ + num_bins = _constexpr_to_value(num_bins) + return semantic.histogram(input, num_bins, _builder) + + +@_tensor_member_fn +@builtin +def gather(src, index, axis, _builder=None): + """Gather from a tensor along a given dimension. + :param src: the source tensor + :type src: Tensor + :param index: the index tensor + :type index: Tensor + :param axis: the dimension to gather along + :type axis: int + """ + axis = _constexpr_to_value(axis) + return semantic.gather(src, index, axis, _builder) + + +# ----------------------- +# Compiler Hint Ops +# ----------------------- + + +@builtin +def debug_barrier(_builder=None): + ''' + Insert a barrier to synchronize all threads in a block. + ''' + return semantic.debug_barrier(_builder) + + +@builtin +def multiple_of(input, values, _builder=None): + """ + Let the compiler know that the values in :code:`input` are all multiples of :code:`value`. + """ + if isinstance(values, constexpr): + values = [values] + for i, d in enumerate(values): + if not isinstance(d, constexpr): + raise TypeError(f"values element {i} must have type `constexpr`") + if not isinstance(d.value, int): + raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") + values = [x.value for x in values] + return semantic.multiple_of(input, values) + + +@builtin +def max_contiguous(input, values, _builder=None): + """ + Let the compiler know that the `value` first values in :code:`input` are contiguous. + """ + if isinstance(values, constexpr): + values = [values] + for i, d in enumerate(values): + if not isinstance(d, constexpr): + raise TypeError(f"values element {i} must have type `constexpr`") + if not isinstance(d.value, int): + raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") + values = [x.value for x in values] + return semantic.max_contiguous(input, values) + + +@builtin +def max_constancy(input, values, _builder=None): + """ + Let the compiler know that the `value` first values in :code:`input` are constant. + + e.g. if :code:`values` is [4], then each group of 4 values in :code:`input` should all be equal, + for example [0, 0, 0, 0, 1, 1, 1, 1]. + """ + if isinstance(values, constexpr): + values = [values] + for i, d in enumerate(values): + if not isinstance(d, constexpr): + raise TypeError(f"values element {i} must have type `constexpr`") + if not isinstance(d.value, int): + raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") + values = [x.value for x in values] + return semantic.max_constancy(input, values) + + +@builtin +def assume(cond, _builder=None): + ''' + Allow compiler to assume the :code:`cond` is True. + ''' + return semantic.assume(semantic.to_tensor(cond, _builder), _builder) + + +# ----------------------- +# Debugging functions +# ----------------------- + + +@builtin +def static_print(*values, sep: str = " ", end: str = "\n", file=None, flush=False, _builder=None): + ''' + Print the values at compile time. The parameters are the same as the builtin :code:`print`. + + NOTE: Calling the Python builtin :code:`print` is not the same as calling this, it instead maps to :code:`device_print`, + which has special requirements for the arguments. + + .. highlight:: python + .. code-block:: python + + tl.static_print(f"BLOCK_SIZE={BLOCK_SIZE}") + ''' + pass + + +@builtin +def static_assert(cond, msg="", _builder=None): + ''' + Assert the condition at compile time. Does not require that the :code:`TRITON_DEBUG` environment variable + is set. + + .. highlight:: python + .. code-block:: python + + tl.static_assert(BLOCK_SIZE == 1024) + ''' + pass + + +@builtin +def device_print(prefix, *args, hex=False, _builder=None): + ''' + Print the values at runtime from the device. String formatting does not work for runtime values, so you should + provide the values you want to print as arguments. The first value must be a string, all following values must + be scalars or tensors. + + Calling the Python builtin :code:`print` is the same as calling this function, and the requirements for the arguments will match + this function (not the normal requirements for :code:`print`). + + .. highlight:: python + .. code-block:: python + + tl.device_print("pid", pid) + print("pid", pid) + + On CUDA, printfs are streamed through a buffer of limited size (on one host, + we measured the default as 6912 KiB, but this may not be consistent across + GPUs and CUDA versions). If you notice some printfs are being dropped, you + can increase the buffer size by calling + + .. highlight:: python + .. code-block:: python + + triton.runtime.driver.active.utils.set_printf_fifo_size(size_bytes) + + CUDA may raise an error if you try to change this value after running a + kernel that uses printfs. The value set here may only affect the current + device (so if you have multiple GPUs, you'd need to call it multiple times). + + :param prefix: a prefix to print before the values. This is required to be a string literal. + :param args: the values to print. They can be any tensor or scalar. + :param hex: print all values as hex instead of decimal + ''' + import string + prefix = _constexpr_to_value(prefix) + assert isinstance(prefix, str), f"{prefix} is not string" + b_ascii = True + for ch in prefix: + if ch not in string.printable: + b_ascii = False + break + assert b_ascii, f"{prefix} is not an ascii string" + new_args = [] + for arg in args: + new_args.append(semantic.to_tensor(arg, _builder)) + return semantic.device_print(prefix, new_args, hex, _builder) + + +@builtin +def device_assert(cond, msg="", _builder=None): + ''' + Assert the condition at runtime from the device. Requires that the environment variable :code:`TRITON_DEBUG` + is set to a value besides :code:`0` in order for this to have any effect. + + Using the Python :code:`assert` statement is the same as calling this function, except that the second argument + must be provided and must be a string, e.g. :code:`assert pid == 0, "pid != 0"`. The environment variable must + be set for this :code:`assert` statement to have any effect. + + .. highlight:: python + .. code-block:: python + + tl.device_assert(pid == 0) + assert pid == 0, f"pid != 0" + + :param cond: the condition to assert. This is required to be a boolean tensor. + :param msg: the message to print if the assertion fails. This is required to be a string literal. + ''' + msg = _constexpr_to_value(msg) + return semantic.device_assert(semantic.to_tensor(cond, _builder), msg, _builder) + + +@builtin +def inline_asm_elementwise(asm: str, constraints: str, args: Sequence, dtype: Union[dtype, Sequence[dtype]], + is_pure: bool, pack: int, _builder=None): + ''' + Execute inline assembly over a tensor. Essentially, this is :code:`map` + where the function is inline assembly. + + The input tensors :code:`args` are implicitly broadcasted to the same shape. + + :code:`dtype` can be a tuple of types, in which case the output is a + tuple of tensors. + + Each invocation of the inline asm processes :code:`pack` elements at a + time. Exactly which set of inputs a block receives is unspecified. + Input elements of size less than 4 bytes are packed into 4-byte + registers. + + This op does not support empty :code:`dtype` -- the inline asm must + return at least one tensor, even if you don't need it. You can work + around this by returning a dummy tensor of arbitrary type; it shouldn't + cost you anything if you don't use it. + + Example using + `PTX `_ + assembly: + + .. highlight:: python + .. code-block:: python + + @triton.jit + def kernel(A, B, C, D, BLOCK: tl.constexpr): + a = tl.load(A + tl.arange(0, BLOCK)) # uint8 tensor + b = tl.load(B + tl.arange(0, BLOCK)) # float32 tensor + + # For each (a,b) in zip(a,b), perform the following: + # - Let ai be `a` converted to int32. + # - Let af be `a` converted to float. + # - Let m be the max of ai and b. + # - Return ai and mi. + # Do the above 4 elements at a time. + (c, d) = tl.inline_asm_elementwise( + asm=""" + { + // Unpack `a` into `ai`. + .reg .b8 tmp<4>; + mov.b32 {tmp0, tmp1, tmp2, tmp3}, $8; + cvt.u32.u8 $0, tmp0; + cvt.u32.u8 $1, tmp1; + cvt.u32.u8 $2, tmp2; + cvt.u32.u8 $3, tmp3; + } + // Convert `ai` to float. + cvt.rn.f32.s32 $4, $0; + cvt.rn.f32.s32 $5, $1; + cvt.rn.f32.s32 $6, $2; + cvt.rn.f32.s32 $7, $3; + // Take max of `ai` and `b`. + max.f32 $4, $4, $9; + max.f32 $5, $5, $10; + max.f32 $6, $6, $11; + max.f32 $7, $7, $12; + """, + constraints=( + # 8 output registers, namely + # $0=ai0, $1=ai1, $2=ai2, $3=ai3, + # $4=m0, $5=m1, $6=m2, $7=m3. + "=r,=r,=r,=r,=r,=r,=r,=r," + # 5 input registers, namely + # $8=ai, + # $9=b0, $10=b1, $11=b2, $12=b3. + # The four elements from `a` are all packed into one register. + "r,r,r,r,r"), + args=[a, b], + dtype=(tl.int32, tl.float32), + is_pure=True, + pack=4, + ) + tl.store(C + tl.arange(0, BLOCK), c) + tl.store(D + tl.arange(0, BLOCK), d) + + :param asm: assembly to run. Must match target's assembly format. + :param constraints: asm constraints in + `LLVM format `_ + :param args: the input tensors, whose values are passed to the asm block + :param dtype: the element type(s) of the returned tensor(s) + :param is_pure: if true, the compiler assumes the asm block has no side-effects + :param pack: the number of elements to be processed by one instance of inline assembly + :param _builder: the builder + :return: one tensor or a tuple of tensors of the given dtypes + ''' + asm = _constexpr_to_value(asm) + constraints = _constexpr_to_value(constraints) + pack = _constexpr_to_value(pack) + is_pure = _constexpr_to_value(is_pure) + + # Wrap `dtype` in a tuple if it's not already. + try: + iter(dtype) # type: ignore + has_multiple_outputs = True + except TypeError: + has_multiple_outputs = False + dtype = (dtype, ) # type: ignore + + dtype = typing.cast(Sequence[_DtypeClass], dtype) + + res_tys = dtype + if dispatch_args := [semantic.to_tensor(arg, _builder) for arg in args]: + bin_op_type_checking = partial( + semantic.binary_op_type_checking_impl, + builder=_builder, + arithmetic_check=False, + allow_lhs_ptr=True, + allow_rhs_ptr=True, + ) + broadcast_arg = dispatch_args[0] + # Get the broadcast shape over all the arguments + for item in dispatch_args: + _, broadcast_arg = bin_op_type_checking(item, broadcast_arg) + if broadcast_arg.shape: + # Change the shape of each argument based on the broadcast shape + for i, item in enumerate(dispatch_args): + dispatch_args[i], _ = bin_op_type_checking(item, broadcast_arg) + res_tys = [block_type(dt, broadcast_arg.shape) for dt in dtype] + handles = [t.handle for t in dispatch_args] + call = _builder.create_inline_asm(asm, constraints, handles, [ty.to_ir(_builder) for ty in res_tys], is_pure, pack) + + if not has_multiple_outputs: + return tensor(call.get_result(0), res_tys[0]) + return tuple(tensor(call.get_result(i), ty) for i, ty in enumerate(res_tys)) + + +# ----------------------- +# Iterators +# ----------------------- + + +class static_range: + """ + Iterator that counts upward forever. + + .. highlight:: python + .. code-block:: python + + @triton.jit + def kernel(...): + for i in tl.static_range(10): + ... + :note: This is a special iterator used to implement similar semantics to Python's :code:`range` in the context of + :code:`triton.jit` functions. In addition, it also guides the compiler to unroll the loop aggressively. + :param arg1: the start value. + :param arg2: the end value. + :param step: the step value. + """ + + def __init__(self, arg1, arg2=None, step=None): + assert isinstance(arg1, constexpr), f"{arg1} used as tl.static_range start value is not a constexpr" + if step is None: + self.step = constexpr(1) + else: + assert isinstance(step, constexpr), f"{step} used as tl.static_range step value is not a constexpr" + self.step = step + if arg2 is None: + self.start = constexpr(0) + self.end = arg1 else: - raise ValueError(f"Unknown overflow_mode:{overflow_mode} is found.") + assert isinstance(arg2, constexpr), f"{arg2} used as tl.static_range end value is not a constexpr" + self.start = arg1 + self.end = arg2 + + def __iter__(self): + raise RuntimeError("static_range can only be used in @triton.jit'd functions") + + def __next__(self): + raise RuntimeError("static_range can only be used in @triton.jit'd functions") + + +class range: + """ + Iterator that counts upward forever. + + .. highlight:: python + .. code-block:: python + + @triton.jit + def kernel(...): + for i in tl.range(10, num_stages=3): + ... + :note: This is a special iterator used to implement similar semantics to Python's :code:`range` in the context of + :code:`triton.jit` functions. In addition, it allows user to pass extra attributes to the compiler. + :param arg1: the start value. + :param arg2: the end value. + :param step: the step value. + :param num_stages: pipeline the loop into this many stages (so there are + :code:`num_stages` iterations of the loop in flight at once). + Note this is subtly different than passing :code:`num_stages` as a + kernel argument. The kernel argument only pipelines loads that feed + into :code:`dot` operations, while this attribute tries to pipeline most + (though not all) loads in this loop. + :param loop_unroll_factor: Tells the Triton IR level loop unroller how many + times to unroll a for loop that this range is used with. Less than 2 for + this value implies no unrolling. + :param disallow_acc_multi_buffer: If true, prevent the accumulator of the dot + operation in the loop to be multi-buffered, if applicable. + :param flatten: automatically flatten the loop nest starting at this loop to + create a single flattened loop. The compiler will try to pipeline the + flattened loop which can avoid stage stalling. + :param warp_specialize: Enable automatic warp specialization on the loop. + The compiler will attempt to partition memory, MMA, and vector + operations in the loop into separate async partitions. This will + increase the total number of warps required by the kernel. + :param disable_licm: Tells the compiler it shouldn't hoist loop invariant + code outside the loop. This is often useful to avoid creating long liveranges + within a loop. -def ext_trans_unwrap_iterable(dims): - return _unwrap_iterable(dims) + Note that warp specialization is only supported on Blackwell GPUs and + only works on simple matmul loops. Support for arbitrary loops will be + expanded over time. + """ + def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_factor=None, + disallow_acc_multi_buffer=False, flatten=False, warp_specialize=False, disable_licm=False): + if step is None: + self.step = constexpr(1) + else: + self.step = step + if arg2 is None: + self.start = constexpr(0) + self.end = arg1 + else: + self.start = arg1 + self.end = arg2 + self.num_stages = num_stages + self.loop_unroll_factor = loop_unroll_factor + self.disallow_acc_multi_buffer = disallow_acc_multi_buffer + self.flatten = flatten + self.warp_specialize = warp_specialize + self.disable_licm = disable_licm -def check_dot_deprecated_param_allow_tf32(allow_tf32): - assert (not allow_tf32), "allow_tf32 is deprecated, please use input_precision='hf32' on Ascend instead." + def __iter__(self): + raise RuntimeError("tl.range can only be used in @triton.jit'd functions") + def __next__(self): + raise RuntimeError("tl.range can only be used in @triton.jit'd functions") -def check_dot_invalid_input_precision(input_precision): - assert input_precision not in [ - "tf32", - "tf32x3", - ], "input_precision == tf32 or tf32x3 is invalid, please use input_precision='hf32' on Ascend instead." +# ----------------------- +# Extern functions +# ----------------------- + + +def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_shape: tuple, + is_pure: bool, _builder=None): + ''' + Dispatch a function to a library + :param func: the function to dispatch + :param lib_name: the name of the library + :param lib_path: the path of the library + :param args: the arguments of the function + :param arg_type_symbol_dict: the type of the arguments + :param ret_shape: the shape of the return value + :param _builder: the builder + :return: the return value of the function + ''' + if len(arg_type_symbol_dict) == 0: + raise ValueError("arg_type_symbol_dict is empty") + + num_args = len(list(arg_type_symbol_dict.keys())[0]) + if len(args) != num_args: + raise ValueError(f"length of input args does not match." + f"Expect {len(args)}, got {num_args}") + + arg_types = [] + arg_list = [] + for arg in args: + if isinstance(arg, tensor): + arg_types.append(arg.dtype) + arg_list.append(arg.handle) + else: + arg_types.append(type(arg)) + arg_list.append(arg) + arg_types = tuple(arg_types) -@_tensor_member_fn -@builtin -def gather(src, index, axis, _builder=None): - """Gather from a tensor along a given dimension. - :param src: the source tensor - :type src: Tensor - :param index: the index tensor - :type index: Tensor - :param axis: the dimension to gather along - :type axis: int - """ - axis = _constexpr_to_value(axis) - return semantic_spec.ext_semantic_gather(src, index, axis, _builder) + if arg_types not in arg_type_symbol_dict: + raise ValueError(f"input arg type does not match." + f"Expect one of {arg_type_symbol_dict.keys()}, got {arg_types}") + else: + symbol = arg_type_symbol_dict[arg_types][0] + ret_type = arg_type_symbol_dict[arg_types][1] + if ret_shape: + ret_type = block_type(ret_type, ret_shape) + return tensor(func(lib_name, lib_path, symbol, arg_list, ret_type.to_ir(_builder), is_pure), ret_type) -@_tensor_member_fn @builtin -def insert_slice(ful, sub, offsets, sizes, strides, _builder=None, _generator=None) -> tensor: +def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, is_pure: bool, + _builder=None): + ''' + Dispatch an elementwise function to a library + :param lib_name: the name of the library + :param lib_path: the path of the library + :param args: the arguments of the function + :param arg_type_symbol_dict: the type of the arguments + :param is_pure: whether the function is pure + :param _builder: the builder + :return: the return value of the function + ''' + dispatch_args = args.copy() + all_scalar = True + ret_shape = None + arg_types = [] + for i in builtins.range(len(dispatch_args)): + dispatch_args[i] = semantic.to_tensor(dispatch_args[i], _builder) + arg_types.append(dispatch_args[i].dtype) + if dispatch_args[i].type.is_block(): + all_scalar = False + if len(arg_types) > 0: + arg_types = tuple(arg_types) + arithmetic_check = True + # If there's a type tuple that is not supported by the library, we will do arithmetic check + if arg_types in arg_type_symbol_dict: + arithmetic_check = False + broadcast_arg = dispatch_args[0] + # Get the broadcast shape over all the arguments + for item in dispatch_args: + _, broadcast_arg = semantic.binary_op_type_checking_impl(item, broadcast_arg, _builder, + arithmetic_check=arithmetic_check) + # Change the shape of each argument based on the broadcast shape + for i in builtins.range(len(dispatch_args)): + dispatch_args[i], _ = semantic.binary_op_type_checking_impl(dispatch_args[i], broadcast_arg, _builder, + arithmetic_check=arithmetic_check) + if not all_scalar: + ret_shape = broadcast_arg.shape + func = _builder.create_extern_elementwise + return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_shape, is_pure, _builder) + + +def binary_op_type_legalization(lhs, rhs, builder): + ''' + Convert both operands to a single common type + :param lhs: the left operand + :param rhs: the right operand + :param builder: the builder + ''' + return semantic.binary_op_type_checking_impl(lhs, rhs, builder) + + +def extern(fn): + """A decorator for external functions.""" + return builtin(fn) + + +# ------------------- FIXME:when upgrade to 3.4, del these ------------------- +def _unwrap_if_constexpr_3_4(o): + if isinstance(o, list): + return [_unwrap_if_constexpr_3_4(x) for x in o] + if isinstance(o, builtins.tuple): + return builtins.tuple(_unwrap_if_constexpr_3_4(x) for x in o) + if isinstance(o, tuple): + return tuple(_unwrap_if_constexpr_3_4(x) for x in o) + return o.value if isinstance(o, constexpr) else o + + +def _unwrap_shape_3_4(shape): + shape = _unwrap_if_constexpr_3_4(shape) + return [_unwrap_if_constexpr_3_4(s) for s in shape] + + +def _normalize_tuple(t): + normalized_tuple = _unwrap_if_constexpr_3_4(t) + if isinstance(normalized_tuple, (list, builtins.tuple)): + normalized_tuple = tuple(normalized_tuple) + return normalized_tuple + + +class base_value(_value): + """Base class of values that exist in the triton IR (i.e. not constexprs). """ - Insert a tensor to another tensor as specified by the operation’s offsets, sizes and strides arguments. + type: base_type - :param ful: The tensor to receive tensor. - :type ful: Tensor - :param sub: The tensor to be inserted. - :type sub: Tensor - :param offsets: - :type offsets: tuple of ints - :param sizes: - :type sizes: tuple of ints - :param strides: - :type strides: tuple of ints - """ - assert len(ful.shape) > 0 - assert len(ful.shape) == len(sub.shape) - new_offsets = [semantic.to_tensor(o, _builder) if isinstance(o, constexpr) else o for o in offsets] - out = semantic_spec.ext_semantic_insert_slice(ful, sub, new_offsets, sizes, strides, _builder) - return out + def _flatten_ir(self, handles: List[ir.value]) -> None: + """Flatten frontend value into a sequence of mlir handles, which are appended + to the output list + """ + raise NotImplementedError -@_tensor_member_fn -@builtin -def extract_slice(ful, offsets, sizes, strides, _builder=None, _generator=None) -> tensor: - """ - Extract a tensor from another tensor as specified by the operation’s offsets, sizes and strides arguments. +class base_type: - :param ful: The tensor to split. - :type ful: Tensor - :param offsets: - :type offsets: tuple of ints - :param sizes: - :type sizes: tuple of ints - :param strides: - :type strides: tuple of ints - """ - assert len(ful.shape) > 0 - new_offsets = [semantic.to_tensor(o, _builder) if isinstance(o, constexpr) else o for o in offsets] - sub = semantic_spec.ext_semantic_extract_slice(ful, new_offsets, sizes, strides, _builder) - return sub + def __eq__(self, other): + raise NotImplementedError("Types must implement __eq__") + def __ne__(self, other): + return not (self == other) -@_tensor_member_fn -@builtin -def get_element(src, indice, _builder=None, _generator=None): - """ - get_element op reads a ranked tensor and returns one element as specified by the given indices. - The result of the op is a value with the same type as the elements of the tensor. - The arity of indices must match the rank of the accessed value. + def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[base_value, int]: + """Build a frontend value with the current dtype, wrapping a list of existing handles. + cursor is the index of the first handle relevant to this value, and the function + should return the updated cursor position after any handles consumed by the created value. + """ + raise NotImplementedError - :param src: The tensor to be accessed. - :type src: Tensor - :param indice: - :type indice: tuple of ints - """ - assert len(src.shape) > 0 - new_indice = [semantic.to_tensor(i, _builder) if isinstance(i, constexpr) else i for i in indice] - return semantic_spec.ext_semantic_get_element(src, new_indice, _builder) + def mangle(self) -> str: + raise NotImplementedError(f"NYI: Type mangling for type {self.__class__}") + def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None: + raise NotImplementedError -@builtin -def __add__(self, other, _builder=None): - return add(self, other, sanitize_overflow=False, _builder=_builder) +class tuple(base_value): -@builtin -def __radd__(self, other, _builder=None): - return add(other, self, sanitize_overflow=False, _builder=_builder) + def __init__(self, args: Sequence, type: tuple_type = None): + self.values = [i for i in args] + def get_type(x): + if isinstance(x, dtype): + return dtype + if isinstance(x, (int, float)): + return constexpr + return x.type -@builtin -def __sub__(self, other, _builder=None): - return sub(self, other, sanitize_overflow=False, _builder=_builder) + self.type = type or tuple_type([get_type(x) for x in self.values]) + def __getitem__(self, idx: constexpr): + if isinstance(idx, int): + idx = constexpr(idx) + if isinstance(idx, constexpr): + return self.values[idx] + else: + assert isinstance(idx, (slice, builtins.slice)) + return tuple(self.values[idx.start:idx.stop:idx.step]) -@builtin -def __rsub__(self, other, _builder=None): - return sub(other, self, sanitize_overflow=False, _builder=_builder) + def __getattr__(self, name): + return self.values[self.type.fields.index(name)] + def __setitem__(self, idx: constexpr, value): + if isinstance(idx, int): + idx = constexpr(idx) + assert isinstance(idx, constexpr) + self.values[idx] = value -@builtin -def __mul__(self, other, _builder=None): - return mul(self, other, sanitize_overflow=False, _builder=_builder) + def __add__(self, other): + other = _normalize_tuple(other) + return tuple(self.values + other.values) + def __mul__(self, other): + assert isinstance(other, constexpr) + return tuple(self.values * other.value) -@builtin -def __rmul__(self, other, _builder=None): - return mul(other, self, sanitize_overflow=False, _builder=_builder) + def __eq__(self, other): + other = _normalize_tuple(other) + return constexpr(self.values == other.values) + def __hash__(self): + return hash(builtins.tuple(self.values)) -@builtin -def __mod__(self, other, _builder=None): - other = _unwrap_if_constexpr(other) - return semantic.mod(self, other, _builder) + def __str__(self): + return str([str(x) for x in self.values]) + def __iter__(self): + return iter(self.values) -@builtin -def __lshift__(self, other, _builder=None): - if self.type.scalar.is_floating(): - raise TypeError(f"unexpected type {self.type.scalar}") - check_bit_width(self, other) - other = _unwrap_if_constexpr(other) - return semantic.shl(self, other, _builder) + def __len__(self): + return len(self.values) + def _flatten_ir(self, handles: List[ir.value]): + for v in self.values: + print("[debug]tuple _flatten_ir: value:", v) + v._flatten_ir(handles) + print("[debug]tuple _flatten_ir: handles:", handles) -@builtin -def __rshift__(self, other, _builder=None): - if self.type.scalar.is_floating(): - raise TypeError(f"unexpected type {self.type.scalar}") - other = _unwrap_if_constexpr(other) - check_bit_width(self, other) - if self.dtype.is_int_signed(): - return semantic.ashr(self, other, _builder) - else: - return semantic.lshr(self, other, _builder) + def __repr__(self): + return f"({' ,'.join(repr(x) for x in self.values)})" -@builtin -def flip(ptr, dim=-1, _builder=None, _generator=None): - try: - dim = int(dim.value) if hasattr(dim, "value") else int(dim) - except Exception as e: - raise TypeError(f"dim must be an integer (or tl.constexpr int), got {dim!r}") from e +class tuple_type(base_type): - dim = len(ptr.shape) - 1 if dim == -1 else dim - return semantic_spec.ext_semantic_flip(ptr, dim, _builder, _generator) + def __init__(self, types, fields=None): + self.types = types + self.fields = fields or [''] * len(types) + self.name = '[' + ','.join([f"{k}:{v}" for k, v in zip(self.fields, self.types)]) + ']' + def __str__(self): + return self.name -@builtin -def compile_hint(ptr, hint_name, hint_val=None, _builder=None): + def __iter__(self): + return iter(self.types) + + def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]): + for ty in self.types: + if not isinstance(ty, constexpr): + ty._flatten_ir_types(builder, out) + + def __getitem__(self, index: int) -> dtype: + return self.types[index] + + def __eq__(self, other): + return type(self) is type(other) and self.types == other.types and self.fields == other.fields + + def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tuple, int]: + values = [] + for ty in self.types: + value, cursor = ty._unflatten_ir(handles, cursor) + values.append(value) + return tuple(values, self), cursor + + def mangle(self): + return 'T' + '_'.join(ty.mangle for ty in self.types) + 'T' + + +class dtype_3_4(base_type): + SINT_TYPES = ['int8', 'int16', 'int32', 'int64'] + UINT_TYPES = ['int1', 'uint8', 'uint16', 'uint32', 'uint64'] + FP_TYPES = ['fp8e4b15', 'fp8e4nv', 'fp8e4b8', 'fp8e5', 'fp8e5b16', 'fp16', 'bf16', 'fp32', 'fp64'] + STANDARD_FP_TYPES = ['fp16', 'bf16', 'fp32', 'fp64'] + OTHER_TYPES = ['void'] + + class SIGNEDNESS(Enum): + SIGNED = 0 + UNSIGNED = 1 + + class KIND(Enum): + BOOLEAN = 0 + INTEGRAL = 1 + FLOATING = 2 + + def __init__(self, name): + name = _unwrap_if_constexpr(name) + self.name = name + assert name in dtype_3_4.SINT_TYPES + dtype_3_4.UINT_TYPES + dtype_3_4.FP_TYPES + dtype_3_4.OTHER_TYPES, name + self.primitive_bitwidth = get_primitive_bitwidth(name) + self.itemsize = self.primitive_bitwidth // 8 + if name in dtype_3_4.SINT_TYPES: + self.int_signedness = dtype_3_4.SIGNEDNESS.SIGNED + self.int_bitwidth = self.primitive_bitwidth + elif name in dtype_3_4.UINT_TYPES: + self.int_signedness = dtype_3_4.SIGNEDNESS.UNSIGNED + self.int_bitwidth = self.primitive_bitwidth + elif name in dtype_3_4.FP_TYPES: + if name == 'fp8e4b15': + self.fp_mantissa_width = 3 + self.exponent_bias = 15 + elif name == 'fp8e4nv': + self.fp_mantissa_width = 3 + self.exponent_bias = 7 + elif name == 'fp8e4b8': + self.fp_mantissa_width = 3 + self.exponent_bias = 8 + elif name == 'fp8e5': + self.fp_mantissa_width = 2 + self.exponent_bias = 15 + elif name == 'fp8e5b16': + self.fp_mantissa_width = 2 + self.exponent_bias = 16 + elif name == 'fp16': + self.fp_mantissa_width = 10 + self.exponent_bias = 15 + elif name == 'bf16': + self.fp_mantissa_width = 7 + self.exponent_bias = 127 + elif name == 'fp32': + self.fp_mantissa_width = 23 + self.exponent_bias = 127 + elif name == 'fp64': + self.fp_mantissa_width = 52 + self.exponent_bias = 1023 + else: + raise RuntimeError(f'Unsupported floating-point type {name}') + + def is_fp8(self): + return 'fp8' in self.name + + def is_fp8e4nv(self): + return self.name == 'fp8e4nv' + + def is_fp8e4b8(self): + return self.name == 'fp8e4b8' + + def is_fp8e4b15(self): + return self.name == 'fp8e4b15' + + def is_fp8e5(self): + return self.name == 'fp8e5' + + def is_fp8e5b16(self): + return self.name == 'fp8e5b16' + + def is_fp16(self): + return self.name == 'fp16' + + def is_bf16(self): + return self.name == 'bf16' + + def is_fp32(self): + return self.name == 'fp32' + + def is_fp64(self): + return self.name == 'fp64' + + def is_int1(self): + return self.name == 'int1' + + def is_int8(self): + return self.name == 'int8' + + def is_int16(self): + return self.name == 'int16' + + def is_int32(self): + return self.name == 'int32' + + def is_int64(self): + return self.name == 'int64' + + def is_uint8(self): + return self.name == 'uint8' + + def is_uint16(self): + return self.name == 'uint16' + + def is_uint32(self): + return self.name == 'uint32' + + def is_uint64(self): + return self.name == 'uint64' + + def is_floating(self): + return self.name in dtype_3_4.FP_TYPES + + def is_standard_floating(self): + return self.name in dtype_3_4.STANDARD_FP_TYPES + + def is_int_signed(self): + return self.name in dtype_3_4.SINT_TYPES + + def is_int_unsigned(self): + return self.name in dtype_3_4.UINT_TYPES + + def is_int(self): + return self.name in dtype_3_4.SINT_TYPES + dtype_3_4.UINT_TYPES + + def is_bool(self): + return self.is_int1() - def _unwrap(val): - return _unwrap_if_constexpr(val) if val else val + def kind(self): + # Return int value following the type ordering bool < integer < fp + if self.is_bool(): + return dtype_3_4.KIND.BOOLEAN + elif self.is_int(): + return dtype_3_4.KIND.INTEGRAL + else: + assert self.is_floating() + return dtype_3_4.KIND.FLOATING + + def get_int_max_value(self): + if self.is_int_signed(): + return 2**(self.int_bitwidth - 1) - 1 + if self.is_int_unsigned(): + return 2**self.int_bitwidth - 1 + assert False + + def get_int_min_value(self): + if self.is_int_signed(): + return -2**(self.int_bitwidth - 1) + if self.is_int_unsigned(): + return 0 + assert False + + @staticmethod + def is_dtype(type_str): + return type_str in dtype_3_4.SINT_TYPES + dtype_3_4.UINT_TYPES + dtype_3_4.FP_TYPES + dtype_3_4.OTHER_TYPES + + @staticmethod + def is_void(): + raise RuntimeError("Not implemented") + + @staticmethod + def is_block(): + return False + + @staticmethod + def is_ptr(): + return False + + @staticmethod + def is_const(): + return False + + def __eq__(self, other: dtype): + if not isinstance(other, dtype): + return False + return self.name == other.name + + def __hash__(self): + return hash((self.name, )) + + @property + def scalar(self): + return self + + def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None: + out.append(self.to_ir(builder)) + + def to_ir(self, builder: ir.builder) -> ir.type: + if self.name.startswith("fp8"): + if self.name not in builder.options.supported_fp8_dtypes: + raise ValueError(f'type {self} not supported in this architecture. ' + f'The supported fp8 dtypes are {builder.options.supported_fp8_dtypes}') + + if self.name == 'void': + return builder.get_void_ty() + elif self.name == 'int1': + return builder.get_int1_ty() + elif self.name in ('int8', 'uint8'): + return builder.get_int8_ty() + elif self.name in ('int16', 'uint16'): + return builder.get_int16_ty() + elif self.name in ('int32', 'uint32'): + return builder.get_int32_ty() + elif self.name in ('int64', 'uint64'): + return builder.get_int64_ty() + elif self.name == 'fp8e5': + return builder.get_fp8e5_ty() + elif self.name == 'fp8e5b16': + return builder.get_fp8e5b16_ty() + elif self.name == 'fp8e4nv': + return builder.get_fp8e4nv_ty() + elif self.name == 'fp8e4b8': + return builder.get_fp8e4b8_ty() + elif self.name == 'fp8e4b15': + return builder.get_fp8e4b15_ty() + elif self.name == 'fp16': + return builder.get_half_ty() + elif self.name == 'bf16': + return builder.get_bf16_ty() + elif self.name == 'fp32': + return builder.get_float_ty() + elif self.name == 'fp64': + return builder.get_double_ty() + raise ValueError(f'fail to convert {self} to ir type') + + def __str__(self): + return self.name + + def codegen_name(self): + if self.name.startswith("fp"): + return "float" + self.name[2:] + elif self.name.startswith("bf"): + return "bfloat" + self.name[2:] + else: + return self.name - hint_name = _constexpr_to_value(hint_name) - assert isinstance(hint_name, str), f"hint name: {hint_name} is not string" - if isinstance(hint_val, list): - hint_val = [_unwrap(val) for val in hint_val] - else: - hint_val = _unwrap(hint_val) - hint_val = _unwrap_if_constexpr(hint_val) if hint_val else hint_val - semantic_spec.ext_semantic_compile_hint(ptr, hint_name, hint_val, _builder) + @property + def cache_key_part(self) -> str: + """See cache_key_part() in triton.cc.""" + return self.name + def __repr__(self): + """Output of repr needs to be an evaluatable expression""" + return f'triton.language.{self.codegen_name()}' -@builtin -def sort(ptr, dim=-1, descending=False, _builder=None): - """ - Triton sort 前端接口 + def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[base_value, int]: + return tensor(handles[cursor], self), cursor + 1 - 参数: - ptr: tl.tensor,输入张量 - dim: int 或 tl.constexpr[int],排序维度 - descending: bool 或 tl.constexpr[bool],是否降序 - _builder: ir.builder,底层 IR 构建器 - 返回: - values: tl.tensor,排序后的值(类型与输入一致) - """ + def mangle(self) -> str: + if self.is_int(): + SIGNED = dtype_3_4.SIGNEDNESS.SIGNED + prefix = 'i' if self.int_signedness == SIGNED else 'u' + return prefix + str(self.int_bitwidth) + if self.is_floating(): + return str(self) + if self.is_void(): + return 'V' + return super().mangle() - try: - dim = int(dim.value) if hasattr(dim, "value") else int(dim) - except Exception as e: - raise TypeError(f"dim must be an integer (or tl.constexpr int), got {dim!r}. Error: {str(e)}") from e + def with_element_ty(self, element_ty: dtype): + assert not self.is_block() + return element_ty - if hasattr(descending, "value"): - descending = bool(descending.value) - else: - descending = bool(descending) - ret = semantic_spec.ext_semantic_sort(ptr, dim, descending, _builder) - base_ty = ptr.type.scalar if hasattr(ptr.type, "scalar") else ptr.type - if base_ty.is_int8() or base_ty.is_int16(): - semantic_spec.ext_semantic_compile_hint(ret, "overflow_mode", constexpr("saturate"), _builder) - return ret +class block_type_3_4(dtype_3_4): + def __init__(self, element_ty: dtype_3_4, shape: List): + self.element_ty = element_ty -@builtin -def multibuffer(src: tensor, size, _builder=None): - """ - Set multi_buffer for an existing tensor - :src: tensor set to bufferize multiple time - :size: number of copies + # Note that block_type's shape is a list of int + # while tensor's shape is a list of constexpr. + assert (isinstance(shape, (list, tuple))) + + # shape can be empty ([]) when an input is a 0D tensor. + self.shape = tuple(_unwrap_shape_3_4(shape)) + if not self.shape: + raise TypeError('0d block_type is forbidden') + + self.numel = validate_block_shape(self.shape) + self.name = f'<{self.shape}, {self.element_ty}>' + + def to_ir(self, builder: ir.builder) -> ir.block_type: + return builder.get_block_ty(self.element_ty.to_ir(builder), self.shape) + + def __str__(self): + return self.name + + def __repr__(self): + return self.__str__() + + def is_block(self): + return True + + def get_block_shapes(self) -> Tuple[int]: + return self.shape + + def with_element_ty(self, scalar_ty: dtype_3_4) -> block_type_3_4: + return block_type_3_4(scalar_ty, self.shape) + + def __eq__(self, other) -> bool: + if not isinstance(other, block_type_3_4): + return False + return self.element_ty == other.element_ty and self.shape == other.shape + + @property + def scalar(self): + return self.element_ty + + def mangle(self) -> str: + elt = self.scalar.mangle() + shape = '_'.join(map(str, self.shape)) + return f'{elt}S{shape}S' + + +class tensor_descriptor_base_type(base_type): + + def __init__(self, block_type: block_type): + self.block_type = block_type + + def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tensor_descriptor_base, int]: + value = tensor_descriptor_base(handles[cursor], self.block_type) + return value, cursor + 1 + + def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None: + is_signed = self.block_type.element_ty.is_int_signed() + out.append(builder.create_tensor_descriptor_type(self.block_type.to_ir(builder), is_signed)) + + def __str__(self) -> str: + # ex. "tensor_descriptor" + return f"tensor_descriptor<{self.block_type}>" + + def __eq__(self, other) -> bool: + if type(other) is not type(self): + return False + return self.block_type == other.block_type + + def __neq__(self, other) -> bool: + return not (self == other) + + def mangle(self) -> str: + return f"TD{self.block_type.mangle()}" + + +class tensor_descriptor_base(base_value): + """" + A tensor descriptor with unknown shape and strides """ - buffer_size = _constexpr_to_value(size) - assert isinstance(buffer_size, int) and buffer_size == 2, f"only support bufferize equals 2" - semantic_spec.ext_semantic_compile_hint(src, "multi_buffer", buffer_size, _builder) + def __init__(self, handle, block_type: block_type_3_4): + """Not called by user code.""" + super().__init__(handle) -@builtin -def sync_block_all(mode, event_id, _builder=None): - mode = _constexpr_to_value(mode) - event_id = _constexpr_to_value(event_id) - assert isinstance(mode, str), f"mode: {mode} is not string" - assert isinstance(event_id, int) and (event_id >= 0) and (event_id < 16), f"event_id: {event_id} should be 0 ~ 15" - assert mode == "all_cube" or mode == "all_vector" or mode == "all", f"ERROR: mode = {mode}, only supports all_cube/all_vector/all" - semantic_spec.ext_semantic_custom_op(_builder, "sync_block_all", mode=mode, event_id=event_id) + self.handle = handle # IR handle + self.type = tensor_descriptor_base_type(block_type) # Tensor type (block_type) + def _flatten_ir(self, handles: List[ir.value]) -> None: + handles.append(self.handle) -@builtin -def sync_block_set(sender, receiver, event_id, _builder=None): - sender = _constexpr_to_value(sender) - receiver = _constexpr_to_value(receiver) - event_id = _constexpr_to_value(event_id) - assert isinstance(sender, str) and (sender == "cube" - or sender == "vector"), f"ERROR: sender = {sender}, only supports cube/vector" - assert isinstance(receiver, str) and (receiver == "cube" or receiver - == "vector"), f"ERROR: receiver = {receiver}, only supports cube/vector" - assert isinstance(event_id, int) and (event_id >= 0) and (event_id < 16), f"event_id: {event_id} should be 0 ~ 15" - if sender == receiver: - raise ValueError(f'Unexpected pair: {sender} -> {receiver}, only supports cube -> vector or vector -> cube') - semantic_spec.ext_semantic_custom_op(_builder, "sync_block_set", sender=sender, event_id=event_id) + @property + def block_type(self): + return self.type.block_type + @property + def block_shape(self): + return self.type.block_type.shape -@builtin -def sync_block_wait(sender, receiver, event_id, _builder=None): - sender = _constexpr_to_value(sender) - receiver = _constexpr_to_value(receiver) - event_id = _constexpr_to_value(event_id) - assert isinstance(sender, str) and (sender == "cube" - or sender == "vector"), f"ERROR: sender = {sender}, only supports cube/vector" - assert isinstance(receiver, str) and (receiver == "cube" or receiver - == "vector"), f"ERROR: receiver = {receiver}, only supports cube/vector" - assert isinstance(event_id, int) and (event_id >= 0) and (event_id < 16), f"event_id: {event_id} should be 0 ~ 15" - if sender == receiver: - raise ValueError(f'Unexpected pair: {sender} -> {receiver}, only supports cube -> vector or vector -> cube') - semantic_spec.ext_semantic_custom_op(_builder, "sync_block_wait", sender=sender, event_id=event_id) + @property + def dtype(self): + return self.type.block_type.element_ty + + def __str__(self) -> str: + return str(self.type) + + @builtin + def load(self, offsets: Sequence[constexpr | tensor], _builder=None) -> tensor: + """Load a block from the descriptor starting at the given element offsets. + + Values outside of the tensor bounds will be filled with zeros. + + :note: Offset must be a multiple of 16-bytes + """ + return semantic.descriptor_load(self, offsets, "", "", _builder) + + @builtin + def store(self, offsets: Sequence[constexpr | tensor], value: tensor, _builder=None) -> tensor: + """Store a block from the descriptor starting at the given element offsets. + + Values outside of the tensor bounds will be ignored. + + :note: Offset must be a multiple of 16-bytes + """ + return semantic.descriptor_store(self, value, offsets, _builder) + + +class tensor_descriptor_type(tensor_descriptor_base_type): + + def __init__(self, block_type: block_type_3_4, shape_type: tuple_type, strides_type: tuple_type): + self.block_type = block_type + self.shape_type = shape_type + self.strides_type = strides_type + + def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tensor_descriptor_base, int]: + handle = handles[cursor] + cursor += 1 + shape, cursor = self.shape_type._unflatten_ir(handles, cursor) + strides, cursor = self.strides_type._unflatten_ir(handles, cursor) + shape = shape.values + strides = strides.values + value = tensor_descriptor(handle, shape, strides, self.block_type) + return value, cursor + + def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None: + super()._flatten_ir_types(builder, out) + self.shape_type._flatten_ir_types(builder, out) + self.strides_type._flatten_ir_types(builder, out) + + def __eq__(self, other): + return super().__eq__(other) and (self.shape_type == other.shape_type) and (self.strides_type + == other.strides_type) + + +class tensor_descriptor(tensor_descriptor_base): + """A descriptor representing a tensor in global memory. + """ + + def __init__(self, handle, shape: List[tensor], strides: List[tensor], block_type: block_type_3_4): + """Not called by user code.""" + # IR handle + super().__init__(handle, block_type) + # Global shape + self.shape = tuple(shape) + self.strides = tuple(strides) + self.type = tensor_descriptor_type( + block_type, + shape_type=self.shape.type, + strides_type=self.strides.type, + ) + + def _flatten_ir(self, handles: List[ir.value]) -> None: + handles.append(self.handle) + self.shape._flatten_ir(handles) + self.strides._flatten_ir(handles) @builtin @@ -372,178 +3424,4 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]): inplace_abs[grid](x, M, N, M_BLOCK, N_BLOCK) """ - return semantic_spec.ext_semantic_make_tensor_descriptor(base, shape, strides, block_shape, _builder) - - -@builtin -def index_select(src: tensor, idx: tensor, bound, lstdim_blksiz, offsets, numels, _builder=None): - """ - Embedding - :src_ptr: - :idx: - """ - bound = _constexpr_to_value(bound) - lstdim_blksiz = _constexpr_to_value(lstdim_blksiz) - return semantic_spec.ext_semantic_embedding_gather(src, idx, bound, lstdim_blksiz, offsets, numels, _builder) - - -def dtype_to_ir(self, builder: ir.builder) -> ir.type: - if not is_compile_on_910_95: - if self.name.startswith("fp8"): - raise ValueError(f'unexpected type fp8.') - - if self.name == 'void': - return builder.get_void_ty() - elif self.name == 'int1': - return builder.get_int1_ty() - elif self.name in ('int8', 'uint8'): - return builder.get_int8_ty() - elif self.name in ('int16', 'uint16'): - return builder.get_int16_ty() - elif self.name in ('int32', 'uint32'): - return builder.get_int32_ty() - elif self.name in ('int64', 'uint64'): - return builder.get_int64_ty() - elif self.name == 'fp8e5': - return builder.get_fp8e5_ty() - elif self.name == 'fp8e5b16': - return builder.get_fp8e5b16_ty() - elif self.name == 'fp8e4nv': - return builder.get_fp8e4nv_ty() - elif self.name == 'fp8e4b8': - return builder.get_fp8e4b8_ty() - elif self.name == 'fp8e4b15': - return builder.get_fp8e4b15_ty() - elif self.name == 'fp16': - return builder.get_half_ty() - elif self.name == 'bf16': - return builder.get_bf16_ty() - elif self.name == 'fp32': - return builder.get_float_ty() - elif self.name == 'fp64': - return builder.get_double_ty() - raise ValueError(f'fail to convert {self} to ir type') - - -@builtin -def dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None, out_dtype=float32, _builder=None, - lhs_k_pack=True, rhs_k_pack=True): - """ - Returns the matrix product of two blocks in microscaling format. - lhs and rhs use microscaling formats described here: - https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf - :param lhs: The first tensor to be multiplied. - :type lhs: 2D tensor of f8, f6 or f4 format packed in int32 format. - :param lhs_scale: Scale factor for lhs tensor. - :type lhs_scale: ue8m0 float8 type (currently represented as an int8 tensor). - :param lhs_format: format of the lhs tensor, available formats: {:code:`e4m3`, :code: `e5m2`, :code:`e2m3`, :code:`e3m2`, :code:`e2m1`}. - :param rhs: The second tensor to be multiplied. - :type rhs: 2D tensor of f8, f6 or f4 format packed in int32 format. - :param rhs_scale: Scale factor for rhs tensor. - :type rhs_scale: ue8m0 float8 type (currently represented as an int8 tensor). - :param rhs_format: format of the rhs tensor, available formats: {:code:`e4m3`, :code: `e5m2`, :code:`e2m3`, :code:`e3m2`, :code:`e2m1`}. - :param acc: The accumulator tensor. If not None, the result is added to this tensor. - """ - out_dtype = _constexpr_to_value(out_dtype) - assert out_dtype == float32, "Only float32 is supported for out_dtype at the moment" - return semantic.dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc, out_dtype, _builder, - lhs_k_pack, rhs_k_pack) - - -class range(): - """ - Iterator that counts upward forever. - - .. highlight:: python - .. code-block:: python - - @triton.jit - def kernel(...): - for i in tl.range(10, num_stages=3): - ... - :note: This is a special iterator used to implement similar semantics to Python's :code:`range` in the context of - :code:`triton.jit` functions. In addition, it allows user to pass extra attributes to the compiler. - :param arg1: the start value. - :param arg2: the end value. - :param step: the step value. - :param num_stages: pipeline the loop into this many stages (so there are - :code:`num_stages` iterations of the loop in flight at once). - - Note this is subtly different than passing :code:`num_stages` as a - kernel argument. The kernel argument only pipelines loads that feed - into :code:`dot` operations, while this attribute tries to pipeline most - (though not all) loads in this loop. - :param loop_unroll_factor: Tells the Triton IR level loop unroller how many - times to unroll a for loop that this range is used with. Less than 2 for - this value implies no unrolling. - :param disallow_acc_multi_buffer: If true, prevent the accumulator of the dot - operation in the loop to be multi-buffered, if applicable. - :param flatten: automatically flatten the loop nest starting at this loop to - create a single flattened loop. The compiler will try to pipeline the - flattened loop which can avoid stage stalling. - :param warp_specialize: Enable automatic warp specialization on the loop. - The compiler will attempt to partition memory, MMA, and vector - operations in the loop into separate async partitions. This will - increase the total number of warps required by the kernel. - :param disable_licm: Tells the compiler it shouldn't hoist loop invariant - code outside the loop. This is often useful to avoid creating long liveranges - within a loop. - - Note that warp specialization is only supported on Blackwell GPUs and - only works on simple matmul loops. Support for arbitrary loops will be - expanded over time. - """ - - def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_factor=None, - disallow_acc_multi_buffer=False, flatten=False, warp_specialize=False, disable_licm=False): - if step is None: - self.step = constexpr(1) - else: - self.step = step - if arg2 is None: - self.start = constexpr(0) - self.end = arg1 - else: - self.start = arg1 - self.end = arg2 - self.num_stages = num_stages - self.loop_unroll_factor = loop_unroll_factor - self.disallow_acc_multi_buffer = disallow_acc_multi_buffer - self.flatten = flatten - self.warp_specialize = warp_specialize - self.disable_licm = disable_licm - - def __iter__(self): - raise RuntimeError("tl.range can only be used in @triton.jit'd functions") - - def __next__(self): - raise RuntimeError("tl.range can only be used in @triton.jit'd functions") - - -class parallel(range): - """ - Iterator that counts upward forever, with parallel execution semantics. - - This is a special iterator used to implement similar semantics to Python's :code:`range` in the context of - :code:`triton.jit` functions. In addition, it allows user to pass extra attributes to the compiler. - :param bind_sub_block: Tells the compiler if multiple vector cores participate in the loop. - This is used in the mixed cube-vector kernel on 910B. The number of vector cores is determined by the number of - iteration in this loop. Currently on 910B, max 2 vector cores could be used. - """ - - def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_factor=None, - bind_sub_block: bool = False): - super().__init__(arg1, arg2, step, num_stages, loop_unroll_factor) - self.bind_sub_block = bind_sub_block - - -core_ext_spec_api_list = [ - "gather", "insert_slice", "extract_slice", "get_element", "__add__", "__radd__", "__sub__", "__rsub__", "__mul__", - "__rmul__", "__mod__", "__lshift__", "__rshift__", "compile_hint", "sort", "multibuffer", "sync_block_all", - "sync_block_set", "sync_block_wait", "load_tensor_descriptor", "store_tensor_descriptor", "make_tensor_descriptor", - "dtype_to_ir", "parallel", "index_select", "dot_scaled", "range" -] - -core_tensor_ext_spec_api_list = [ - "__add__", "__radd__", "__sub__", "__rsub__", "__mul__", "__rmul__", "__mod__", "__lshift__", "__rshift__" -] + return semantic.make_tensor_descriptor(base, shape, strides, block_shape, _builder) diff --git a/third_party/ascend/backend/spec/triton/language/math.py b/third_party/ascend/backend/spec/triton/language/math.py index 783ab2dc1..a4899cb45 100644 --- a/third_party/ascend/backend/spec/triton/language/math.py +++ b/third_party/ascend/backend/spec/triton/language/math.py @@ -1,55 +1,301 @@ -import triton.language as language -from . import standard from . import core +from . import semantic +from functools import wraps +from typing import List +import numbers -softmax = standard.softmax -sigmoid = standard.sigmoid -argmax = standard.argmax -argmin = standard.argmin -umulhi = language.extra.ascend.libdevice.umulhi -exp = language.extra.ascend.libdevice.exp -exp2 = language.extra.ascend.libdevice.exp2 -log = language.extra.ascend.libdevice.log -log2 = language.extra.ascend.libdevice.log2 -cos = language.extra.ascend.libdevice.cos -sin = language.extra.ascend.libdevice.sin -sqrt = language.extra.ascend.libdevice.sqrt -sqrt_rn = language.extra.ascend.libdevice.sqrt_rn -rsqrt = language.extra.ascend.libdevice.rsqrt -div_rn = language.extra.ascend.libdevice.div_rn -erf = language.extra.ascend.libdevice.erf -tanh = language.extra.ascend.libdevice.tanh -floor = language.extra.ascend.libdevice.floor -ceil = language.extra.ascend.libdevice.ceil -fma = language.extra.ascend.libdevice.fma -_check_dtype = language.extra.ascend.libdevice._check_dtype -cdiv = language.extra.ascend.libdevice.cdiv - -isnan = language.extra.ascend.libdevice.isnan -isinf = language.extra.ascend.libdevice.isinf -reciprocal = language.extra.ascend.libdevice.reciprocal -relu = language.extra.ascend.libdevice.relu -log1p = language.extra.ascend.libdevice.log1p -tan = language.extra.ascend.libdevice.tan -atan = language.extra.ascend.libdevice.atan -ilogb = language.extra.ascend.libdevice.ilogb -ldexp = language.extra.ascend.libdevice.ldexp -pow = language.extra.ascend.libdevice.pow -flip = core.flip -atan2 = standard.atan2 -rint = standard.rint -finitef = standard.finitef -isfinited = standard.isfinited -div_rz = language.extra.ascend.libdevice.div_rz -fmod = language.extra.ascend.libdevice.fmod -trunc = language.extra.ascend.libdevice.trunc -round = language.extra.ascend.libdevice.round - -math_ext_base_api_list = [ - "umulhi", "exp", "exp2", "log", "log2", "cos", "sin", "sqrt", "sqrt_rn", "rsqrt", "div_rn", "erf", "tanh", "floor", - "ceil", "fma", "_check_dtype", "softmax", "sigmoid", "cdiv", "argmax", "argmin" -] -math_ext_spec_api_list = [ - "isnan", "isinf", "reciprocal", "relu", "log1p", "tan", "atan", "ilogb", "ldexp", "pow", "flip", "atan2", "div_rz", - "fmod", "trunc", "round", "rint", "finitef", "isfinited" -] +T = core.TypeVar('T') + + +def _check_dtype(dtypes: List[str]) -> T: + """ + We're following libdevice's convention to check accepted data types for math functions. + It is not a good practice to support all data types as accelerators/GPUs don't support + many float16 and bfloat16 math operations. + We should let the users know that they are using and invoke explicit cast to convert + the data type to the supported one. + """ + + def wrapper(fn): + + @wraps(fn) + def check(*args, **kwargs): + # concatenate args and kwargs + all_args = list(args) + list(kwargs.values()) + for arg in [a for a in all_args if isinstance(a, core.tensor)]: + arg_type = arg.type.scalar.name + if hasattr(arg, 'was_bool_to_int8') and arg.was_bool_to_int8: + # In Triton, int1 maps to the boolean type + arg_type = 'int1' + if arg_type not in dtypes: + raise ValueError(f"Expected dtype {dtypes} but got {arg_type}") + return fn(*args, **kwargs) + + return check + + return wrapper + + +def _add_math_1arg_docstr(name: str) -> core.Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = """ + Computes the element-wise {name} of :code:`x`. + + :param x: the input values + :type x: Block + """ + func.__doc__ = docstr.format(name=name) + return func + + return _decorator + + +def _add_math_2arg_docstr(name: str) -> core.Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = """ + Computes the element-wise {name} of :code:`x` and :code:`y`. + + :param x: the input values + :type x: Block + :param y: the input values + :type y: Block + """ + func.__doc__ = docstr.format(name=name) + return func + + return _decorator + + +def _add_math_3arg_docstr(name: str) -> core.Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = """ + Computes the element-wise {name} of :code:`x`, :code:`y`, and :code:`z`. + + :param x: the input values + :type x: Block + :param y: the input values + :type y: Block + :param z: the input values + :type z: Block + """ + func.__doc__ = docstr.format(name=name) + return func + + return _decorator + + +@core.builtin +@_check_dtype(dtypes=["int32", "int64", "uint32", "uint64"]) +@_add_math_2arg_docstr("most significant N bits of the 2N-bit product") +def umulhi(x, y, _builder=None): + x = semantic.to_tensor(x, _builder) + y = semantic.to_tensor(y, _builder) + x, y = core.binary_op_type_legalization(x, y, _builder) + return core.tensor(_builder.create_umulhi(x.handle, y.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["bf16", "fp16", "fp32", "fp8e4nv", "fp8e5", "fp64"]) +@_add_math_1arg_docstr("exponential") +@core._tensor_member_fn +def exp(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_exp(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["bf16", "fp16", "fp32", "fp8e4nv", "fp8e5", "fp64"]) +@_add_math_1arg_docstr("exponential (base 2)") +@core._tensor_member_fn +def exp2(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_exp2(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["bf16", "fp16", "fp32", "fp8e4nv", "fp8e5", "fp64"]) +@_add_math_1arg_docstr("natural logarithm") +@core._tensor_member_fn +def log(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_log(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["bf16", "fp16", "fp32", "fp8e4nv", "fp8e5", "fp64"]) +@_add_math_1arg_docstr("logarithm (base 2)") +@core._tensor_member_fn +def log2(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_log2(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["bf16", "fp16", "fp32", "fp8e4nv", "fp8e5", "fp64"]) +@_add_math_1arg_docstr("cosine") +@core._tensor_member_fn +def cos(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_cos(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["bf16", "fp16", "fp32", "fp8e4nv", "fp8e5", "fp64"]) +@_add_math_1arg_docstr("sine") +@core._tensor_member_fn +def sin(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_sin(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["bf16", "fp16", "fp32", "fp8e4nv", "fp8e5", "fp64"]) +@_add_math_1arg_docstr("fast square root") +@core._tensor_member_fn +def sqrt(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_sqrt(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["bf16", "fp16", "fp32", "fp8e4nv", "fp8e5", "fp64"]) +@_add_math_1arg_docstr("precise square root (rounding to nearest wrt the IEEE standard)") +@core._tensor_member_fn +def sqrt_rn(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_precise_sqrt(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["bf16", "fp16", "fp32", "fp8e4nv", "fp8e5", "fp64"]) +@_add_math_1arg_docstr("inverse square root") +@core._tensor_member_fn +def rsqrt(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_rsqrt(x.handle), x.type) + + +@core.builtin +@_add_math_1arg_docstr("absolute value") +@core._tensor_member_fn +def abs(x, _builder=None): + x = semantic.to_tensor(x, _builder) + dtype = x.dtype + if dtype.is_fp8e4b15(): + mask = core.full(x.shape, 0x7F, core.int8, _builder=_builder) + return core.tensor(_builder.create_and(x.handle, mask.handle), x.type) + elif dtype.is_floating(): + return core.tensor(_builder.create_fabs(x.handle), x.type) + elif dtype.is_int_signed(): + return core.tensor(_builder.create_iabs(x.handle), x.type) + elif dtype.is_int_unsigned(): + return x # no-op + else: + assert False, f"Unexpected dtype {dtype}" + + +@core.builtin +@_add_math_2arg_docstr("fast division") +def fdiv(x, y, ieee_rounding=False, _builder=None): + ieee_rounding = core._constexpr_to_value(ieee_rounding) + x = semantic.to_tensor(x, _builder) + y = semantic.to_tensor(y, _builder) + return semantic.fdiv(x, y, ieee_rounding, _builder) + + +@core.builtin +@_check_dtype(dtypes=["bf16", "fp16", "fp32", "fp8e4nv", "fp8e5", "fp64"]) +@_add_math_2arg_docstr("precise division (rounding to nearest wrt the IEEE standard)") +def div_rn(x, y, _builder=None): + x = semantic.to_tensor(x, _builder) + y = semantic.to_tensor(y, _builder) + x, y = core.binary_op_type_legalization(x, y, _builder) + return core.tensor(_builder.create_precise_divf(x.handle, y.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["bf16", "fp16", "fp32", "fp8e4nv", "fp8e5", "fp64"]) +@_add_math_1arg_docstr("error function") +@core._tensor_member_fn +def erf(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_erf(x.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["bf16", "fp16", "fp32", "fp8e4nv", "fp8e5", "fp64"]) +@_add_math_1arg_docstr("floor") +@core._tensor_member_fn +def floor(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_floor(x.handle), x.type) + + +@core.builtin +@_add_math_1arg_docstr("ceil") +@core._tensor_member_fn +def ceil(x, _builder=None): + x = semantic.to_tensor(x, _builder) + if x.type.scalar.is_int(): + return x + elif x.type.scalar.is_floating(): + return core.tensor(_builder.create_ceil(x.handle), x.type) + raise ValueError("ceil does not support boolean type") + + +@core.builtin +@_add_math_3arg_docstr("fused multiply-add") +def fma(x, y, z, _builder=None): + x = semantic.to_tensor(x, _builder) + y = semantic.to_tensor(y, _builder) + z = semantic.to_tensor(z, _builder) + x, y = core.binary_op_type_legalization(x, y, _builder) + z, x = core.binary_op_type_legalization(z, x, _builder) + z, y = core.binary_op_type_legalization(z, y, _builder) + return core.tensor(_builder.create_fma(x.handle, y.handle, z.handle), x.type) + + +@core.builtin +@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) +@_add_math_1arg_docstr("error function") +@core._tensor_member_fn +def tanh(x, _builder=None): + x = semantic.to_tensor(x, _builder) + return core.tensor(_builder.create_tanh(x.handle), x.type) + + +@core.builtin +@_add_math_2arg_docstr("cdiv") +@core._tensor_member_fn +def cdiv(x, div, _builder=None): + if isinstance(x, core.constexpr): + x = x.value + if isinstance(div, core.constexpr): + div = div.value + from math import ceil as py_ceil + if isinstance(x, numbers.Number) and isinstance(div, numbers.Number): + if isinstance(x, bool) or isinstance(div, bool): + raise ValueError("cdiv does not support boolean type") + elif isinstance(x, int) and isinstance(div, int): + res = x // div + rem = x % div + return res + (1 if rem != 0 else 0) + else: + return py_ceil(x / div) + + x = semantic.to_tensor(x, _builder) + div = semantic.to_tensor(div, _builder) + x_scalar_type = x.type.scalar + div_scalar_type = div.type.scalar + if x_scalar_type.is_bool() or div_scalar_type.is_bool(): + raise ValueError("cdiv does not support boolean type") + elif x_scalar_type.is_int() and div_scalar_type.is_int(): + # integer cdiv: (x + div - 1) // div as before + return semantic.floordiv(semantic.add(x, semantic.sub(div, 1, True, _builder), True, _builder), div, _builder) + else: + div_res = semantic.truediv(x, div, _builder) + cdiv_res = core.tensor(_builder.create_ceil(div_res.handle), div_res.type) + return semantic.cast(cdiv_res, x_scalar_type, _builder) diff --git a/third_party/ascend/backend/spec/triton/language/semantic.py b/third_party/ascend/backend/spec/triton/language/semantic.py index 310daa3a8..7be3b70db 100644 --- a/third_party/ascend/backend/spec/triton/language/semantic.py +++ b/third_party/ascend/backend/spec/triton/language/semantic.py @@ -1,233 +1,1641 @@ -from typing import List, Optional, Union, Tuple +from __future__ import annotations # remove after python 3.11 +import warnings + +from typing import List, Optional, Sequence, Tuple, TypeVar import numbers -import triton.language as tl -from triton._C.libtriton import ir -import triton.language.core as core -import triton.language.standard as standard -from triton.language.semantic import (to_tensor, bitcast, wrap_tensor, cast, not_equal, permute, reshape, - _canonicalize_boundary_check) -from triton.language._utils import TRITON_MAX_TENSOR_NUMEL -from .tensor_descriptor import (_unwrap_if_constexpr, _unwrap_shape, block_type, tensor_descriptor) - -try: - import acl - is_compile_on_910_95 = acl.get_soc_name().startswith("Ascend910_95") -except Exception as e: - is_compile_on_910_95 = False - - -def ret_if_not_create_int_cast(src_sca_ty, dst_sca_ty, input, builder): - if not is_compile_on_910_95 and \ - (src_sca_ty.is_int_unsigned() or dst_sca_ty.is_int_unsigned()) and \ - src_sca_ty.int_bitwidth >= dst_sca_ty.int_bitwidth: - return cast(cast(input, tl.float32, builder), dst_sca_ty, builder) - return None +from .._C.libtriton import ir +from . import core as tl +from . import math + +from . import is_compile_on_910_95 + +T = TypeVar('T') + + +class IncompatibleTypeErrorImpl(Exception): + + def __init__(self, type_a, type_b): + self.type_a = type_a + self.type_b = type_b + self.message = "invalid operands of type " + self.type_a.__repr__() + " and " + self.type_b.__repr__() + super(IncompatibleTypeErrorImpl, self).__init__(self.message) + + +# ===----------------------------------------------------------------------===## +# Programming Model +# ===----------------------------------------------------------------------===## + + +def program_id(axis: int, builder: ir.builder) -> tl.tensor: + if axis not in (0, 1, 2): + raise ValueError(f"program_id axis must be 0, 1, or 2 but got {axis}") + return tl.tensor(builder.create_get_program_id(axis), tl.int32) + + +def num_programs(axis: int, builder: ir.builder) -> tl.tensor: + if axis not in (0, 1, 2): + raise ValueError(f"num_programs axis must be 0, 1, or 2 but got {axis}") + return tl.tensor(builder.create_get_num_programs(axis), tl.int32) + + +# ===----------------------------------------------------------------------===// +# Implicit Casting Utilities +# ===----------------------------------------------------------------------===// + + +def integer_promote_impl(a_ty: tl.dtype, b_ty: tl.dtype) -> tl.dtype: + a_rank = a_ty.int_bitwidth + b_rank = b_ty.int_bitwidth + a_sn = a_ty.int_signedness + b_sn = b_ty.int_signedness + # Rules for signedness taken from "Usual arithmetic conversions" on + # https://en.cppreference.com/w/c/language/conversion. + if a_sn == b_sn: + return a_ty if a_rank > b_rank else b_ty + elif a_sn == tl.dtype.SIGNEDNESS.UNSIGNED: + return a_ty if a_rank >= b_rank else b_ty + elif b_sn == tl.dtype.SIGNEDNESS.UNSIGNED: + return b_ty if b_rank >= a_rank else a_ty + raise TypeError(f"unexpected signedness {a_sn} and {b_sn}") + + +def computation_type_impl(a_ty: tl.dtype, a_is_scalar: bool, b_ty: tl.dtype, b_is_scalar: bool, + div_or_mod: bool) -> tl.dtype: + # 0) For scalars we follow semantics similar to PyTorch, namely: + # - If the scalar is of a lower or equal kind (bool < uint < int < fp), + # it doesn't participate in the pomotion + if a_is_scalar != b_is_scalar: + scalar_ty, tensor_ty = (a_ty, b_ty) if a_is_scalar else (b_ty, a_ty) + if scalar_ty.kind().value <= tensor_ty.kind().value: + # Upcast because of 3) and 4) below! + if div_or_mod and (tensor_ty in (tl.float16, tl.bfloat16)): + return tl.float32 + return tensor_ty + + # 1) if one operand is double, the other is implicitly + # converted to double + if a_ty.is_fp64() or b_ty.is_fp64(): + return tl.float64 + # 2) if one operand is float, the other is implicitly + # converted to float + if a_ty.is_fp32() or b_ty.is_fp32(): + return tl.float32 + # 3 ) if one operand is half, the other is implicitly converted to half + # unless we're doing / or %, which do not exist natively in PTX for fp16. + # Supported PTX op: add, sub, mul, fma, neg, abs, min, max, tanh, ex2, setp + if a_ty.is_fp16() or b_ty.is_fp16(): + if div_or_mod: + return tl.float32 + else: + return tl.float16 + # 4) return bf16 only if both operands are of bf16 + if a_ty.is_bf16() or b_ty.is_bf16(): + if div_or_mod: + return tl.float32 + if a_ty.is_bf16() and b_ty.is_bf16(): + return tl.bfloat16 + return tl.float32 + # 5) return fp16 if operands are different fp8 + if a_ty.is_fp8() and b_ty.is_fp8(): + return a_ty if a_ty == b_ty else tl.float16 + if not a_ty.is_int() or not b_ty.is_int(): + raise TypeError(f"unexpected type {a_ty} and {b_ty}") + # 6 ) both operands are integer and undergo + # integer promotion + if div_or_mod and a_ty.int_signedness != b_ty.int_signedness: + raise TypeError("Cannot use /, #, or % with " + a_ty.__repr__() + " and " + b_ty.__repr__() + + " because they have different signedness;" + "this is unlikely to result in a useful answer. Cast them to the same signedness.") + return integer_promote_impl(a_ty, b_ty) + + +def to_tensor(x, builder, check_type: bool = True): + if isinstance(x, bool): + return tl.tensor(builder.get_int1(x), tl.int1) + # Note: compile-time const integers are represented by unsigned values + elif isinstance(x, int): + if -2**31 <= x < 2**31: + dtype = tl.int32 + elif 2**31 <= x < 2**32: + dtype = tl.uint32 + elif -2**63 <= x < 2**63: + dtype = tl.int64 + elif 2**63 <= x < 2**64: + dtype = tl.uint64 + else: + raise ValueError(f'Nonrepresentable integer {x}.') + return full((), x, dtype=dtype, builder=builder) + elif isinstance(x, float): + min_float32 = 2**-126 + max_float32 = (2 - 2**-23) * 2**127 + abs_x = __builtins__['abs'](x) + if abs_x == float("inf") or\ + abs_x == 0.0 or \ + x != x or \ + min_float32 <= abs_x <= max_float32: + dtype = tl.float32 + else: + dtype = tl.float64 + return full((), x, dtype=dtype, builder=builder) + + elif isinstance(x, tl.constexpr): + return to_tensor(x.value, builder) + elif isinstance(x, tl.tensor): + return x + if check_type: + raise TypeError(f"cannot convert {x} of type {type(x)} to tensor") + return x + + +# ===----------------------------------------------------------------------===// +# Binary Operators +# ===----------------------------------------------------------------------===// + + +def check_ptr_type_impl(type_a: tl.dtype, type_b: tl.dtype, allow_ptr_a: bool) -> None: + if type_a.is_ptr(): + if not allow_ptr_a: + raise IncompatibleTypeErrorImpl(type_a, type_b) + # T* + U* with T != U + if type_b.is_ptr() and (type_a != type_b): + raise IncompatibleTypeErrorImpl(type_a, type_b) + # T* + float + if type_b.is_floating(): + raise IncompatibleTypeErrorImpl(type_a, type_b) + + +def binary_op_type_checking_impl(lhs: tl.tensor | numbers.Number, rhs: tl.tensor | numbers.Number, builder: ir.builder, + allow_lhs_ptr=False, allow_rhs_ptr=False, arithmetic_check=True, + div_or_mod=False) -> Tuple[tl.tensor, tl.tensor]: + lhs_is_scalar = isinstance(lhs, numbers.Number) + rhs_is_scalar = isinstance(rhs, numbers.Number) + if lhs_is_scalar: + lhs_scalar = lhs + lhs = to_tensor(lhs, builder) + if rhs_is_scalar: + rhs_scalar = rhs + rhs = to_tensor(rhs, builder) + + # implicit typecasting + lhs_sca_ty = lhs.type.scalar + rhs_sca_ty = rhs.type.scalar + check_ptr_type_impl(lhs_sca_ty, rhs_sca_ty, allow_lhs_ptr) + check_ptr_type_impl(rhs_sca_ty, lhs_sca_ty, allow_rhs_ptr) + if arithmetic_check and not lhs_sca_ty.is_ptr() and not rhs_sca_ty.is_ptr(): + ret_sca_ty = computation_type_impl(lhs_sca_ty, lhs_is_scalar, rhs_sca_ty, rhs_is_scalar, div_or_mod) + if (lhs_is_scalar and lhs_scalar < 0 and ret_sca_ty.is_int_unsigned() + or rhs_is_scalar and rhs_scalar < 0 and ret_sca_ty.is_int_unsigned()): + raise ValueError("Cannot perform a binary operation between an unsigned tensor and a negative scalar. " + "Perform a explicit cast on one of them.") + lhs = full( + (), lhs_scalar, dtype=ret_sca_ty, builder=builder) if lhs_is_scalar else cast(lhs, ret_sca_ty, builder) + rhs = full( + (), rhs_scalar, dtype=ret_sca_ty, builder=builder) if rhs_is_scalar else cast(rhs, ret_sca_ty, builder) + + # implicit broadcasting + lhs, rhs = broadcast_impl_value(lhs, rhs, builder) + return lhs, rhs + + +def binary_op_sanitize_overflow_impl(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder, binary_op: callable): + if lhs.type.scalar.int_bitwidth >= 64 or not builder.options.sanitize_overflow: + return + lhs_sca_ty = lhs.type.scalar + rhs_sca_ty = rhs.type.scalar + assert lhs_sca_ty == rhs_sca_ty + assert lhs_sca_ty.is_int() + lhs = cast(lhs, tl.int64, builder) + rhs = cast(rhs, tl.int64, builder) + ret = binary_op(lhs, rhs, False, builder) + max_value = lhs_sca_ty.get_int_max_value() + max_value = tl.tensor(builder.get_int64(max_value), tl.int64) + min_value = lhs_sca_ty.get_int_min_value() + min_value = tl.tensor(builder.get_int64(min_value), tl.int64) + cond = and_(less_equal(ret, max_value, builder), greater_equal(ret, min_value, builder), builder) + msg = f"int{lhs_sca_ty.int_bitwidth} overflow detected for operation {binary_op.__name__}" + device_assert(cond, msg, builder) + + +def add(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, sanitize_overflow: bool, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, True, True) + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + if input_scalar_ty.is_ptr() and other_scalar_ty.is_ptr(): + raise TypeError("cannot add pointers together") + + # offset + ptr + # ptr + offset + if other_scalar_ty.is_ptr() and not input_scalar_ty.is_ptr(): + input, other = other, input + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + if input_scalar_ty.is_ptr(): + return tl.tensor(builder.create_addptr(input.handle, other.handle), input.type) + # float + float + elif input_scalar_ty.is_floating(): + return tl.tensor(builder.create_fadd(input.handle, other.handle), input.type) + # int + int + elif input_scalar_ty.is_int(): + if sanitize_overflow: + binary_op_sanitize_overflow_impl(input, other, builder, add) + return tl.tensor(builder.create_add(input.handle, other.handle), input.type) + raise TypeError(f"unexpected type {input_scalar_ty}") + + +def sub(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, sanitize_overflow: bool, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, True, False) + scalar_ty = input.type.scalar + # ptr - offset + if scalar_ty.is_ptr(): + return tl.tensor(builder.create_addptr(input.handle, minus(other, builder).handle), input.type) + # float - float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fsub(input.handle, other.handle), input.type) + # int - int + elif scalar_ty.is_int(): + if sanitize_overflow: + binary_op_sanitize_overflow_impl(input, other, builder, sub) + return tl.tensor(builder.create_sub(input.handle, other.handle), input.type) + raise TypeError(f"unexpected type {scalar_ty}") + + +def mul(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, sanitize_overflow: bool, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float * float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fmul(input.handle, other.handle), input.type) + # int * int + elif scalar_ty.is_int(): + if sanitize_overflow: + binary_op_sanitize_overflow_impl(input, other, builder, mul) + return tl.tensor(builder.create_mul(input.handle, other.handle), input.type) + raise TypeError(f"unexpected type {scalar_ty}") + + +def truediv(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + # float / int + if input_scalar_ty.is_floating() and other_scalar_ty.is_int(): + other = cast(other, input_scalar_ty, builder) + # int / float + elif input_scalar_ty.is_int() and other_scalar_ty.is_floating(): + input = cast(input, other_scalar_ty, builder) + # int / int (cast to tl.float32) + elif input_scalar_ty.is_int() and other_scalar_ty.is_int(): + input = cast(input, tl.float32, builder) + other = cast(other, tl.float32, builder) + # float / float (cast to the highest exponent type) + elif input_scalar_ty.is_floating() and other_scalar_ty.is_floating(): + if input_scalar_ty.fp_mantissa_width > other_scalar_ty.fp_mantissa_width: + other = cast(other, input_scalar_ty, builder) + else: + input = cast(input, other_scalar_ty, builder) + # unreachable + else: + raise TypeError(f"unexpected type {input_scalar_ty}") + return tl.tensor(builder.create_fdiv(input.handle, other.handle), input.type) + + +def floordiv(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + if input_scalar_ty.is_int() and other_scalar_ty.is_int(): + ret_ty = integer_promote_impl(input_scalar_ty, other_scalar_ty) + input = cast(input, ret_ty, builder) + other = cast(other, ret_ty, builder) + if ret_ty.is_int_signed(): + return tl.tensor(builder.create_sdiv(input.handle, other.handle), input.type) + else: + return tl.tensor(builder.create_udiv(input.handle, other.handle), input.type) + raise TypeError(f"unexpected type {input_scalar_ty}") + + +def fdiv(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, ieee_rounding: bool, + builder: ir.builder) -> tl.tensor: + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + if not input_scalar_ty.is_floating() or not other_scalar_ty.is_floating(): + raise TypeError("both operands of fdiv must have floating scalar type") + input, other = binary_op_type_checking_impl(input, other, builder, False, False, False, True) + ret = builder.create_fdiv(input.handle, other.handle) + return tl.tensor(ret, input.type) + + +def mod(input: tl.tensor | numbers.Number, other: tl.tensor | numbers.Number, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) + scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + # float % float + if scalar_ty.is_floating(): + # input - input.div(other, rounding_mode="floor") * other + floor = math.floor(fdiv(input, other, False, builder), _builder=builder) + ret = sub(input, mul(floor, other, True, builder), True, builder) + return ret + # % int + elif scalar_ty.is_int(): + if scalar_ty.int_signedness != other_scalar_ty.int_signedness: + raise TypeError("Cannot mod " + scalar_ty.__repr__() + " by " + other_scalar_ty.__repr__() + " " + "because they have different signedness;" + "this is unlikely to result in a useful answer. Cast them to the same signedness.") + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_srem(input.handle, other.handle), input.type) + else: + return tl.tensor(builder.create_urem(input.handle, other.handle), input.type) + raise TypeError(f"unexpected type {scalar_ty}") -def check_arange_range_power_of_two(range, builder): - # Check if compile_mode is simt, then range must be a power of 2 - if builder.is_simt_mode(): - # Check if range is a power of 2 - if (range & (range - 1)) != 0: - raise ValueError("arange's range must be a power of 2") +############## +# other arithmetic ops +############## -def arange_disable_check_power_of_two(): - return True +def minimum(x: tl.tensor, y: tl.tensor, propagate_nan: tl.PropagateNan, builder: ir.builder): + x, y = binary_op_type_checking_impl(x, y, builder) + dtype = x.dtype + if dtype.is_floating(): + if propagate_nan == tl.PropagateNan.ALL: + return tl.tensor(builder.create_minimumf(x.handle, y.handle), x.type) + elif propagate_nan == tl.PropagateNan.NONE: + return tl.tensor(builder.create_minnumf(x.handle, y.handle), x.type) + else: + raise ValueError(f"Unexpected propagate_nan {propagate_nan}") + elif dtype.is_int_signed(): + return tl.tensor(builder.create_minsi(x.handle, y.handle), x.type) + elif dtype.is_int_unsigned(): + return tl.tensor(builder.create_minui(x.handle, y.handle), x.type) + else: + raise TypeError(f"Unexpected dtype {dtype}") -def check_arange_less_than_max_numel(range): - if range > TRITON_MAX_TENSOR_NUMEL: - raise ValueError( - f"end - start must be less than or equal to TRITON_MAX_TENSOR_NUMEL = {TRITON_MAX_TENSOR_NUMEL}") +def maximum(x: tl.tensor, y: tl.tensor, propagate_nan: tl.PropagateNan, builder: ir.builder): + x, y = binary_op_type_checking_impl(x, y, builder) + dtype = x.dtype + if dtype.is_floating(): + if propagate_nan == tl.PropagateNan.ALL: + return tl.tensor(builder.create_maximumf(x.handle, y.handle), x.type) + elif propagate_nan == tl.PropagateNan.NONE: + return tl.tensor(builder.create_maxnumf(x.handle, y.handle), x.type) + else: + raise ValueError(f"Unexpected propagate_nan {propagate_nan}") + elif dtype.is_int_signed(): + return tl.tensor(builder.create_maxsi(x.handle, y.handle), x.type) + elif dtype.is_int_unsigned(): + return tl.tensor(builder.create_maxui(x.handle, y.handle), x.type) + else: + raise TypeError(f"Unexpected dtype {dtype}") -def is_cast_src_dst_scalar_type_equal(src_sca_ty, dst_sca_ty): - if src_sca_ty == dst_sca_ty: - return True - return False +def clamp(x: tl.tensor, min: tl.tensor, max: tl.tensor, propagate_nan: tl.PropagateNan, builder: ir.builder): + min, max = binary_op_type_checking_impl(min, max, builder) + x, min = binary_op_type_checking_impl(x, min, builder) + x, max = binary_op_type_checking_impl(x, max, builder) -def check_unsupported_fp8_fp64(src_sca_ty, dst_sca_ty): - if not is_compile_on_910_95: - if (src_sca_ty.is_fp8() or dst_sca_ty.is_fp8()) or (src_sca_ty.is_fp64() or dst_sca_ty.is_fp64()): - raise ValueError("[fp8, fp64] is unsupported on Ascend for now." - "Source scalar type is " + str(src_sca_ty) + " and destination type is " + str(dst_sca_ty)) + dtype = x.dtype + if dtype.is_floating(): + return tl.tensor(builder.create_clampf(x.handle, min.handle, max.handle, propagate_nan), x.type) + else: + raise TypeError(f"Unexpected dtype {dtype}. Only floating point clamp is supported") -def ext_dot_operand_types(): - return (tl.int1, ) +############## +# bitwise ops +############## -def dot_check_hf32_input_precision(input_precision, ir, lhs, rhs, ret_scalar_ty): - if (input_precision == getattr(ir.INPUT_PRECISION, "HF32")): - if (not lhs.dtype.is_fp32() or not rhs.dtype.is_fp32() or not ret_scalar_ty.is_fp32()): - raise ValueError("input_precision = 'hf32' must be used with f32 * f32 = f32 on Ascend") +def bitwise_op_type_checking_impl(input: tl.tensor, other: tl.tensor, + builder: ir.builder) -> Tuple[tl.tensor, tl.tensor]: + input, other = binary_op_type_checking_impl(input, other, builder) + input_sca_ty = input.type.scalar + other_sca_ty = other.type.scalar + if not input_sca_ty.is_int() or not other_sca_ty.is_int(): + raise IncompatibleTypeErrorImpl(input_sca_ty, other_sca_ty) + ret_sca_ty = integer_promote_impl(input_sca_ty, other_sca_ty) + if ret_sca_ty != input_sca_ty: + input = cast(input, ret_sca_ty, builder) + if ret_sca_ty != other_sca_ty: + other = cast(other, ret_sca_ty, builder) + return input, other + +def and_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_and(input.handle, other.handle), input.type) -def dot_disable_check_max_num_imprecise_acc(): - print("max_num_imprecise_acc in tl.dot is not supported on Ascend yet. Thus it is ignored.") +def or_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_or(input.handle, other.handle), input.type) -def reset_dot_max_num_imprecise_acc(): - return 0 +def xor_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_xor(input.handle, other.handle), input.type) -def check_was_bool_to_int8_dtype(input): + +def logical_and(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + dst_sca_ty = tl.dtype("int1") + dst_bits = dst_sca_ty.primitive_bitwidth if hasattr(input, 'was_bool_to_int8'): - if input.type.scalar.is_int8(): - raise TypeError(f"unexpected type bool") + assert input.type.scalar.is_int8(), "input wat bool to int8. However, input.type is not int8." + input = cast(input, tl.int1, builder) + if not input.type.is_int1(): + src_sca_ty = input.type.scalar + src_bits = src_sca_ty.primitive_bitwidth + if src_bits == dst_bits or src_sca_ty.is_ptr() or dst_sca_ty.is_ptr(): + input = bitcast(input, tl.dtype("int1"), builder) + else: + input = not_equal(input, 0, builder) + if hasattr(other, 'was_bool_to_int8'): + assert other.type.scalar.is_int8(), "Other input wat bool to int8. However, other input.type is not int8." + other = cast(other, tl.int1, builder) + if not other.type.is_int1(): + src_sca_ty = other.type.scalar + src_bits = src_sca_ty.primitive_bitwidth + if src_bits == dst_bits or src_sca_ty.is_ptr() or dst_sca_ty.is_ptr(): + other = bitcast(other, tl.dtype("int1"), builder) + else: + other = not_equal(other, 0, builder) + return and_(input, other, builder) -def cast_bool_to_specified_dtype(input, builder): +def logical_or(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + dst_sca_ty = tl.dtype("int1") + dst_bits = dst_sca_ty.primitive_bitwidth if hasattr(input, 'was_bool_to_int8'): assert input.type.scalar.is_int8(), "input wat bool to int8. However, input.type is not int8." - return cast(input, tl.int1, builder) - if input.type.scalar.is_floating(): # NOTE: Only in not_? What about logical_and/logical_or? - raise TypeError(f"unexpected type {input.type.scalar}") - return input + input = cast(input, tl.int1, builder) + if not input.type.is_int1(): + src_sca_ty = input.type.scalar + src_bits = src_sca_ty.primitive_bitwidth + if src_bits == dst_bits or src_sca_ty.is_ptr() or dst_sca_ty.is_ptr(): + input = bitcast(input, tl.dtype("int1"), builder) + else: + input = not_equal(input, 0, builder) + if hasattr(other, 'was_bool_to_int8'): + assert other.type.scalar.is_int8(), "Other wat bool to int8. However, other.type is not int8." + other = cast(other, tl.int1, builder) + if not other.type.is_int1(): + src_sca_ty = other.type.scalar + src_bits = src_sca_ty.primitive_bitwidth + if src_bits == dst_bits or src_sca_ty.is_ptr() or dst_sca_ty.is_ptr(): + other = bitcast(other, tl.dtype("int1"), builder) + else: + other = not_equal(other, 0, builder) + return or_(input, other, builder) -def check_unexpected_dtype_float(input): +def not_(input: tl.tensor, builder: ir.builder): + if hasattr(input, 'was_bool_to_int8'): + assert input.type.scalar.is_int8(), "input wat bool to int8. However, input.type is not int8." + input = cast(input, tl.int1, builder) if input.type.scalar.is_floating(): raise TypeError(f"unexpected type {input.type.scalar}") + return invert(input, builder) -def check_unexpected_dtype_bool(dtype): - if dtype.is_bool(): - raise TypeError(f"Unexpected dtype {dtype}") +def lshr(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_lshr(input.handle, other.handle), input.type) -def set_load_legacy_other_input(other, mask, care_padding, builder): - if mask is not None and other is None and care_padding == True: - return to_tensor(0, builder) - return other +def ashr(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_ashr(input.handle, other.handle), input.type) -def disable_cast_back_when_load_legacy_ptr_is_bool(): - return True +def shl(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_shl(input.handle, other.handle), input.type) -def set_attr_was_bool_to_int8(ret, is_bool): - if is_bool: - ret.was_bool_to_int8 = True +# ===----------------------------------------------------------------------===// +# Unary Operators +# ===----------------------------------------------------------------------===// -def atomic_disable_original_check(): - return True +def plus(input: tl.tensor) -> tl.tensor: + return input -def atomic_cas_disable_element_bitwidth_check(): - return True +def minus(input: tl.tensor, builder: ir.builder) -> tl.tensor: + input_sca_ty = input.type.scalar + if hasattr(input, 'was_bool_to_int8'): + if input.type.scalar.is_int8(): + raise TypeError(f"unexpected type bool") + if input_sca_ty.is_ptr(): + raise ValueError("wrong type argument to unary minus (" + input_sca_ty.__repr__() + ")") + _0 = tl.tensor(builder.get_null_value(input_sca_ty.to_ir(builder)), input_sca_ty) + return sub(_0, input, True, builder) -def ext_atomic_cas_element_typechecking(element_ty): - if element_ty in [tl.int1, tl.int8, tl.float64, tl.bfloat16]: - raise ValueError(f"atomic_cas does not support {str(element_ty)}. " - "All support dtypes are int16, int32, int64, float16, float32.") +def invert(input: tl.tensor, builder: tl.tensor) -> tl.tensor: + if hasattr(input, 'was_bool_to_int8'): + assert input.type.scalar.is_int8(), "input wat bool to int8. However, input.type is not int8." + input = cast(input, tl.int1, builder) + input_sca_ty = input.type.scalar + if input_sca_ty.is_floating(): + raise TypeError(f"unexpected type {input_sca_ty}") + if input_sca_ty.is_ptr(): + raise ValueError("wrong type argument to unary invert (" + input_sca_ty.__repr__() + ")") + _1 = tl.tensor(builder.get_all_ones_value(input_sca_ty.to_ir(builder)), input_sca_ty) + return xor_(input, _1, builder) + + +# ===----------------------------------------------------------------------===// +# Comparison Operators +# ===----------------------------------------------------------------------===// +def _bool_like(v: tl.tensor) -> tl.block_type: + if not v.type.is_block(): + return tl.int1 + shape = v.type.shape + return tl.block_type(tl.int1, shape) + + +def greater_than(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float > float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOGT(input.handle, other.handle), _bool_like(input)) + # > int + elif scalar_ty.is_int(): + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_icmpSGT(input.handle, other.handle), _bool_like(input)) + else: + return tl.tensor(builder.create_icmpUGT(input.handle, other.handle), _bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + + +def greater_equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float >= float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOGE(input.handle, other.handle), _bool_like(input)) + # >= int + elif scalar_ty.is_int(): + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_icmpSGE(input.handle, other.handle), _bool_like(input)) + else: + return tl.tensor(builder.create_icmpUGE(input.handle, other.handle), _bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + + +def less_than(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float < float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOLT(input.handle, other.handle), _bool_like(input)) + # < int + elif scalar_ty.is_int(): + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_icmpSLT(input.handle, other.handle), _bool_like(input)) + else: + return tl.tensor(builder.create_icmpULT(input.handle, other.handle), _bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + + +def less_equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float < float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOLE(input.handle, other.handle), _bool_like(input)) + # < int + elif scalar_ty.is_int(): + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_icmpSLE(input.handle, other.handle), _bool_like(input)) + else: + return tl.tensor(builder.create_icmpULE(input.handle, other.handle), _bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + + +def equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float == float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOEQ(input.handle, other.handle), _bool_like(input)) + # == int + elif scalar_ty.is_int(): + return tl.tensor(builder.create_icmpEQ(input.handle, other.handle), _bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + + +def not_equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float == float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpUNE(input.handle, other.handle), _bool_like(input)) + # == int + elif scalar_ty.is_int(): + return tl.tensor(builder.create_icmpNE(input.handle, other.handle), _bool_like(input)) + raise TypeError(f"unexpected type {scalar_ty}") + + +# ===----------------------------------------------------------------------===// +# Block Creation +# ===----------------------------------------------------------------------===// + + +def arange(start: int, end: int, builder: ir.builder) -> tl.tensor: + if not isinstance(start, int) or not isinstance(end, int): + raise ValueError("arange's arguments must be of type tl.constexpr") + is_start_int64 = bool(start >> 32) + is_end_int64 = bool(end >> 32) + if is_start_int64 or is_end_int64: + raise ValueError("arange must fit in int32") + if end <= start: + raise ValueError("arange's end argument must be greater than the start argument") + range = end - start + # Check if compile_mode is simt, then range must be a power of 2 + if builder.is_simt_mode(): + # Check if range is a power of 2 + if (range & (range - 1)) != 0: + raise ValueError("arange's range must be a power of 2") + shape = [range] + ret_ty = tl.block_type(tl.int32, shape) + return tl.tensor(builder.create_make_range(start, end), ret_ty) -def is_atomic_max_no_bitcast(): - return True +def full(shape: List[int], value, dtype: tl.dtype, builder: ir.builder) -> tl.tensor: + if isinstance(value, tl.tensor): + assert value.numel.value == 1, "only accepts size-1 tensor" + value = cast(value, dtype, builder) + else: + # scalar + if dtype is None: + raise ValueError("dtype must be specified when value is not a tensor") + if value == 0: + value = builder.get_null_value(dtype.to_ir(builder)) + else: + get_value_fn = getattr(builder, f"get_{dtype.name}") + value = get_value_fn(value) + value = tl.tensor(value, dtype) + return splat(value, shape, builder) -def is_atomic_min_no_bitcast(): - return True +# ===----------------------------------------------------------------------===// +# Shape Manipulation +# ===----------------------------------------------------------------------===// -def atomic_max_returning_tensor(ir, ptr, val, mask, sem, scope, builder): - return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, ptr.handle, val.handle, mask.handle, sem, scope), - val.type) +def splat(value: tl.tensor, shape: List[int], builder: ir.builder) -> tl.tensor: + assert not value.type.is_block(), "Cannot splat a block tensor" + if len(shape) == 0: + return value + ret_ty = tl.block_type(value.dtype, shape) + return tl.tensor(builder.create_splat(value.handle, shape), ret_ty) + + +def reshape(input: tl.tensor, dst_shape: List[int], can_reorder: bool, builder: ir.builder) -> tl.tensor: + numel = 1 + for s in dst_shape: + numel *= s + if input.type.numel != numel: + raise ValueError("reshape() cannot change total number of elements in tensor") + ret_ty = tl.block_type(input.type.scalar, dst_shape) + return tl.tensor(builder.create_reshape(input.handle, dst_shape, can_reorder), ret_ty) -def atomic_min_returning_tensor(ir, ptr, val, mask, sem, scope, builder): - return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, ptr.handle, val.handle, mask.handle, sem, scope), - val.type) +def expand_dims(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: + dst_shape = [tl._constexpr_to_value(x) for x in input.shape] + dst_shape.insert(axis, 1) -def is_float_format_support_bf16(): - return True + if not input.type.is_block(): + return splat(input, shape=dst_shape, builder=builder) + ret_ty = tl.block_type(input.type.scalar, dst_shape) + return tl.tensor(builder.create_expand_dims(input.handle, axis), ret_ty) -def is_float_format_support_fp16(): - return True +def cat(lhs: tl.tensor, rhs: tl.tensor, can_reorder: bool, builder: ir.builder) -> tl.tensor: + assert can_reorder, "current implementation of `cat` always may reorder elements" + assert len(lhs.shape) == 1 + ret_type = tl.block_type(lhs.type.scalar, [lhs.shape[0] + rhs.shape[0]]) + return tl.tensor(builder.create_cat(lhs.handle, rhs.handle), ret_type) -def floating_mod_returning_tensor(builder, input, other): - return tl.tensor(builder.create_mod(input.handle, other.handle), input.type) +def join(a: tl.tensor, b: tl.tensor, builder: ir.builder) -> tl.tensor: + a, b = broadcast_impl_value(a, b, builder) -def logical_check_int1_bitcast(input, dst_sca_ty, dst_bits, builder): - src_sca_ty = input.type.scalar + # The IR can't handle joining two scalars, so upcast them to 1D tensors, + # then downcast the result. + was_rank_1 = a.shape == [] + if was_rank_1: + a = expand_dims(a, 0, builder) + b = expand_dims(b, 0, builder) + + if isinstance(a.shape[-1], tl.constexpr): + two = tl.constexpr(2) + else: + two = 2 + new_shape = a.shape + [two] + + ret_type = tl.block_type(a.type.scalar, new_shape) + ret = tl.tensor(builder.create_join(a.handle, b.handle), ret_type) + + if was_rank_1: + ret = reshape(ret, [2], can_reorder=False, builder=builder) + + return ret + + +def split(a: tl.tensor, builder: ir.builder) -> Tuple[tl.tensor, tl.tensor]: + assert (len(a.shape) > 0) + assert (tl._constexpr_to_value(a.shape[-1]) == 2) + + new_shape = a.shape[:-1] + ret_type = tl.block_type(a.type.scalar, new_shape) + outLHS, outRHS = builder.create_split(a.handle) + return ( + tl.tensor(outLHS, ret_type), + tl.tensor(outRHS, ret_type), + ) + + +def permute(input: tl.tensor, dims: Tuple[int], builder: ir.builder) -> tl.tensor: + if len(input.shape) != len(dims): + raise ValueError("permute dims must have the same length as input shape") + if sorted(tl._constexpr_to_value(d) for d in dims) != list(range(len(dims))): + raise ValueError(f"permute dims must be a permutation of 0, 1, ..., n-1, but were {dims}") + + ret_type = tl.block_type(input.type.scalar, [input.shape[d] for d in dims]) + return tl.tensor(builder.create_trans(input.handle, dims), ret_type) + + +def broadcast_impl_shape(input: tl.tensor, shape: List[int], builder: ir.builder) -> tl.tensor: + if not input.type.is_block(): + ret_ty = tl.block_type(input.type, shape) + return tl.tensor(builder.create_splat(input.handle, shape), ret_ty) + src_shape = input.type.get_block_shapes() + if len(src_shape) != len(shape): + raise ValueError(f"Cannot broadcast, rank mismatch: {src_shape}, {shape}") + if shape == src_shape: + return input + for i, item in enumerate(src_shape): + if shape[i] != item and item != 1: + raise ValueError(f"Cannot broadcast, the expanded size of the tensor ({shape[i]})" + f" must match the existing size ({item}) at non-singleton dimension" + f" {i}: {src_shape}, {shape}") + ret_ty = tl.block_type(input.type.scalar, shape) + return tl.tensor(builder.create_broadcast(input.handle, shape), ret_ty) + + +def broadcast_impl_value(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder) -> tl.tensor: + lhs_ty = lhs.type + rhs_ty = rhs.type + + # make_shape_compatible(block, scalar) + if lhs_ty.is_block() and not rhs_ty.is_block(): + rhs_ty = tl.block_type(rhs_ty.scalar, lhs_ty.shape) + rhs = tl.tensor(builder.create_splat(rhs.handle, lhs_ty.get_block_shapes()), rhs_ty) + # make_shape_compatible(scalar, block) + elif not lhs_ty.is_block() and rhs_ty.is_block(): + lhs_ty = tl.block_type(lhs_ty.scalar, rhs_ty.shape) + lhs = tl.tensor(builder.create_splat(lhs.handle, rhs_ty.get_block_shapes()), lhs_ty) + # make_shape_compatible(block, block) + elif lhs_ty.is_block() and rhs_ty.is_block(): + lhs_shape = lhs_ty.get_block_shapes() + rhs_shape = rhs_ty.get_block_shapes() + + if len(lhs_shape) < len(rhs_shape): + # Add new axes to lhs + for _ in range(len(lhs_shape), len(rhs_shape)): + lhs = tl.tensor(builder.create_expand_dims(lhs.handle, 0), + tl.block_type(lhs_ty.scalar, [1] + lhs_shape)) + lhs_ty = lhs.type + lhs_shape = lhs_ty.get_block_shapes() + elif len(rhs_shape) < len(lhs_shape): + # Add new axes to rhs + for _ in range(len(rhs_shape), len(lhs_shape)): + rhs = tl.tensor(builder.create_expand_dims(rhs.handle, 0), + tl.block_type(rhs_ty.scalar, [1] + rhs_shape)) + rhs_ty = rhs.type + rhs_shape = rhs_ty.get_block_shapes() + assert len(rhs_shape) == len(lhs_shape) + + ret_shape = [] + for i, left in enumerate(lhs_shape): + right = rhs_shape[i] + if left == 1: + ret_shape.append(right) + elif (right == 1) or (right == left): + ret_shape.append(left) + else: + raise ValueError("Cannot make_shape_compatible: incompatible dimensions " + "at index " + str(i) + ": " + str(left) + " and " + str(right)) + if lhs_shape != ret_shape: + ret_ty = tl.block_type(lhs_ty.scalar, ret_shape) + lhs = tl.tensor(builder.create_broadcast(lhs.handle, ret_shape), ret_ty) + if rhs_shape != ret_shape: + ret_ty = tl.block_type(rhs_ty.scalar, ret_shape) + rhs = tl.tensor(builder.create_broadcast(rhs.handle, ret_shape), ret_ty) + # (scalar, scalar) => returns original blocks + return lhs, rhs + + +####### +# cast +####### + + +def _str_to_rounding_mode(rounding_mode: Optional[str]): + if rounding_mode is None: + return None + if rounding_mode == 'rtne': + return ir.ROUNDING_MODE.RTNE + if rounding_mode == 'rtz': + return ir.ROUNDING_MODE.RTZ + raise ValueError(f"Invalid rounding mode: {rounding_mode}. Supported rounding modes are 'rtne' and 'rtz'.") + + +def bitcast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder) -> tl.tensor: + src_ty = input.type + if src_ty.is_block(): + dst_ty = tl.block_type(dst_ty.scalar, input.type.get_block_shapes()) + if src_ty == dst_ty: + return input + src_sca_ty = src_ty.scalar + dst_sca_ty = dst_ty.scalar + if src_sca_ty.is_ptr() or dst_sca_ty.is_ptr(): + return cast(input, dst_ty, builder) + # Bitcast src_bits = src_sca_ty.primitive_bitwidth - if src_bits == dst_bits or src_sca_ty.is_ptr() or dst_sca_ty.is_ptr(): - input = bitcast(input, tl.dtype("int1"), builder) + dst_bits = dst_sca_ty.primitive_bitwidth + if src_bits != dst_bits: + raise ValueError("Cannot bitcast data-type of size " + str(src_bits) + " to " + "data-type of size " + str(dst_bits)) + return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), dst_ty) + + +def cast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder, fp_downcast_rounding: Optional[str] = None, + overflow_mode: Optional[str] = None) -> tl.tensor: + src_ty = input.type + if isinstance(dst_ty, tl.constexpr): + dst_ty = dst_ty.value + if isinstance(fp_downcast_rounding, tl.constexpr): + fp_downcast_rounding = fp_downcast_rounding.value + if src_ty.is_block(): + dst_ty = tl.block_type(dst_ty.scalar, input.type.get_block_shapes()) + if src_ty == dst_ty: + return input + + src_sca_ty = src_ty.scalar + dst_sca_ty = dst_ty.scalar + + # For fp downcasting default rounding mode should be RTNE, for all other conversions it should + # not be set + fp_downcast_rounding = _str_to_rounding_mode(fp_downcast_rounding) + use_custom_rounding = False + if dst_sca_ty.is_floating() and src_sca_ty.is_floating( + ) and dst_sca_ty.primitive_bitwidth < src_sca_ty.primitive_bitwidth: + if fp_downcast_rounding is None: fp_downcast_rounding = ir.ROUNDING_MODE.RTNE + elif fp_downcast_rounding != ir.ROUNDING_MODE.RTNE: use_custom_rounding = True else: - input = not_equal(input, 0, builder) + if fp_downcast_rounding is not None: + raise ValueError("fp_downcast_rounding should be set only for truncating fp conversions. " + "Source scalar type is " + str(src_sca_ty) + " and destination type is " + str(dst_sca_ty)) + if (src_sca_ty.is_fp8e4b15() or dst_sca_ty.is_fp8e4b15()): + assert builder.codegen_fns.get( + "convert_custom_types") is not None, "target doesn't provide conversion for this type." + return builder.codegen_fns["convert_custom_types"](input, dst_ty, fp_downcast_rounding, _builder=builder) + # Casting with customized floating types involved: fp8 <=> bf16, fp16, fp32, fp64 + # and non-default rounding modes for downcasting + if (src_sca_ty.is_fp8() and dst_sca_ty.is_floating()) or \ + (src_sca_ty.is_floating() and dst_sca_ty.is_fp8()) or \ + use_custom_rounding: + return tl.tensor(builder.create_fp_to_fp(input.handle, dst_ty.to_ir(builder), fp_downcast_rounding), dst_ty) + + # bf16 <=> (not fp32) + if (src_sca_ty.is_fp16() and not dst_sca_ty.is_fp32()) or \ + (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()): + return cast(cast(input, tl.float32, builder), dst_sca_ty, builder) -def ext_dot_scaled_validate_lhs_dtype(lhs): - assert lhs.dtype == tl.bfloat16 or lhs.dtype == tl.float16, f"lhs matrix dtype must be bf16 or fp16" + # Standard floating types' casting: truncation + # fp64 => fp32, fp16, bf16 + # fp32 => fp16, bf16 + truncate_fp = src_sca_ty.is_floating() and \ + dst_sca_ty.is_floating() and \ + src_sca_ty.primitive_bitwidth > dst_sca_ty.primitive_bitwidth + if truncate_fp: + return tl.tensor(builder.create_fp_trunc(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Standard floating types' casting: extension + # fp32 => fp64 + # fp16 => fp32, fp64 + # bf16 => fp32, fp64 + ext_fp = src_sca_ty.is_floating() and \ + dst_sca_ty.is_floating() and \ + src_sca_ty.primitive_bitwidth < dst_sca_ty.primitive_bitwidth + if ext_fp: + return tl.tensor(builder.create_fp_ext(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Casting between integer types + if src_sca_ty.is_int() and dst_sca_ty.is_int() and \ + (src_sca_ty.int_bitwidth != dst_sca_ty.int_bitwidth or src_sca_ty.int_signedness != dst_sca_ty.int_signedness): + sign_extend = src_sca_ty.is_int_signed() and not src_sca_ty.is_bool() + if dst_sca_ty.is_bool(): + ty = input.dtype.to_ir(builder) + _0 = tl.tensor(builder.get_null_value(ty), input.dtype) + return not_equal(input, _0, builder) + elif overflow_mode == "saturate" and \ + (src_sca_ty.is_int_unsigned() or dst_sca_ty.is_int_unsigned()) and \ + src_sca_ty.int_bitwidth >= dst_sca_ty.int_bitwidth: + return cast(cast(input, tl.float32, builder), dst_sca_ty, builder) + else: + return tl.tensor(builder.create_int_cast(input.handle, dst_ty.to_ir(builder), sign_extend), dst_ty) + + # Casting standard floating types to integer types + if src_sca_ty.is_standard_floating() and dst_sca_ty.is_int(): + if dst_sca_ty.is_bool(): + ty = input.dtype.to_ir(builder) + _0 = tl.tensor(builder.get_null_value(ty), input.dtype) + return not_equal(input, _0, builder) + elif dst_sca_ty.is_int_signed(): + return tl.tensor(builder.create_fp_to_si(input.handle, dst_ty.to_ir(builder)), dst_ty) + else: + return tl.tensor(builder.create_fp_to_ui(input.handle, dst_ty.to_ir(builder)), dst_ty) + # Casting integer types to standard floating types + if src_sca_ty.is_int() and dst_sca_ty.is_standard_floating(): + if src_sca_ty.is_bool() or not src_sca_ty.is_int_signed(): + return tl.tensor(builder.create_ui_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty) + else: + return tl.tensor(builder.create_si_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty) -def ext_dot_scaled_validate_rhs_dtype(rhs): - assert rhs.dtype == tl.bfloat16 or rhs.dtype == tl.float16, f"rhs matrix dtype must be bf16 or fp16" + # Casting pointer types to integer types + if src_sca_ty.is_ptr() and dst_sca_ty.is_int(): + bitwidth = dst_sca_ty.int_bitwidth + if bitwidth == 64: + return tl.tensor(builder.create_ptr_to_int(input.handle, dst_ty.to_ir(builder)), dst_ty) + if bitwidth == 1: + return not_equal(cast(input, tl.int64, builder), tl.tensor(builder.get_int64(0), tl.int64), builder) + # Casting integer types to pointer types + if src_sca_ty.is_int() and dst_sca_ty.is_ptr(): + return tl.tensor(builder.create_int_to_ptr(input.handle, dst_ty.to_ir(builder)), dst_ty) -def ext_dot_scaled_check_same_dtype(lhs, rhs): - assert lhs.dtype == rhs.dtype, f"lhs rhs matrix must get same dtype" + # Casting pointer types to pointer types + if src_sca_ty.is_ptr() and dst_sca_ty.is_ptr(): + return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), dst_ty) + assert False, f'cannot cast {input} to {dst_ty}' -def dot_scaled_disable_original_check(): - return True +# ===----------------------------------------------------------------------===// +# Memory Operators +# ===----------------------------------------------------------------------===// -def ext_dot_scaled_check_lhs_rhs_format(lhs_format, rhs_format): - lhs_format: str = lhs_format.value - rhs_format: str = rhs_format.value - allowed_formats = {"bf16", "fp16"} # unsupported fp8/4 dtype: "e2m1", "e4m3", "e5m2" - assert lhs_format in allowed_formats, f"NYI: lhs_format {lhs_format}" - assert rhs_format in allowed_formats, f"NYI: rhs_format {rhs_format}" +def _str_to_load_cache_modifier(cache_modifier): + cache = ir.CACHE_MODIFIER.NONE # default + if cache_modifier: + if cache_modifier == ".ca": + cache = ir.CACHE_MODIFIER.CA + elif cache_modifier == ".cg": + cache = ir.CACHE_MODIFIER.CG + elif cache_modifier == ".cv": + cache = ir.CACHE_MODIFIER.CV + else: + raise ValueError(f"Cache modifier {cache_modifier} not supported") + return cache + + +def _str_to_store_cache_modifier(cache_modifier): + cache = ir.CACHE_MODIFIER.NONE # default + if cache_modifier: + if cache_modifier == ".wb": + cache = ir.CACHE_MODIFIER.WB + elif cache_modifier == ".cg": + cache = ir.CACHE_MODIFIER.CG + elif cache_modifier == ".cs": + cache = ir.CACHE_MODIFIER.CS + elif cache_modifier == ".wt": + cache = ir.CACHE_MODIFIER.WT + else: + raise ValueError(f"Cache modifier {cache_modifier} not supported") + return cache -def dot_scaled_recheck_rhs_scale_is_none(rhs_scale, rhs_scale_is_none): - rhs_scale_is_none = rhs_scale is None or (isinstance(rhs_scale, tl.constexpr) and rhs_scale.value is None) - return rhs_scale_is_none +def _str_to_eviction_policy(eviction_policy): + eviction = ir.EVICTION_POLICY.NORMAL # default + if eviction_policy: + if eviction_policy == "evict_last": + eviction = ir.EVICTION_POLICY.EVICT_LAST + elif eviction_policy == "evict_first": + eviction = ir.EVICTION_POLICY.EVICT_FIRST + else: + raise ValueError(f"Eviction policy {eviction_policy} not supported") + return eviction -def dot_scaled_check_lhs_scale_is_none(lhs_scale): - lhs_scale_is_none = lhs_scale is None or (isinstance(lhs_scale, tl.constexpr) and lhs_scale.value is None) - return lhs_scale_is_none +def _str_to_padding_option(padding_option): + padding = None # default + if padding_option: + if padding_option == "zero": + padding = ir.PADDING_OPTION.PAD_ZERO + elif padding_option == "nan": + padding = ir.PADDING_OPTION.PAD_NAN + else: + raise ValueError(f"Padding option {padding_option} not supported") + return padding + + +def _str_to_sem(sem_option): + sem = ir.MEM_SEMANTIC.ACQUIRE_RELEASE + if sem_option: + if sem_option == "acquire": + sem = ir.MEM_SEMANTIC.ACQUIRE + elif sem_option == "release": + sem = ir.MEM_SEMANTIC.RELEASE + elif sem_option == "acq_rel": + sem = ir.MEM_SEMANTIC.ACQUIRE_RELEASE + elif sem_option == "relaxed": + sem = ir.MEM_SEMANTIC.RELAXED + else: + raise ValueError(f"Memory semantic {sem_option} not supported") + return sem + + +def _str_to_scope(scope_option): + scope = ir.MEM_SYNC_SCOPE.GPU + if scope_option: + if scope_option == "gpu": + scope = ir.MEM_SYNC_SCOPE.GPU + elif scope_option == "cta": + scope = ir.MEM_SYNC_SCOPE.CTA + elif scope_option == "sys": + scope = ir.MEM_SYNC_SCOPE.SYSTEM + else: + raise ValueError(f"Memory semantic {scope_option} not supported") + return scope + + +def _canonicalize_boundary_check(boundary_check, block_shape): + if boundary_check: + if not hasattr(boundary_check, "__iter__"): + boundary_check = [boundary_check] + boundary_check = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in boundary_check] + for dim in boundary_check: + assert isinstance(dim, int) and 0 <= dim < len(block_shape) + assert len(boundary_check) > 0 + assert len(boundary_check) == len(set(boundary_check)), "Duplicate dimension in `boundary_check`" + return sorted(boundary_check) + return () + + +def _load_block_pointer(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder): + # Load by a block pointer: `pointer_type>` + # Block pointer can not have `mask` and `other` arguments + if mask is not None or other is not None: + raise ValueError("`mask` and `other` arguments cannot be specified for loading block pointers") -def is_dot_scaled_support_rhs_scale(): - return True + elt_ty = ptr.type.element_ty.element_ty + assert elt_ty != tl.int1, "`tl.int1` should be rewrited in `tl.make_block_ptr`" + if elt_ty.is_int() and padding == ir.PADDING_OPTION.PAD_NAN: + raise ValueError("Padding option `nan` is not supported for integer block pointers") + # `dst_ty` is de-referenced type of the pointer type + dst_ty = ptr.type.element_ty -def check_dot_scaled_lhs_scale_dtype(lhs_scale): - assert isinstance(lhs_scale, tl.tensor) and lhs_scale.dtype == tl.int8, f"lhs_scale must be int8 tensor" + # Check `boundary_check` argument + boundary_check = _canonicalize_boundary_check(boundary_check, dst_ty.get_block_shapes()) + # Build IR + return tl.tensor( + builder.create_tensor_pointer_load(ptr.handle, boundary_check, padding, cache, eviction, is_volatile), dst_ty) -def check_dot_scaled_rhs_scale_dtype(rhs_scale, rhs_scale_is_none): - if not rhs_scale_is_none: - assert isinstance(rhs_scale, tl.tensor) and rhs_scale.dtype == tl.int8, f"rhs_scale must be int8 tensor" +def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, care_padding, builder): + # Load by a tensor of pointers or a pointer of scalar: `block_type>` or `pointer_type<>` + if not ptr.type.scalar.is_ptr(): + raise ValueError(f"Unsupported ptr type {ptr.type.__repr__()} in `tl.load`") -def dot_scaled_lrhs_k_pack(lhs_k_pack, rhs_k_pack, builder): - if lhs_k_pack == False: - dims = (1, 0) - dims = core._unwrap_iterable(dims) - tmp_lhs = permute(lhs, dims, builder) - lhs = reshape(tmp_lhs, (lhs.shape[0], lhs.shape[1]), True, builder) + # Check `mask`, `other`, `boundary_check`, and `padding` arguments + if mask is None and other is not None: + raise ValueError("`other` cannot be provided without `mask`") + if padding or boundary_check: + raise ValueError("`padding_option` or `boundary_check` argument is not supported for loading a tensor of" + "pointers or loading a scalar. Because the compiler does not know the boundary; please " + "use block pointers (defined by `make_block_ptr`) instead") - if rhs_k_pack == False: - dims = (1, 0) - dims = core._unwrap_iterable(dims) - tmp_rhs = permute(rhs, dims, builder) - rhs = reshape(tmp_rhs, (rhs.shape[0], rhs.shape[1]), True, builder) + if mask is not None and other is None and care_padding == True: + # Get element type to determine default padding value + elt_ty = ptr.type.scalar.element_ty + # Use 0.0 for floating point types, 0 for integer types + default_value = 0.0 if elt_ty.is_floating() else 0 + other = to_tensor(default_value, builder) + # For a pointer of scalar, check the type of `mask` and `other` + if not ptr.type.is_block(): + if mask and mask.type.is_block(): + raise ValueError("Mask argument cannot be block type if pointer argument is not a block") + if other and other.type.is_block(): + raise ValueError("Other argument cannot be block type if pointer argument is not a block") + + # Make `mask` and `other` into the same shape as `ptr` + if ptr.type.is_block(): + if mask is not None: + mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder) + if other is not None: + other = broadcast_impl_shape(other, ptr.type.get_block_shapes(), builder) + + # Get `pointer_type` and `elt_ty` + ptr_ty = ptr.type.scalar + elt_ty = ptr_ty.element_ty + + # Treat `pointer_type` as `pointer_type` + is_bool = elt_ty == tl.int1 + if is_bool: + elt_ty = tl.int8 + ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space) + ptr = cast(ptr, ptr_ty, builder) + + # Cast `other` into `elt_ty` type + if other is not None: + other = cast(other, elt_ty, builder) + + # Create loaded result type `dst_ty` + if ptr.type.is_block(): + shape = ptr.type.get_block_shapes() + dst_ty = tl.block_type(elt_ty, shape) + else: + # Load by de-referencing the pointer of scalar + dst_ty = elt_ty + + # Build IR + if mask is None: + ret = tl.tensor(builder.create_load(ptr.handle, cache, eviction, is_volatile), dst_ty) + else: + ret = tl.tensor( + builder.create_masked_load(ptr.handle, mask.handle, other.handle if other else None, cache, eviction, + is_volatile), dst_ty) + # Do not cast back to int1 when is_bool=true. We directly use the int8 tensor given by tl.load + if is_bool: + ret.was_bool_to_int8 = True + + return ret + + +def load(ptr: tl.tensor, mask: Optional[tl.tensor], other: Optional[tl.tensor], boundary_check: Tuple, + padding_option: str, cache_modifier: str, eviction_policy: str, is_volatile: bool, care_padding: bool, + builder: ir.builder) -> tl.tensor: + # Cache, eviction and padding options + cache = _str_to_load_cache_modifier(cache_modifier) + eviction = _str_to_eviction_policy(eviction_policy) + padding = _str_to_padding_option(padding_option) + + if ptr.type.is_ptr() and ptr.type.element_ty.is_block(): + # Load by a block pointer: `pointer_type>` + return _load_block_pointer(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder) + else: + # Load by a tensor of pointers or a pointer of scalar: `block_type>` or `pointer_type<>` + return _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, care_padding, + builder) + + +def tensormap_create( + desc_ptr: tl.tensor, + global_address: tl.tensor, + box_dim: List[tl.tensor], + global_dim: List[tl.tensor], + global_stride: List[tl.tensor], + element_stride: List[tl.tensor], + elem_type: int, + interleave_layout: int, + swizzle_mode: int, + fill_mode: int, + builder: ir.builder, +) -> tl.tensor: + assert not global_stride or global_stride[0].dtype == tl.int64 + return tl.tensor( + builder.create_tensormap_create( + desc_ptr.handle, + global_address.handle, + [x.handle for x in box_dim], + [x.handle for x in global_dim], + [x.handle for x in global_stride], + [x.handle for x in element_stride], + elem_type, + interleave_layout, + swizzle_mode, + fill_mode, + ), + tl.void, + ) + + +def tensormap_fenceproxy_acquire(desc_ptr: tl.tensor, builder: ir.builder) -> tl.tensor: + return tl.tensor(builder.create_tensormap_fenceproxy_acquire(desc_ptr.handle), tl.void) + + +def _store_block_pointer(ptr, val, mask, boundary_check, cache, eviction, builder): + # Store by a block pointer: `pointer_type>` + # Block pointers can not have the `mask` argument + if mask is not None: + raise ValueError("`mask` and `other` arguments cannot be specified for loading block pointers") + + # Check same shape and element type + block_shape = ptr.type.element_ty.get_block_shapes() + if not val.type.is_block(): + val = broadcast_impl_shape(val, block_shape, builder) + assert val.type.is_block(), "Value argument must be block type or a scalar" + assert block_shape == val.type.get_block_shapes( + ), f"Block shape({block_shape}) and value shape({val.type.get_block_shapes()}) mismatch" + assert ptr.type.element_ty.element_ty == val.type.element_ty, f"Block element type({ptr.type.element_ty.element_ty}) and value element type({val.type.element_ty}) mismatch" + + elt_ty = ptr.type.element_ty.element_ty + assert elt_ty != tl.int1, "`tl.int1` should be rewrited in `tl.make_block_ptr`" + + # Check `boundary_check` argument + boundary_check = _canonicalize_boundary_check(boundary_check, block_shape) + + # Cast to target data type + val = cast(val, elt_ty, builder) + + # Build IR + return tl.tensor(builder.create_tensor_pointer_store(ptr.handle, val.handle, boundary_check, cache, eviction), + tl.void) + + +def _store_legacy(ptr, val, mask, boundary_check, cache, eviction, builder): + # Store by a tensor of pointers or a pointer of scalar: `block_type>` or `pointer_type<>` + if not ptr.type.scalar.is_ptr(): + raise ValueError(f"Unsupported ptr type {ptr.type.__repr__()} in `tl.store`") + + # Check `boundary_check` argument + if boundary_check: + raise ValueError("`boundary_check` argument is not supported for storing a tensor of pointers or storing a " + "scalar. Because the compiler does not know the boundary; please use block pointers " + "(defined by `make_block_ptr`) instead") + + # For a pointer of scalar, check the type of `val` and `mask` + if not ptr.type.is_block(): + if val.type.is_block(): + raise ValueError("Value argument cannot be block type if pointer argument is not a block") + if mask and mask.type.is_block(): + raise ValueError("Mask argument cannot be block type if pointer argument is not a block") + + # Make `mask` and `val` into the same shape as `ptr` + if ptr.type.is_block(): + val = broadcast_impl_shape(val, ptr.type.get_block_shapes(), builder) + if mask is not None: + mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder) + + ptr_ty = ptr.type.scalar + elt_ty = ptr_ty.element_ty + + # Treat `pointer_type` as `pointer_type` + if elt_ty == tl.int1: + elt_ty = tl.int8 + ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space) + ptr = cast(ptr, ptr_ty, builder) + + # Cast to target data type + val = cast(val, elt_ty, builder) + + # Build IR + if not mask: + return tl.tensor(builder.create_store(ptr.handle, val.handle, cache, eviction), tl.void) + if not mask.type.scalar.is_bool(): + raise ValueError("Mask must have boolean scalar type") + return tl.tensor(builder.create_masked_store(ptr.handle, val.handle, mask.handle, cache, eviction), tl.void) + + +def store(ptr: tl.tensor, val: tl.tensor, mask: Optional[tl.tensor], boundary_check, cache_modifier: str, + eviction_policy: str, builder: ir.builder) -> tl.tensor: + # Cache and eviction options + cache = _str_to_store_cache_modifier(cache_modifier) + eviction = _str_to_eviction_policy(eviction_policy) + + if ptr.type.is_const() or ptr.type.scalar.is_const(): + raise ValueError("Cannot store to a constant pointer") + + if ptr.type.is_ptr() and ptr.type.element_ty.is_block(): + # Store by a block pointer: `pointer_type>` + return _store_block_pointer(ptr, val, mask, boundary_check, cache, eviction, builder) + else: + # Store by a tensor of pointers or a pointer of scalar: `block_type>` or `pointer_type<>` + return _store_legacy(ptr, val, mask, boundary_check, cache, eviction, builder) + + +######### +# atomic +######### + + +def atomic_cas(ptr: tl.tensor, cmp: tl.tensor, val: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + element_ty = ptr.type.scalar.element_ty + if not is_compile_on_910_95: + supported_types = [tl.int8, tl.uint8, tl.int16, tl.int32, tl.int64, tl.float16, tl.bfloat16, tl.float32] + if element_ty not in supported_types: + raise ValueError(f"atomic_cas does not support {str(element_ty)}. " + "All support dtypes are int8, uint8, int16, int32, int64, float16, bfloat16, float32.") + else: + unsupported_types = [tl.int1] + if element_ty in unsupported_types: + raise ValueError( + f"atomic_cas does not support {str(element_ty)}. " + "All support dtypes are int8, uint8, int16, uint16, int32, uint32, int64, uint64, fp8e4m3, fp8e5m2, float16, bfloat16, float32." + ) + return tl.tensor(builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle, sem, scope), val.type) + + +def atom_red_typechecking_impl(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, op: str, + builder: ir.builder) -> Tuple[tl.tensor, tl.tensor, tl.tensor]: + if not ptr.type.scalar.is_ptr(): + raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__()) + if ptr.type.is_const() or ptr.type.element_ty.is_const(): + raise ValueError("Cannot store to a constant pointer") + if ptr.type.is_block(): + if mask is not None: + mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder) + if val is not None: + val = broadcast_impl_shape(val, ptr.type.get_block_shapes(), builder) + val = cast(val, ptr.type.scalar.element_ty, builder) + if not mask: + mask_ir = builder.get_int1(True) + mask_ty = tl.int1 + if ptr.type.is_block(): + mask_ir = builder.create_splat(mask_ir, ptr.type.get_block_shapes()) + mask_ty = tl.block_type(tl.int1, ptr.type.get_block_shapes()) + mask = tl.tensor(mask_ir, mask_ty) + return ptr, val, mask + + +def atomic_max(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'max', builder) + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + sca_ty = val.type.scalar + # direct call to atomic_max for integers + if sca_ty.is_int(): + if sca_ty.is_int_signed(): + return tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + else: + return tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + + # Design for NPU + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, ptr.handle, val.handle, mask.handle, sem, scope), + val.type) + + +def atomic_min(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'min', builder) + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + sca_ty = val.type.scalar + # direct call to atomic_min for integers + if sca_ty.is_int(): + if sca_ty.is_int_signed(): + return tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + else: + return tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + + # Design for NPU + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, ptr.handle, val.handle, mask.handle, sem, scope), + val.type) + + +# def atomic_max(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: +# ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'max', builder) +# sem = _str_to_sem(sem) +# scope = _str_to_scope(scope) +# sca_ty = val.type.scalar +# # direct call to atomic_max for integers +# if sca_ty.is_int(): +# if sca_ty.is_int_signed(): +# return tl.tensor( +# builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, ptr.handle, val.handle, mask.handle, sem, scope), val.type) +# else: +# return tl.tensor( +# builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, ptr.handle, val.handle, mask.handle, sem, scope), val.type) +# # for float +# # return atomic_smax(i_ptr, i_val) if val >= 0 +# # return atomic_umin(i_ptr, i_val) if val < 0 +# if sca_ty not in {tl.float32, tl.float64}: +# raise TypeError(f"atomic_max not supported for dtype {sca_ty}") + +# zero = full([], 0.0, sca_ty, builder) + +# i_type = tl.int32 if sca_ty == tl.float32 else tl.int64 +# i_val = bitcast(val, i_type, builder) +# i_ptr = bitcast(ptr, tl.pointer_type(i_type, 1), builder) +# ui_type = tl.uint32 if sca_ty == tl.float32 else tl.uint64 +# ui_val = bitcast(val, ui_type, builder) +# ui_ptr = bitcast(ptr, tl.pointer_type(ui_type, 1), builder) +# pos = greater_equal(val, zero, builder) +# neg = less_than(val, zero, builder) +# pos_ret = tl.tensor( +# builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle, +# and_(mask, pos, builder).handle, sem, scope), i_val.type) +# neg_ret = tl.tensor( +# builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, ui_ptr.handle, ui_val.handle, +# and_(mask, neg, builder).handle, sem, scope), ui_val.type) +# ret = where(pos, pos_ret, neg_ret, builder) +# return bitcast(ret, sca_ty, builder) + +# def atomic_min(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: +# ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'min', builder) +# sem = _str_to_sem(sem) +# scope = _str_to_scope(scope) +# sca_ty = val.type.scalar +# # direct call to atomic_min for integers +# if sca_ty.is_int(): +# if sca_ty.is_int_signed(): +# return tl.tensor( +# builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, ptr.handle, val.handle, mask.handle, sem, scope), val.type) +# else: +# return tl.tensor( +# builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, ptr.handle, val.handle, mask.handle, sem, scope), val.type) +# # for float +# # return atomic_smin(i_ptr, i_val) if val >= 0 +# # return atomic_umax(i_ptr, i_val) if val < 0 +# if sca_ty not in {tl.float32, tl.float64}: +# raise TypeError(f"atomic_min not supported for dtype {sca_ty}") + +# zero = full([], 0.0, sca_ty, builder) + +# i_type = tl.int32 if sca_ty == tl.float32 else tl.int64 +# i_val = bitcast(val, i_type, builder) +# i_ptr = bitcast(ptr, tl.pointer_type(i_type, 1), builder) +# ui_type = tl.uint32 if sca_ty == tl.float32 else tl.uint64 +# ui_val = bitcast(val, ui_type, builder) +# ui_ptr = bitcast(ptr, tl.pointer_type(ui_type, 1), builder) +# pos = greater_equal(val, zero, builder) +# neg = less_than(val, zero, builder) +# pos_ret = tl.tensor( +# builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, i_ptr.handle, i_val.handle, +# and_(mask, pos, builder).handle, sem, scope), i_val.type) +# neg_ret = tl.tensor( +# builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, ui_ptr.handle, ui_val.handle, +# and_(mask, neg, builder).handle, sem, scope), ui_ptr.type) +# ret = where(pos, pos_ret, neg_ret, builder) +# return bitcast(ret, sca_ty, builder) + + +def atomic_add(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'add', builder) + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + sca_ty = val.type.scalar + op = ir.ATOMIC_OP.FADD if sca_ty.is_floating() else ir.ATOMIC_OP.ADD + return tl.tensor(builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + + +def atomic_and(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'and', builder) + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle, sem, scope), + val.type) + + +def atomic_or(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'or', builder) + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle, sem, scope), + val.type) + + +def atomic_xor(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xor', builder) + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle, sem, scope), + val.type) + + +def atomic_xchg(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, + builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xchg', builder) + sem = _str_to_sem(sem) + scope = _str_to_scope(scope) + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle, sem, scope), + val.type) + + +# ===----------------------------------------------------------------------===// +# Linear Algebra +# ===----------------------------------------------------------------------===// -def _bitcast_to_fp_type(val, float_format, builder): +def _str_to_dot_input_precision(input_precision, builder): + assert input_precision.lower() in builder.options.allowed_dot_input_precisions, \ + f"input_precision must be one of {builder.options.allowed_dot_input_precisions}. Got {input_precision}" + input_precision = input_precision.upper() + if input_precision == "TF32X3": + input_precision = "TF32x3" + return getattr(ir.INPUT_PRECISION, input_precision) + + +def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, input_precision: Optional[str], max_num_imprecise_acc: int, + out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor: + assert lhs.type.is_block() and rhs.type.is_block() + + if lhs.dtype.is_fp8() and rhs.dtype.is_fp8(): + # All combinations of supported fp8 x fp8 are permitted + pass + else: + assert lhs.dtype in (tl.int1, tl.int8, tl.uint8, tl.float16, tl.bfloat16, + tl.float32), f"Unsupported lhs dtype {lhs.dtype}" + assert rhs.dtype in (tl.int1, tl.int8, tl.uint8, tl.float16, tl.bfloat16, + tl.float32), f"Unsupported rhs dtype {rhs.dtype}" + assert lhs.dtype == rhs.dtype, f"Both operands must be same dtype. Got {lhs.dtype} and {rhs.dtype}" + + if lhs.dtype.is_fp8e4b15() or rhs.dtype.is_fp8e4b15(): + lhs = cast(lhs, tl.float16, builder) + rhs = cast(rhs, tl.float16, builder) + + if input_precision is None: + input_precision = builder.options.default_dot_input_precision + + input_precision = _str_to_dot_input_precision(input_precision, builder) + + lhs_rank = len(lhs.shape) + rhs_rank = len(rhs.shape) + assert lhs_rank == rhs_rank == 2 or lhs_rank == rhs_rank == 3, f"Both inputs must be either 2D or 3D; (lhs: {lhs.shape} vs rhs: {rhs.shape})" + assert lhs.shape[-1].value == rhs.shape[ + -2].value, f"First input shape ({lhs.shape}) and second input shape {rhs.shape} are not compatible for matmul (second index of first shape ({lhs.shape[-1].value}) must be equal to first index of second shape ({rhs.shape[-2].value})" + assert builder.codegen_fns.get("min_dot_size") is not None, "target doesn't provide lower shape bounds for dot." + min_dot_size = builder.codegen_fns["min_dot_size"](lhs.type, rhs.type) + assert lhs.shape[-2].value >= min_dot_size[0] and lhs.shape[-1].value >= min_dot_size[2] \ + and rhs.shape[-1].value >= min_dot_size[1], \ + f"Input shapes should have M >= {min_dot_size[0]}, N >= {min_dot_size[1]} and K >= {min_dot_size[2]}" + if lhs.type.scalar.is_int(): + assert lhs.type.scalar == tl.int8, "only int8 supported!" + _0 = builder.get_int32(0) + ret_scalar_ty = tl.int32 + elif out_dtype.is_bf16(): + raise ValueError( + "out_dtype=bfloat16 is unsupported. Please use out_dtype=float32/float16 and cast with `.to(tl.bfloat16)`") + elif lhs.type.scalar.is_fp32() or lhs.type.scalar.is_bf16(): + _0 = builder.get_fp32(0) + ret_scalar_ty = tl.float32 + else: + _0 = builder.get_fp16(0) if out_dtype.is_fp16() else builder.get_fp32(0) + ret_scalar_ty = out_dtype + + M = lhs.type.shape[-2] + N = rhs.type.shape[-1] + K = lhs.type.shape[-1] + B = lhs.type.shape[0] if lhs_rank == 3 else None + ret_ty = tl.block_type(ret_scalar_ty, [B, M, N] if B else [M, N]) + if acc is None: + acc_handle = builder.create_splat(_0, [B, M, N] if B else [M, N]) + else: + acc_handle = acc.handle + assert acc.type == ret_ty + + if (input_precision == getattr(ir.INPUT_PRECISION, "HF32")): + if (not lhs.dtype.is_fp32() or not rhs.dtype.is_fp32() or not ret_scalar_ty.is_fp32()): + raise ValueError("input_precision = 'hf32' must be used with f32 * f32 = f32 on Ascend") + + if max_num_imprecise_acc is not None: + print("max_num_imprecise_acc in tl.dot is not supported on Ascend yet. Thus it is ignored.") + max_num_imprecise_acc = 0 + return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, acc_handle, input_precision, max_num_imprecise_acc), + ret_ty) + + +def _str_to_fp_type(float_format: Optional[str]): + if float_format == 'e4m3': + return ir.F8F6F4TY.E4M3 + if float_format == 'e5m2': + return ir.F8F6F4TY.E5M2 + if float_format == 'e2m3': + return ir.F8F6F4TY.E2M3 + if float_format == 'e3m2': + return ir.F8F6F4TY.E3M2 + if float_format == 'e2m1': + return ir.F8F6F4TY.E2M1 + if float_format == 'bf16': + return ir.F8F6F4TY.BF16 + if float_format == 'fp16': + return ir.F8F6F4TY.FP16 + raise ValueError(f"Invalid float format: {float_format}.") + + +def _bitcast_to_fp_type(val: tl.tensor, float_format: str, builder: ir.builder): triton_ty = {"e5m2": tl.float8e5, "e4m3": tl.float8e4nv, "bf16": tl.bfloat16, "fp16": tl.float16}.get(float_format) if triton_ty is None: assert float_format == "e2m1", f"Internal Error: Unexpected float format: {float_format}" @@ -241,35 +1649,152 @@ def _bitcast_to_fp_type(val, float_format, builder): return bitcast(val, triton_ty, builder) -def dot_scaled_lhs_bitcast_to_fp_type(lhs, lhs_format, builder): +def dot_scaled(lhs: tl.tensor, lhs_scale: tl.tensor, lhs_format: str, rhs: tl.tensor, rhs_scale: Optional[tl.tensor], + rhs_format: str, acc: Union[tl.tensor, None], out_dtype: tl.dtype, lhs_k_pack, rhs_k_pack, + builder: ir.builder) -> tl.tensor: + assert lhs.type.is_block() and rhs.type.is_block() + assert lhs.dtype == tl.bfloat16 or lhs.dtype == tl.float16, f"lhs matrix dtype must be bf16 or fp16" + assert rhs.dtype == tl.bfloat16 or rhs.dtype == tl.float16, f"rhs matrix dtype must be bf16 or fp16" + assert lhs.dtype == rhs.dtype, f"lhs rhs matrix must get same dtype" + lhs_rank = len(lhs.shape) + rhs_rank = len(rhs.shape) + assert lhs_rank == rhs_rank == 2 or lhs_rank == rhs_rank == 3, f"Both inputs must be either 2D or 3D; (lhs: {lhs.shape} vs rhs: {rhs.shape})" lhs_format: str = lhs_format.value - lhs = _bitcast_to_fp_type(lhs, lhs_format, builder) - return lhs - - -def dot_scaled_rhs_bitcast_to_fp_type(rhs, rhs_format, builder): rhs_format: str = rhs_format.value + lhs_format_enum = _str_to_fp_type(lhs_format) + rhs_format_enum = _str_to_fp_type(rhs_format) + allowed_formats = {"bf16", "fp16"} # unsupported fp8/4 dtype: "e2m1", "e4m3", "e5m2" + assert lhs_format in allowed_formats, f"NYI: lhs_format {lhs_format}" + assert rhs_format in allowed_formats, f"NYI: rhs_format {rhs_format}" + rhs_scale_is_none = rhs_scale is None or (isinstance(rhs_scale, tl.constexpr) and rhs_scale.value is None) + lhs_scale_is_none = lhs_scale is None or (isinstance(lhs_scale, tl.constexpr) and lhs_scale.value is None) + assert isinstance(lhs_scale, tl.tensor) and lhs_scale.dtype == tl.int8, f"lhs_scale must be int8 tensor" + if not rhs_scale_is_none: + assert isinstance(rhs_scale, tl.tensor) and rhs_scale.dtype == tl.int8, f"rhs_scale must be int8 tensor" + lhs = _bitcast_to_fp_type(lhs, lhs_format, builder) rhs = _bitcast_to_fp_type(rhs, rhs_format, builder) - return rhs + if lhs_k_pack == False: + dims = (1, 0) + dims = core._unwrap_iterable(dims) + tmp_lhs = permute(lhs, dims, builder) + lhs = reshape(tmp_lhs, (lhs.shape[0], lhs.shape[1]), True, builder) + + if rhs_k_pack == False: + dims = (1, 0) + dims = core._unwrap_iterable(dims) + tmp_rhs = permute(rhs, dims, builder) + rhs = reshape(tmp_rhs, (rhs.shape[0], rhs.shape[1]), True, builder) -def check_dot_scaled_dimension(lhs, rhs): assert lhs.type.shape[-1] == rhs.type.shape[-2], (f"lhs last dimension (columns) {lhs.shape[-1]} " f"must equal rhs penultimate dimension (rows) {rhs.shape[-2]}") - - -def check_dot_scaled_pack_size(PACKED_A, K, lhs_format, lhs, rhs): - lhs_format: str = lhs_format.value + M = lhs.type.shape[-2] + K, N = rhs.type.shape[-2:] + PACKED_A = 2 if lhs_format == "e2m1" else 1 PACKED_B = 2 if lhs_format == "e2m1" else 1 assert K * PACKED_B == PACKED_A * lhs.type.shape[ -1], f"Reduction dimension should pack the same number of elements; (lhs: {lhs.shape} vs rhs: {rhs.shape})" + B = lhs.type.shape[0] if lhs_rank == 3 else None + + ret_ty = tl.block_type(out_dtype, [B, M, N] if B else [M, N]) + _0 = builder.get_fp32(0) + if acc is None: + acc_handle = builder.create_splat(_0, [B, M, N] if B else [M, N]) + else: + acc_handle = acc.handle + assert acc.type == ret_ty + rhs_scale_handle = None if rhs_scale_is_none else rhs_scale.handle + lhs_scale_handle = None if lhs_scale_is_none else lhs_scale.handle + return tl.tensor( + builder.create_dot_scaled(lhs.handle, lhs_scale.handle, lhs_format_enum, rhs.handle, rhs_scale_handle, + rhs_format_enum, acc_handle), ret_ty) + +# ===----------------------------------------------------------------------===// +# Indexing +# ===----------------------------------------------------------------------===// -def set_dot_scaled_lhs_scale_handle(lhs_scale, lhs_scale_is_none): - return None if lhs_scale_is_none else lhs_scale.handle + +def where(condition: tl.tensor, x: tl.tensor, y: tl.tensor, builder: ir.builder) -> tl.tensor: + if condition.dtype != tl.int1: + warnings.warn( + f"tl.where with a non-boolean condition is deprecated and will error out in a future triton release. Got {condition.dtype}" + ) + condition = cast(condition, tl.int1, builder) + x, y = binary_op_type_checking_impl(x, y, builder, True, True) + # x, y are broadcasted + if condition.type.is_block(): + condition, x = broadcast_impl_value(condition, x, builder) + x, y = broadcast_impl_value(x, y, builder) + else: + condition, _ = broadcast_impl_value(condition, x, builder) + ret_ty = x.type + return tl.tensor(builder.create_select(condition.handle, x.handle, y.handle), ret_ty) -def ext_semantic_gather(src: tl.tensor, index: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: +# ===----------------------------------------------------------------------===// +# Reduction +# ===----------------------------------------------------------------------=== + + +def wrap_tensor(x, scalar_ty, ret_shape): + if ret_shape: + res_ty = tl.block_type(scalar_ty, ret_shape) + else: + # 0d-tensor -> scalar + res_ty = scalar_ty + return tl.tensor(x, res_ty) + + +def reduction(inputs: Sequence[tl.tensor], axis: int, region_builder_fn, builder: ir.builder) -> Tuple[tl.tensor, ...]: + if axis is None: + inputs = tuple(reshape(t, [t.numel.value], can_reorder=True, builder=builder) for t in inputs) + axis = 0 + # get result shape + shape = inputs[0].type.shape + rank = len(shape) + assert axis < rank, f"reduction axis must be < inputs rank ({rank})" + ret_shape = [s for i, s in enumerate(shape) if i != axis] + assert all(t.type.shape == shape for t in inputs), "all reduction inputs must have the same shape" + + reduce_op = builder.create_reduce([t.handle for t in inputs], axis) + region_builder_fn(reduce_op) + reduce_op.verify() + + return tuple(wrap_tensor(reduce_op.get_result(i), inputs[i].type.scalar, ret_shape) for i in range(len(inputs))) + + +# ===----------------------------------------------------------------------=== +# Associative Scan +# ===----------------------------------------------------------------------=== + + +def associative_scan(inputs: Sequence[tl.tensor], axis: int, region_builder_fn, reverse: bool, + builder: ir.builder) -> Tuple[tl.tensor, ...]: + shape = inputs[0].type.shape + rank = len(shape) + + assert -rank <= axis < rank, f"scan axis {axis} must be < inputs rank ({rank})" + + if axis < 0: + axis += rank + + for t in inputs: + assert t.type.shape == shape, "all scan inputs must have the same shape" + + scan_op = builder.create_scan([t.handle for t in inputs], axis, reverse) + region_builder_fn(scan_op) + scan_op.verify() + + return tuple(wrap_tensor(scan_op.get_result(i), inputs[i].type.scalar, shape) for i in range(len(inputs))) + + +# ===----------------------------------------------------------------------=== +# Gather +# ===----------------------------------------------------------------------=== + + +def gather(src: tl.tensor, index: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: assert index.dtype.is_int(), "index must be an integer tensor" if not (src.dtype.is_floating() or src.dtype.is_int8()): raise ValueError(f"Expected dtype fp16/fp32/bf16/f8E5M2/f8E4M3FN/int8, but got {src.dtype}") @@ -290,134 +1815,148 @@ def ext_semantic_gather(src: tl.tensor, index: tl.tensor, axis: int, builder: ir return wrap_tensor(gather, src.type.scalar, index.type.shape) -def ext_semantic_insert_slice(ful: tl.tensor, sub: tl.tensor, offsets: List[tl.tensor], sizes: List[int], - strides: List[int], builder: ir.builder) -> tl.tensor: - assert (len(ful.shape) == len(offsets)) - assert (len(ful.shape) == len(sizes)) - assert (len(ful.shape) == len(strides)) - assert (all([s >= 1 for s in sizes])) - assert (all([s >= 0 for s in strides])) - new_offsets = [o.handle for o in offsets] - ret_type = tl.block_type(ful.type.scalar, ful.shape) - out = builder.create_insert_slice(ful.handle, sub.handle, new_offsets, sizes, strides) - return tl.tensor(out, ret_type) +# ===----------------------------------------------------------------------=== +# Histogram +# ===----------------------------------------------------------------------=== -def ext_semantic_extract_slice(ful: tl.tensor, offsets: List[tl.tensor], sizes: List[int], strides: List[int], - builder: ir.builder) -> tl.tensor: - assert (len(ful.shape) == len(offsets)) - assert (len(ful.shape) == len(sizes)) - assert (len(ful.shape) == len(strides)) - assert (all([s >= 1 for s in sizes])) - assert (all([s >= 0 for s in strides])) - new_offsets = [o.handle for o in offsets] - ret_type = tl.block_type(ful.type.scalar, sizes) - out = builder.create_extract_slice(ful.handle, new_offsets, sizes, strides) - return tl.tensor(out, ret_type) +def histogram(input: tl.tensor, num_bins: int, builder: ir.builder) -> tl.tensor: + assert len(input.shape) == 1, "histogram only supports 1D input" + assert input.dtype.is_int(), "histogram only supports integer input" + return tl.tensor(builder.create_histogram(input.handle, num_bins), tl.block_type(tl.int32, (num_bins, ))) -def ext_semantic_get_element(src: tl.tensor, indice: List[tl.tensor], builder: ir.builder): - if len(src.shape) != len(indice): - raise ValueError("Indice's rank must be equal to src tensor's rank") +## - new_indice = [i.handle for i in indice] - result = builder.create_extract_scalar(src.handle, new_indice) - return wrap_tensor(result, src.type.scalar, None) +def multiple_of(x: tl.tensor, values: List[int]) -> tl.tensor: + if max(1, len(x.shape)) != len(values): + raise ValueError("Shape of input to multiple_of does not match the length of values") + x.handle.set_attr("tt.divisibility", ir.make_attr(values, x.handle.get_context())) + return x -def ext_semantic_compile_hint(ptr: tl.tensor, hint_name: str, hint_val, builder: ir.builder): - # simt mode does not support hint annotations - if builder.is_simt_mode(): + +def max_contiguous(x: tl.tensor, values: List[int]) -> tl.tensor: + if len(x.shape) != len(values): + raise ValueError("Shape of input to max_contiguous does not match the length of values") + x.handle.set_attr("tt.contiguity", ir.make_attr(values, x.handle.get_context())) + return x + + +def max_constancy(x: tl.tensor, values: List[int]) -> tl.tensor: + if len(x.shape) != len(values): + raise ValueError("Shape of input to max_constancy does not match the length of values") + x.handle.set_attr("tt.constancy", ir.make_attr(values, x.handle.get_context())) + return x + + +def debug_barrier(builder: ir.builder) -> tl.tensor: + return tl.tensor(builder.create_barrier(), tl.void) + + +def device_print(prefix: str, args: List[tl.tensor], hex: bool, builder: ir.builder) -> tl.tensor: + # It makes sense visually for prefix to end in ": "; make it so. Also, + # non-empty prefixes should start with " ". + if not prefix.endswith(" ") and args: + prefix += " " + if not prefix.endswith(": ") and args: + prefix = prefix[:-1] + ": " + if len(prefix) > 2 and not prefix.startswith(" "): + prefix = " " + prefix + + new_args = [arg.handle for arg in args] + is_signed = [arg.dtype in (tl.int1, tl.int8, tl.int16, tl.int32, tl.int64) for arg in args] + return tl.tensor(builder.create_print(prefix, hex, new_args, is_signed), tl.void) + + +def device_assert(cond: tl.tensor, msg: str, builder: ir.builder) -> tl.tensor: + if not builder.options.debug: return - if not hint_val: - hint_val = builder.get_unit_attr() - elif isinstance(hint_val, bool): - hint_val = builder.get_bool_attr(hint_val) - elif isinstance(hint_val, int): - hint_val = builder.get_int32_attr(hint_val) - elif isinstance(hint_val, tl.constexpr): - hint_val = builder.get_str_attr(hint_val.value) - elif isinstance(hint_val, list): - # only support i64 array attr for now - hint_val = builder.get_i64_array_attr(hint_val) - else: - raise ValueError(f"Unsupported hint value type: {type(hint_val)}") - builder.create_annotation(ptr.handle, hint_name, hint_val) - - -def ext_semantic_custom_op(builder: ir.builder, op_name: str, **kwargs): - if op_name == "sync_block_all": - return builder.create_custom_op_for_inter_core_sync(op_name, kwargs["mode"], kwargs["event_id"]) - - elif op_name == "sync_block_set": - return builder.create_custom_op_for_inter_core_sync(op_name, kwargs["sender"], kwargs["event_id"]) - - elif op_name == "sync_block_wait": - return builder.create_custom_op_for_inter_core_sync(op_name, kwargs["sender"], kwargs["event_id"]) - - raise ValueError(f"Unsupported custom op: {op_name}") - - -def ext_semantic_sort(ptr: tl.tensor, dim: int, descending, builder: ir.builder): - """ - Triton sort 操作 - - 参数: - ptr: tl.tensor,输入张量 - dim: int,排序维度,必须是尾轴(最后一维) - descending: bool 或 constexpr,是否降序 - builder: ir.builder,底层 IR 构建器 - 返回: - values: tl.tensor,排序后的值(类型与输入一致) - """ - - allowed_types = { - tl.int8, tl.int16, tl.bfloat16, tl.float16, tl.float32, tl.int32, tl.int64, tl.float8e4nv, tl.float8e5 - } - base_ty = ptr.type.scalar if hasattr(ptr.type, "scalar") else ptr.type - if base_ty not in allowed_types: - raise TypeError( - f"tt.sort only supports int8, int16, bfloat16, float16, float32, int32, int64, float8e4nv, float8e5" - f"but got {ptr.type}") - - shape = getattr(ptr, "shape", None) - if shape is None or shape == (): - shape = getattr(getattr(ptr, "type", None), "shape", None) - - rank = None - if shape is not None: - try: - rank = len(shape) - except Exception: - rank = len(list(shape)) - - if rank is not None: - if rank < 1: - raise ValueError("tt.sort requires tensor rank >= 1") - last_dim = rank - 1 - norm_dim = dim if dim >= 0 else dim + rank - if norm_dim != last_dim: - raise ValueError(f"tt.sort only supports sorting along the last dimension " - f"(dim={last_dim} or -1) for shape {tuple(shape)}, but got dim={dim}") - dim = last_dim - else: - if dim != -1: - raise ValueError("tt.sort only supports the last dimension; when rank is unknown " - "you must pass dim=-1") + return tl.tensor(builder.create_assert(cond.handle, msg), tl.void) - if hasattr(descending, "value"): - descending = bool(descending.value) - else: - descending = bool(descending) - sorted_vals = builder.create_sort(ptr.handle, dim, descending) +def assume(cond, builder: ir.builder) -> tl.tensor: + return tl.tensor(builder.create_assume(cond.handle), tl.void) - values = tl.tensor(sorted_vals, type=ptr.type) - return values +def _convert_elem_to_ir_value(builder, elem, require_i64): + if isinstance(elem, int): + elem = tl.constexpr(elem) + if isinstance(elem, tl.constexpr): + if require_i64: + assert -2**63 <= elem.value < 2**63, f"Block pointers only support 64 bit `shape/strides`, " \ + f"got a value {elem.value} which is out of the range" + return builder.get_int64(elem.value) + else: + assert -2**31 <= elem.value < 2**31, f"Block pointers only support 32 bit `offsets/block_shape`, " \ + f"got a value {elem.value} which is out of the range" + return builder.get_int32(elem.value) + elif isinstance(elem, tl.tensor): + assert elem.numel.value == 1, "Expected a scalar in shape/strides/offsets" + assert elem.dtype.is_int(), "Expected an integer scalar type in shape/strides/offsets" + if elem.dtype != tl.int64 and require_i64: + return builder.create_int_cast(elem.handle, builder.get_int64_ty(), elem.dtype.is_int_signed()) + elif elem.dtype != tl.int32 and not require_i64: + assert False, "Block pointers only support 32 bit `offsets/block_shape`, " \ + "add a `.to(tl.int32)` or use regular indexing for 64 bit support" + return elem.handle + assert False, f"Unsupported element type in shape/strides/offsets: {type(elem)}" + + +def _convert_to_ir_values(builder, list_like, require_i64=True): + if hasattr(list_like, "__iter__"): + return [_convert_elem_to_ir_value(builder, elem, require_i64) for elem in list_like] + return [_convert_elem_to_ir_value(builder, list_like, require_i64)] + + +def make_block_ptr(base: tl.tensor, shape, strides, offsets, block_shape, order, builder: ir.builder) -> tl.tensor: + # Convert dynamic arguments to IR values + # NOTES(Chenggang): current `shape/strides` are `int64_t`, while `offsets/block_shape` are `int32_t` + shape = _convert_to_ir_values(builder, shape) + strides = _convert_to_ir_values(builder, strides) + offsets = _convert_to_ir_values(builder, offsets, require_i64=False) + + # Check `base` type + if not base.type.is_ptr() or base.type.element_ty.is_block(): + raise ValueError("Expected `base` to be a pointer type (but not a block pointer type or others)") + + # Treat `pointer_type` as `pointer_type` + if base.type.element_ty == tl.int1: + base = cast(base, tl.pointer_type(tl.int8, base.type.address_space), builder) + + # Check whether `block_shape` is static + if not hasattr(block_shape, "__iter__"): + block_shape = [block_shape] + block_shape = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in block_shape] + assert all(isinstance(elem, int) and -2**31 <= elem < 2**31 for elem in block_shape), \ + "Expected a list of constant integers (`int32_t` range) in `block_shape`" + + # Check `order` + if not hasattr(order, "__iter__"): + order = [order] + order = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in order] + assert sorted(order) == list(range(len(order))), "Expected a permutation of (0, 1, ..., len(order)-1) in order" + + # Must have same length + assert all(len(block_shape) == len(list_like) for list_like in [shape, strides, offsets, order]), \ + "Expected shape/strides/offsets/block_shape to have the same length" + # Build value, the type is: + # `pointer_type>` in Python + # `tt.ptr>` in MLIR + handle = builder.create_make_block_ptr(base.handle, shape, strides, offsets, block_shape, order) + return tl.tensor(handle, tl.pointer_type(tl.block_type(base.type.element_ty, block_shape))) -def ext_semantic_scalar_constant(value, dtype: tl.dtype, builder: ir.builder) -> tl.tensor: + +def advance(base: tl.tensor, offsets, builder: ir.builder) -> tl.tensor: + # Convert dynamic offsets to IR values + offsets = _convert_to_ir_values(builder, offsets, require_i64=False) + + # Advanced block pointer type is the same as before + return tl.tensor(builder.create_advance(base.handle, offsets), base.type) + + +def scalar_constant(value, dtype: tl.dtype, builder: ir.builder) -> tl.tensor: if dtype is None: raise ValueError("dtype must be specified when value is not a tensor") if value == 0: @@ -428,15 +1967,40 @@ def ext_semantic_scalar_constant(value, dtype: tl.dtype, builder: ir.builder) -> return tl.tensor(value, dtype) -def ext_semantic_make_scalar(value, dtype: tl.dtype, builder: ir.builder) -> tl.tensor: +def make_scalar(value, dtype: tl.dtype, builder: ir.builder) -> tl.tensor: if isinstance(value, tl.tensor): assert value.numel.value == 1, "only accepts size-1 tensor" return cast(value, dtype, builder) - return ext_semantic_scalar_constant(value, dtype, builder) + return scalar_constant(value, dtype, builder) + + +def descriptor_load(desc: tl.tensor_descriptor_base, offsets, cache_modifier: str, eviction_policy: str, + builder: ir.builder) -> tl.tensor: + assert isinstance(desc, tl.tensor_descriptor_base) + ndim = len(desc.block_shape) + assert len(offsets) == ndim, f"expected {ndim} offsets, but got {len(offsets)}" + + offsets = _convert_to_ir_values(builder, offsets, require_i64=False) + x = builder.create_descriptor_load(desc.handle, offsets, _str_to_load_cache_modifier(cache_modifier), + _str_to_eviction_policy(eviction_policy)) + return tl.tensor(x, desc.block_type) + + +def validate_store_like(desc: tl.tensor_descriptor_base, value: tl.tensor, offsets) -> None: + assert isinstance(desc, tl.tensor_descriptor_base) + ndim = len(desc.block_shape) + assert len(offsets) == ndim, f"expected {ndim} offsets, but got {len(offsets)}" + assert value.shape == desc.block_shape + + +def descriptor_store(desc: tl.tensor_descriptor_base, value: tl.tensor, offsets, builder: ir.builder) -> tl.tensor: + validate_store_like(desc, value, offsets) + offsets = _convert_to_ir_values(builder, offsets, require_i64=False) + return tl.tensor(builder.create_descriptor_store(desc.handle, value.handle, offsets), tl.void) -def ext_semantic_make_tensor_descriptor(base: tl.tensor, shape: List[tl.tensor], strides: List[tl.tensor], - block_shape: List[tl.constexpr], builder: ir.builder) -> tensor_descriptor: +def make_tensor_descriptor(base: tl.tensor, shape: List[tl.tensor], strides: List[tl.tensor], + block_shape: List[tl.constexpr], builder: ir.builder) -> tl.tensor_descriptor: ndim = len(shape) if not (1 <= ndim <= 5): raise ValueError(f"Expected 1 <= ndim <= 5 but got {ndim} dimensions") @@ -449,485 +2013,26 @@ def ext_semantic_make_tensor_descriptor(base: tl.tensor, shape: List[tl.tensor], if primitive_bitwidth == 1: raise ValueError("int1 type is not supported for make_tensor_descriptor yet") elem_size = primitive_bitwidth // 8 - contig_dim_size = _unwrap_if_constexpr(block_shape[-1]) + contig_dim_size = tl._unwrap_if_constexpr(block_shape[-1]) if contig_dim_size * elem_size < 16: raise ValueError( f"Descriptor block shape must have at least 16 bytes in the last dimension, but got {contig_dim_size} * {elem_size} = {contig_dim_size * elem_size} bytes" ) - strides[-1] = _unwrap_if_constexpr(strides[-1]) + strides[-1] = tl._unwrap_if_constexpr(strides[-1]) if strides[-1] != 1: raise ValueError(f"Tensor descriptor last dim must be 1 but got {strides[-1]}") - shape = [ext_semantic_make_scalar(x, tl.int32, builder) for x in shape] - strides = [ext_semantic_make_scalar(x, tl.int64, builder) for x in strides] + shape = [make_scalar(x, tl.int32, builder) for x in shape] + strides = [make_scalar(x, tl.int64, builder) for x in strides] - block_shape = _unwrap_shape(block_shape) + block_shape = tl._unwrap_shape(block_shape) assert isinstance(base.type, tl.pointer_type) - desc_block_type = block_type(base.type.element_ty, block_shape) + desc_block_type = tl.block_type(base.type.element_ty, block_shape) base_handle = base.handle is_signed_int = base.type.element_ty.is_int_signed() handle = builder.create_make_tensor_descriptor(base_handle, [s.handle for s in shape], [s.handle for s in strides], block_shape, is_signed_int) - return tensor_descriptor(handle, shape, strides, desc_block_type) - - -def ext_semantic_index_select_simd(src: tl.tensor, dim: int, index: tl.tensor, src_shape: List[Union[int, tl.tensor]], - src_offset: List[Union[int, tl.tensor]], read_shape: List[Union[int, tl.tensor]], - builder: ir.builder) -> tl.tensor: - """ - Index select operation (SIMD version) that loads data from multiple indices along a dimension. - - Args: - src: Source tensor pointer (in GM) - dim: Dimension along which to select indices - index: 1D tensor of indices to select (in UB) - src_shape: Complete shape of source tensor. Each element can be int or tensor. - src_offset: Starting offset for reading. Each element can be int or tensor. - read_shape: Size to read (tile shape). Each element can be int or tensor. - builder: IR builder - - Returns: - Result tensor in UB - - Constraints: - - read_shape[dim] must be -1 - - src_offset[dim] can be -1 (ignored) - - All list parameters must have the same length (ndim) - """ - # Validate inputs - ndim = len(src_shape) - assert len(src_offset) == ndim, \ - f"src_offset length {len(src_offset)} must match src_shape length {ndim}" - assert len(read_shape) == ndim, \ - f"read_shape length {len(read_shape)} must match src_shape length {ndim}" - assert 0 <= dim < ndim, \ - f"dim={dim} must be in range [0, {ndim})" - assert len(index.shape) == 1, \ - f"index must be 1D tensor, got {len(index.shape)}D" - assert dim < ndim - 1, \ - f"index_select_simd cannot support trailing dimension as dim={dim}, ndim={ndim}" - - newsrc_shape = [o.handle for o in src_shape] - newsrc_offset = [o.handle for o in src_offset] - # Create output type - return_shape = [index.shape[0] if i == dim else read_shape[i] for i in range(ndim)] - element_ty = src.type.element_ty - output_ty = tl.block_type(element_ty, return_shape) - out = builder.create_index_select_simd(src.handle, index.handle, dim, newsrc_shape, newsrc_offset, read_shape, - return_shape) - return tl.tensor(out, output_ty) - - -def ext_semantic__load_block_pointer(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder): - # Load by a block pointer: pointer_type> - # Block pointer can not have mask and other arguments - if mask is not None or other is not None: - raise ValueError("mask and other arguments cannot be specified for loading block pointers") - - elt_ty = ptr.type.element_ty.element_ty - assert elt_ty != tl.int1, "`tl.int1` should be rewrited in `tl.make_block_ptr`" - if elt_ty.is_int() and padding == ir.PADDING_OPTION.PAD_NAN: - raise ValueError("Padding option `nan` is not supported for integer block pointers") - - # `dst_ty` is de-referenced type of the pointer type - dst_ty = ptr.type.element_ty - - # Check `boundary_check` argument - boundary_check = _canonicalize_boundary_check(boundary_check, dst_ty.get_block_shapes()) - - if boundary_check and padding is None: - padding = ir.PADDING_OPTION.PAD_ZERO - - # Build IR - return tl.tensor( - builder.create_tensor_pointer_load(ptr.handle, boundary_check, padding, cache, eviction, is_volatile), dst_ty) - - -def ext_semantic_flip_simd(ptr: tl.tensor, dim: int, builder: ir.builder): - """ - Triton flip operation for simd - - Args: - ptr: tl.tensor, input tensor - dim: int, dimension to flip (can be negative, normalized here) - builder: ir.builder, underlying IR builder - Returns: - flipped: tl.tensor, same type and shape as input - """ - - shape = getattr(ptr, "shape", None) - if shape is None or shape == (): - shape = getattr(getattr(ptr, "type", None), "shape", None) - - rank = None - if shape is not None: - try: - rank = len(shape) - except Exception: - rank = len(list(shape)) - - if rank is not None: - if rank < 1: - raise ValueError("tt.flip requires tensor rank >= 1") - norm_dim = dim if dim >= 0 else dim + rank - if not (0 <= norm_dim < rank): - raise ValueError(f"tt.flip got invalid dim={dim} for shape {tuple(shape)}") - dim = norm_dim - else: - if dim < 0: - raise ValueError("tt.flip with unknown rank requires non-negative dim") - - flipped_vals = builder.create_flip(ptr.handle, dim) - flipped = tl.tensor(flipped_vals, type=ptr.type) - return flipped - - -def _get_flip_dim(dim, shape): - dim = _unwrap_if_constexpr(dim) - shape = _unwrap_if_constexpr(shape) - if dim is None: - dim = len(shape) - 1 - if dim < 0: # flip doesn't work if dim < 0 because the xor-swap for loop will start/end at the wrong index - dim += len(shape) - return core.constexpr(dim) - - -def _log2(i: core.constexpr): - log2 = 0 - n = core.constexpr(i).value - while n > 1: - n >>= 1 - log2 += 1 - return core.constexpr(log2) - - -def ext_semantic_flip(ptr: tl.tensor, dim: int, builder: ir.builder, generator=None): - """ - Flips a tensor `ptr` along the dimension `dim`. - - :param ptr: the first input tensor - :type ptr: tl.tensor - :param dim: the dimension to flip along - :type dim: int - :param generator: the code generator (required for reduce operations) - :type generator: generator object - """ - - # If compile_mode is not simt, use the simd implementation - if not builder.is_simt_mode(): - return ext_semantic_flip_simd(ptr, dim, builder) - core.static_assert(-len(ptr.shape) <= dim and dim < len(ptr.shape), _builder=builder) - _dim: core.constexpr = _get_flip_dim(dim, ptr.shape) - core.static_assert(standard._is_power_of_two(ptr.shape[_dim]), _builder=builder) - steps: core.constexpr = _log2(ptr.shape[_dim]) - # If steps is 0, return the original tensor - if steps == 0: - return ptr - # reshape the swap dimension to (2, 2, ..., 2) - idtype = core.get_int_dtype(bitwidth=ptr.dtype.primitive_bitwidth, signed=True) - y = core.reshape( - ptr.to(idtype, bitcast=True, _builder=builder), - ptr.shape.__getitem__(slice(None, _dim)) + [2] * steps + ptr.shape.__getitem__(slice(_dim + 1, None)), - _builder=builder) - for i in ext_semantic_static_range(steps): - y = y.__xor__(standard.xor_sum(y, _dim + i, True, _builder=builder, _generator=generator), _builder=builder) - ptr = core.reshape(y, ptr.shape, _builder=builder).to(ptr.dtype, bitcast=True, _builder=builder) - return ptr - - -class ext_semantic_static_range: - """ - Iterator for non-JIT Python functions that need to iterate over constexpr values. - This is used in functions like flip that are called during compilation. - """ - - def __init__(self, arg1, arg2=None, step=None): - if step is None: - self.step = core.constexpr(1) - else: - self.step = step - if arg2 is None: - self.start = core.constexpr(0) - self.end = arg1 - else: - self.start = arg1 - self.end = arg2 - - def __iter__(self): - # Extract actual values from constexpr objects for iteration - start_val = core._constexpr_to_value(self.start) - end_val = core._constexpr_to_value(self.end) - step_val = core._constexpr_to_value(self.step) - # Store as regular Python integers for iteration - self._current = start_val - self._end = end_val - self._step = step_val - return self - - def __next__(self): - if self._current >= self._end: - raise StopIteration - value = self._current - self._current += self._step - return value - - -def _convert_elem_to_ir_value(builder, elem, require_i64): - if isinstance(elem, int): - elem = tl.constexpr(elem) - if isinstance(elem, tl.constexpr): - if require_i64: - assert -2**63 <= elem.value < 2**63, f"Block pointers only support 64 bit `shape/strides`, " \ - f"got a value {elem.value} which is out of the range" - return builder.get_int64(elem.value) - else: - assert -2**31 <= elem.value < 2**31, f"Block pointers only support 32 bit `offsets/block_shape`, " \ - f"got a value {elem.value} which is out of the range" - return builder.get_int32(elem.value) - elif isinstance(elem, tl.tensor): - if require_i64: - return builder.create_int_cast(elem.handle, builder.get_int64_ty(), elem.dtype.is_int_signed()) - else: - return builder.create_int_cast(elem.handle, builder.get_int32_ty(), elem.dtype.is_int_signed()) - else: - assert False, f"Unsupported element type in shape/strides/offsets: {type(elem)}" - - -def ext_semantic_embedding_gather(src: tl.tensor, idx: tl.tensor, bound: int, blksiz: int, offsets: Tuple, - numels: Tuple, builder: ir.builder) -> tl.tensor: - """ - Embedding - :src_ptr: - :idx: - """ - assert idx.dtype.is_int(), "index must be an integer tensor" - if not src.dtype.element_ty.is_floating(): - raise ValueError(f"Expected dtype fp16/fp32/bf16, but got {src.dtype.element_ty}") - - require_i64 = idx.dtype.is_int64() - # require_i64 = True - offsets = [_convert_elem_to_ir_value(builder, elem, require_i64) for elem in offsets] - numels = [_convert_elem_to_ir_value(builder, elem, require_i64) for elem in numels] - ret = builder.create_embedding_gather(src.handle, idx.handle, bound, blksiz, offsets, numels) - ret_shape = [_unwrap_if_constexpr(s) for s in idx.shape] - ret_shape.append(blksiz) - return wrap_tensor(ret, src.dtype.element_ty, ret_shape) - - -def ext_semantic_index_put(ptr: tl.tensor, index: tl.tensor, value: tl.tensor, dim: int, index_boundary: int, - end_offset: Tuple, start_offset: Tuple, dst_stride: Tuple, builder: ir.builder): - """ - Index put values from a tensor into a destination tensor. - - Index put operation for different tensor ranks: - 1. 2D index scatter (0 <= dim < 1): - 1.1 dim = 0 - out[index[i]][start_offset[1]:end_offset[1]] = value[i][0:end_offset[1]-start_offset[1]] - 2. 3D index scatter (0 <= dim < 2): - 2.1 dim = 0 - out[index[i]][start_offset[1]:end_offset[1]][start_offset[2]:end_offset[2]] - = value[i][0:end_offset[1]-start_offset[1]][0:end_offset[2]-start_offset[2]] - 2.2 dim = 1 - out[start_offset[0]:end_offset[0]][index[j]][start_offset[2]:end_offset[2]] - = value[0:end_offset[0]-start_offset[0]][j][0:end_offset[2]-start_offset[2]] - - Args: - - ptr: pointer type, the destination tensor pointer (in GM) - - index: tensor, a index to scatter (in UB) - - value: tensor, a value to store (in UB) - - dim: int32, the dimension to scatter along - - index_boundary: int64, the upper boundary for index values - - end_offset: tuple of int, the offsets of each dimension for the end of the scatter region - - start_offset: tuple of int, the offsets of each dimension for the start of the scatter region - - dst_stride: tuple of int, the stride of each dimension of destination tensor - - Constraints: - - `ptr` and `value` must have the same rank. - - `ptr.dtype` only supports `float16`, `bfloat16`, `float32` currently. - - `index` must be an integer tensor. If `index.rank` != 1, it will be reshaped to 1D. - - `index.numel` must equal `value.shape[dim]`. - - `value` support 2~5D tensors. - - `dim` must be valid (0 <= dim < rank(value) - 1). - """ - assert index.dtype.is_int(), "index must be an integer tensor" - if not ptr.dtype.element_ty.is_floating(): - raise ValueError(f"Expected dtype fp16/fp32/bf16, but got {ptr.dtype.element_ty}") - if not isinstance(dim, int): - raise ValueError("dim must be of type tl.constexpr") - - v_rank = len(value.shape) - idx_rank = len(index.shape) - if v_rank < 2 or v_rank > 5: - raise ValueError(f"value rank must be in [2, 5], got value rank={v_rank}") - if dim < 0 or dim >= v_rank - 1: - raise ValueError(f"dim must satisfy 0<=dim 5: - raise ValueError(f"index rank must be in [1, 5], got rank={idx_rank}") - if dim < 0 or dim >= idx_rank: - raise ValueError(f"dim must satisfy 0<=dim 5: - raise ValueError(f"index rank must be in [1, 5], got rank={idx_rank}") - if dim < 0 or dim >= idx_rank: - raise ValueError(f"dim must satisfy 0<=dim 1: + n >>= 1 + log2 += 1 + return core.constexpr(log2) + + +def _is_power_of_two(i: core.constexpr): + n = i.value + return core.constexpr((n & (n - 1)) == 0 and n != 0) + + +# ----------------------- +# Standard library +# ----------------------- + +# @core._tensor_member_fn +# @jit +# def cdiv(x, div): +# """ +# Computes the ceiling division of :code:`x` by :code:`div` + +# :param x: the input number +# :type x: Block +# :param div: the divisor +# :type div: Block +# """ +# return (x + div - 1) // div @core._tensor_member_fn @jit @math._add_math_1arg_docstr("sigmoid") def sigmoid(x): - _is_int8_type: core.constexpr = x.dtype.is_int8() - core.static_assert(not _is_int8_type, f"Expected dtype fp16/fp32/bf16, but got int8 or int1") - _is_floating_type: core.constexpr = x.dtype.is_floating() - core.static_assert(_is_floating_type == True, f"Expected dtype fp16/fp32/bf16, but got {core.constexpr(x.dtype)}") - return (1 / (1 + math.exp(-x.to(core.float32)))).to(x.dtype) + return 1 / (1 + math.exp(-x)) @core._tensor_member_fn @jit @math._add_math_1arg_docstr("softmax") def softmax(x, ieee_rounding=False): - _is_int8_type: core.constexpr = x.dtype.is_int8() - core.static_assert(not _is_int8_type, f"Expected dtype fp16/fp32/bf16, but got int8 or int1") - _is_floating_type: core.constexpr = x.dtype.is_floating() - core.static_assert(_is_floating_type == True, f"Expected dtype fp16/fp32/bf16, but got {core.constexpr(x.dtype)}") - z = x.to(core.float32) - max(x, 0) + z = x - max(x, 0) num = math.exp(z) den = sum(num, 0) - return math.fdiv(num, den, ieee_rounding).to(x.dtype) + return math.fdiv(num, den, ieee_rounding) @core._tensor_member_fn @jit -@math._add_math_1arg_docstr("isfinited") -def isfinited(x): - _is_int8_type: core.constexpr = x.dtype.is_int8() - core.static_assert(not _is_int8_type, f"Expected dtype fp16/fp32/bf16, but got int8 or int1") - _is_floating_type: core.constexpr = x.dtype.is_floating() - core.static_assert(_is_floating_type == True, f"Expected dtype fp16/fp32/bf16, but got {core.constexpr(x.dtype)}") - nan_mask = math.isnan(x) - inf_mask = math.isinf(x) - return (~nan_mask & ~inf_mask).to(int1) +def ravel(x): + """ + Returns a contiguous flattened view of :code:`x`. + + :param x: the input tensor + :type x: Block + """ + return core.reshape(x, [x.numel], can_reorder=False) -@core._tensor_member_fn @jit -@math._add_math_1arg_docstr("finitef") -def finitef(x): - _is_int8_type: core.constexpr = x.dtype.is_int8() - core.static_assert(not _is_int8_type, f"finitef only supports float32, but got int8 or int1") - core.static_assert(x.dtype == float32, f"finitef only supports float32, but got {core.constexpr(x.dtype)}") - nan_mask = math.isnan(x) - inf_mask = math.isinf(x) - return (~nan_mask & ~inf_mask).to(int1) +def swizzle2d(i, j, size_i, size_j, size_g): + """ + Transforms the indices of a row-major `size_i * size_j` matrix into + the indices of a column-major matrix for each group of `size_g` rows. + + For example, for :code:`size_i = size_j = 4` and :code:`size_g = 2`, it will + transform :: + + [[0 , 1 , 2 , 3 ], + [4 , 5 , 6 , 7 ], + [8 , 9 , 10, 11], + [12, 13, 14, 15]] + + into :: + + [[0, 2, 4 , 6 ], + [1, 3, 5 , 7 ], + [8, 10, 12, 14], + [9, 11, 13, 15]] + """ + # "unrolled index in array" + ij = i * size_j + j + # number of elements in `size_g` groups + # of `size_j` columns + size_gj = size_g * size_j + # index of the group in which (i,j) is + group_id = ij // size_gj + # row-index of the first element of this group + off_i = group_id * size_g + # last group may have fewer rows + size_g = core.minimum(size_i - off_i, size_g) + # linear index with respect to the first element in this group + ij = ij % size_gj + # new row and column indices + new_i = off_i + ij % size_g + new_j = ij // size_g + return new_i, new_j -@core._tensor_member_fn @jit -@math._add_math_1arg_docstr("rint") -def rint(x): - _is_int8_type: core.constexpr = x.dtype.is_int8() - core.static_assert(not _is_int8_type, f"Expected dtype fp16/fp32/bf16, but got int8 or int1") - _is_floating_type: core.constexpr = x.dtype.is_floating() - core.static_assert(_is_floating_type == True, f"Expected dtype fp16/fp32/bf16, but got {core.constexpr(x.dtype)}") - # Calculate integer part and fractional part - floor_x = math.floor(x) - fractional = x - floor_x - # Check if fractional part is close to 0.5 - is_half = math.abs(fractional - 0.5) < 1e-8 - # Check if integer part is even - floor_int = floor_x.to(int32) - is_even = (floor_int % 2) == 0 - # Apply bankers rounding rules: - # - If fractional part is 0.5: keep integer part if even, add 1 if odd - # - Otherwise: round to the nearest integer directly - return core.where(is_half, core.where(is_even, floor_x, floor_x + 1.0), - core.where(x >= 0, math.floor(x + 0.5), math.ceil(x - 0.5))) - - -pi: core.constexpr = math_pi +def zeros(shape, dtype): + """ + Returns a tensor filled with the scalar value 0 for the given :code:`shape` and :code:`dtype`. + + :param shape: Shape of the new array, e.g., (8, 16) or (8, ) + :type shape: tuple of ints + :param dtype: Data-type of the new array, e.g., :code:`tl.float16` + :type dtype: DType + """ + return core.full(shape, 0, dtype) -@core._tensor_member_fn @jit -@math._add_math_2arg_docstr("atan2") -def atan2(y, x): - _is_int8_type_x: core.constexpr = x.dtype.is_int8() - core.static_assert(not _is_int8_type_x, f"Expected dtype fp16/fp32/bf16, but got int8 or int1") - _is_int8_type_y: core.constexpr = y.dtype.is_int8() - core.static_assert(not _is_int8_type_y, f"Expected dtype fp16/fp32/bf16, but got int8 or int1") - _is_floating_type_x: core.constexpr = x.dtype.is_floating() - core.static_assert(_is_floating_type_x == True, f"Expected dtype fp16/fp32/bf16, but got {core.constexpr(x.dtype)}") - _is_floating_type_y: core.constexpr = y.dtype.is_floating() - core.static_assert(_is_floating_type_y == True, f"Expected dtype fp16/fp32/bf16, but got {core.constexpr(y.dtype)}") - half_pi: core.constexpr = 0.5 * pi - base = core.where(x == 0, 0.0, math.atan(y.to(float32) / x.to(float32))) - base = core.where((x == 0) & (y > 0), half_pi, base) - base = core.where((x == 0) & (y < 0), -half_pi, base) +def zeros_like(input): + """ + Returns a tensor of zeros with the same shape and type as a given tensor. + + :param input: input tensor + :type input: Tensor + """ + return zeros(input.shape, input.dtype) + - add_pi = core.where((x < 0) & (y >= 0), pi, 0.0) - sub_pi = core.where((x < 0) & (y < 0), -pi, 0.0) - return (base + add_pi + sub_pi).to(x.dtype) +# max and argmax @jit @@ -109,7 +142,7 @@ def _argmax_combine(value1, index1, value2, index2, tie_break_left): if tie_break_left: tie = value1 == value2 and index1 < index2 else: - tie = value1 == value2 and index1 > index2 + tie = False gt = value1 > value2 or tie v_ret = core.where(gt, value1, value2) i_ret = core.where(gt, index1, index2) @@ -154,6 +187,11 @@ def max(input, axis=None, return_indices=False, return_indices_tie_break_left=Tr input = input.to(core.float32) else: assert input.dtype.is_int(), "Expecting input to be integer type" + # FIXME: Skip int8/int16 -> int32 promotion on Ascend. + # Converting small integer types (e.g., int8) to int32 consumes excessive UB (Unified Buffer) memory, + # which can lead to "UB overflow" errors during kernel execution. + # Therefore, we keep the original narrow integer type and rely on backend support. + pass # Do not promote to int32 if not propagate_nan: return core.reduce(input, axis, _elementwise_max_default, keep_dims=keep_dims) else: @@ -168,12 +206,15 @@ def argmax(input, axis, tie_break_left=True, keep_dims=False): return ret +# min and argmin + + @jit def _argmin_combine(value1, index1, value2, index2, tie_break_left): if tie_break_left: tie = value1 == value2 and index1 < index2 else: - tie = value1 == value2 and index1 > index2 + tie = False lt = value1 < value2 or tie value_ret = core.where(lt, value1, value2) index_ret = core.where(lt, index1, index2) @@ -212,6 +253,11 @@ def min(input, axis=None, return_indices=False, return_indices_tie_break_left=Tr input = input.to(core.float32) else: assert input.dtype.is_int(), "Expecting input to be integer type" + # FIXME: Skip int8/int16 -> int32 promotion on Ascend. + # Converting small integer types (e.g., int8) to int32 consumes excessive UB (Unified Buffer) memory, + # which can lead to "UB overflow" errors during kernel execution. + # Therefore, we keep the original narrow integer type and rely on backend support. + pass # Do not promote to int32 return core.reduce(input, axis, _elementwise_min, keep_dims=keep_dims) @@ -223,6 +269,22 @@ def argmin(input, axis, tie_break_left=True, keep_dims=False): return ret +@jit +def _sum_combine(a, b): + return a + b + + +# sum + + +@core._tensor_member_fn +@jit +@core._add_reduction_docstr("sum") +def sum(input, axis=None, keep_dims=False): + input = core._promote_bfloat16_to_float32(input) + return core.reduce(input, axis, _sum_combine, keep_dims=keep_dims) + + @jit def _xor_combine(a, b): return a ^ b @@ -232,11 +294,44 @@ def _xor_combine(a, b): @core._tensor_member_fn -@jit +@core.builtin @core._add_reduction_docstr("xor sum") -def xor_sum(x, axis=None, keep_dims=False): - core.static_assert(x.type.scalar.is_int(), "xor_sum only supported for integers") - return core.reduce(x, axis, _xor_combine, keep_dims=keep_dims) +def xor_sum(input, axis=None, keep_dims=False, _builder=None, _generator=None): + scalar_ty = input.type.scalar + if not scalar_ty.is_int(): + raise ValueError("xor_sum only supported for integers") + + input = core._promote_bfloat16_to_float32(input, _builder=_builder) + return core.reduce(input, axis, _xor_combine, keep_dims=keep_dims, _builder=_builder, _generator=_generator) + + +# cumsum + + +@core._tensor_member_fn +@jit +@core._add_scan_docstr("cumsum") +def cumsum(input, axis=0, reverse=False): + # todo rename this to a generic function name + input = core._promote_bfloat16_to_float32(input) + return core.associative_scan(input, axis, _sum_combine, reverse) + + +# cumprod + + +@jit +def _prod_combine(a, b): + return a * b + + +@core._tensor_member_fn +@jit +@core._add_scan_docstr("cumprod") +def cumprod(input, axis=0, reverse=False): + # todo rename this to a generic function name + input = core._promote_bfloat16_to_float32(input) + return core.associative_scan(input, axis, _prod_combine, reverse) # sort @@ -250,7 +345,27 @@ def _indicator(n_dims: core.constexpr, j: core.constexpr): @jit -def _compare_and_swap(x, flip_dim, i: core.constexpr): +def _compare_and_swap(x, flip, i: core.constexpr, n_dims: core.constexpr): + n_outer: core.constexpr = x.numel >> n_dims + shape: core.constexpr = [n_outer * 2**i, 2, 2**(n_dims - i - 1)] + y = core.reshape(x, shape) + # slice left/right with 'stride' 2**(n_dims - i - 1) + mask = core.arange(0, 2)[None, :, None] + left = core.broadcast_to(sum(y * (1 - mask), 1)[:, None, :], shape).to(y.dtype) + right = core.broadcast_to(sum(y * mask, 1)[:, None, :], shape).to(y.dtype) + left = core.reshape(left, x.shape) + right = core.reshape(right, x.shape) + # actual compare-and-swap + idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True) + ileft = left.to(idtype, bitcast=True) + iright = right.to(idtype, bitcast=True) + ix = x.to(idtype, bitcast=True) + ret = ix ^ core.where((left > right) != flip, ileft ^ iright, zeros_like(ix)) + return ret.to(x.dtype, bitcast=True) + + +@jit +def _compare_and_swap_3_4(x, flip, i: core.constexpr): # compare-and-swap on the ith *innermost* dimension n_dims: core.constexpr = _log2(x.numel) @@ -264,7 +379,7 @@ def _compare_and_swap(x, flip_dim, i: core.constexpr): is_right = _indicator(n_dims, i) # conditional swap: - ret = core.where((x > y) != (flip_dim ^ is_right), y, x) + ret = core.where((x > y) != (flip ^ is_right), y, x) return ret @@ -281,12 +396,60 @@ def _bitonic_merge_hypercube(x, stage: core.constexpr, order: core.constexpr): # if flip = 00110011... then all the elements will be re-arranged alternatingly (with # a stride of 2) at this stage if order == 2: - flip_dim = _indicator(_log2(x.numel), stage) + flip = _indicator(_log2(x.numel), stage) + else: + flip = order + # perform `stage` rounds of `compare-and-swap` + for i in core.static_range(stage): + x = _compare_and_swap_3_4(x, flip, stage - 1 - i) + return x + + +@jit +def _bitonic_merge(x, stage: core.constexpr, order: core.constexpr, n_dims: core.constexpr): + ''' + order_type 0 == ascending + order_type 1 == descending + order_type 2 == alternating + ''' + n_outer: core.constexpr = x.numel >> n_dims + core.static_assert(stage <= n_dims) + # flip denotes whether to re-arrange sub-sequences of elements in ascending or + # descending order. + # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage + # if flip = 00110011... then all the elements will be re-arranged alternatingly (with + # a stride of 2) at this stage + if order == 2: + shape: core.constexpr = [n_outer * 2**(n_dims - 1 - stage), 2, 2**stage] + flip = core.reshape(core.broadcast_to(core.arange(0, 2)[None, :, None], shape), x.shape) else: - flip_dim = order + flip = order # perform `stage` rounds of `compare-and-swap` for i in core.static_range(stage): - x = _compare_and_swap(x, flip_dim, stage - 1 - i) + x = _compare_and_swap(x, flip, i + (n_dims - stage), n_dims) + return x + + +@core._tensor_member_fn +@jit +def sort(x, dim: core.constexpr = None, descending: core.constexpr = core.CONSTEXPR_0): + """ + Sorts a tensor along a specified dimension. + + :param x: The input tensor to be sorted. + :type x: Tensor + :param dim: The dimension along which to sort the tensor. If None, the tensor is sorted along the last dimension. Currently, only sorting along the last dimension is supported. + :type dim: int, optional + :param descending: If set to True, the tensor is sorted in descending order. If set to False, the tensor is sorted in ascending order. + :type descending: bool, optional + """ + # handle default dimension or check that it is the most minor dim + _dim: core.constexpr = len(x.shape) - 1 if dim is None else dim + core.static_assert(_dim == len(x.shape) - 1, "only minor dimension is currently supported") + # iteratively run bitonic merge-sort steps + n_dims: core.constexpr = _log2(x.shape[_dim]) + for i in core.static_range(1, n_dims + 1): + x = _bitonic_merge(x, i, 2 if i < n_dims else descending, n_dims) return x @@ -323,7 +486,7 @@ def sort_impl(x, k: core.constexpr = None, dim: core.constexpr = None, descendin # select top k elements using bitonic top-k # https://www.doc.ic.ac.uk/~hlgr/pdfs/MassivelyParallelTopK.pdf for i in core.static_range(log_k + 1, log_n + 1): - h = real_max(h, axis=(_log2(h.numel) - 1 - log_k)) if descending else min(h, axis=(_log2(h.numel) - 1 - log_k)) + h = max(h, axis=(_log2(h.numel) - 1 - log_k)) if descending else min(h, axis=(_log2(h.numel) - 1 - log_k)) h = _bitonic_merge_hypercube(h, log_k, 2 if i < log_n else descending) # reshape back: @@ -336,6 +499,66 @@ def topk(x, k: core.constexpr, dim: core.constexpr = None): return sort_impl(x, k=k, dim=dim, descending=True) -standard_ext_spec_api_list = [ - "sigmoid", "softmax", "isfinited", "finitef", "rint", "atan2", "argmax", "argmin", "topk", "max" -] +# flip + + +def _get_flip_dim(dim, shape): + dim = core._unwrap_if_constexpr(dim) + shape = core._unwrap_if_constexpr(shape) + if dim is None: + dim = len(shape) - 1 + assert dim == len(shape) - 1, "Currently only support flipping the last dimension" + return core.constexpr(dim) + + +@core._tensor_member_fn +@jit +def flip(x, dim=None): + """ + Flips a tensor `x` along the dimension `dim`. + + :param x: the first input tensor + :type x: Block + :param dim: the dimension to flip along (currently only final dimension supported) + :type dim: int + """ + core.static_assert(_is_power_of_two(x.shape[_get_flip_dim(dim, x.shape)])) + core.static_assert(_is_power_of_two(x.numel)) + # # reshape the tensor to have all dimensions be 2. + # # TODO: We shouldn't have to change the dimensions not sorted. + steps: core.constexpr = _log2(x.numel) + start: core.constexpr = _log2(x.numel) - _log2(x.shape[_get_flip_dim(dim, x.shape)]) + y = core.reshape(x, [2] * steps) + y = core.expand_dims(y, start) + flip = (core.arange(0, 2)[:, None] == 1 - core.arange(0, 2)) + for i in core.static_range(start, steps): + flip2 = flip + for j in core.static_range(0, steps + 1): + if j != i and j != i + 1: + flip2 = core.expand_dims(flip2, j) + y = sum(y * flip2, i + 1, keep_dims=True) + x = core.reshape(y, x.shape) + return x + + +@jit +def interleave(a, b): + """ + Interleaves the values of two tensors along their last dimension. The two tensors must have the same shape. + Equivalent to `tl.join(a, b).reshape(a.shape[:-1] + [2 * a.shape[-1]])` + + :param a: The first input tensor. + :type a: Tensor + :param b: The second input tensor. + :type b: Tensor + """ + c = core.join(a, b) + + if len(c.shape) == 1: + # We must have interleaved two scalars. + return c + else: + # This `else` is necessary because Triton's AST parser doesn't + # understand that if we take the `if` above we definitely don't run this + # `else`. + return core.reshape(c, c.shape[:-2] + [2 * c.shape[-2]]) diff --git a/third_party/ascend/backend/spec/triton/language/tensor_descriptor.py b/third_party/ascend/backend/spec/triton/language/tensor_descriptor.py deleted file mode 100644 index e125fcc5e..000000000 --- a/third_party/ascend/backend/spec/triton/language/tensor_descriptor.py +++ /dev/null @@ -1,694 +0,0 @@ -# TODO: When upgrading to Triton 3.4.0, remove this file, -# use the upstream Triton functions, and update core.py and semantic.py accordingly. -from __future__ import annotations - -import builtins -from typing import List, Tuple, Sequence, TypeVar -from enum import Enum - -from triton._C.libtriton import ir -from triton.language.core import ( - builtin, - constexpr, - tensor, - _value, - void as real_void, -) - -from triton.language.semantic import ( - _convert_to_ir_values, - _str_to_load_cache_modifier, - _str_to_eviction_policy, -) - -from triton.language._utils import validate_block_shape - - -def _unwrap_if_constexpr(o): - if isinstance(o, list): - return [_unwrap_if_constexpr(x) for x in o] - if isinstance(o, builtins.tuple): - return builtins.tuple(_unwrap_if_constexpr(x) for x in o) - if isinstance(o, tuple): - return tuple(_unwrap_if_constexpr(x) for x in o) - return o.value if isinstance(o, constexpr) else o - - -def _unwrap_shape(shape): - shape = _unwrap_if_constexpr(shape) - return [_unwrap_if_constexpr(s) for s in shape] - - -def _normalize_tuple(t): - normalized_tuple = _unwrap_if_constexpr(t) - if isinstance(normalized_tuple, (list, builtins.tuple)): - normalized_tuple = tuple(normalized_tuple) - return normalized_tuple - - -def descriptor_load(desc: tensor_descriptor_base, offsets, cache_modifier: str, eviction_policy: str, - builder: ir.builder) -> tensor: - assert isinstance(desc, tensor_descriptor_base) - ndim = len(desc.block_shape) - assert len(offsets) == ndim, f"expected {ndim} offsets, but got {len(offsets)}" - - offsets = _convert_to_ir_values(builder, offsets, require_i64=False) - x = builder.create_descriptor_load(desc.handle, offsets, _str_to_load_cache_modifier(cache_modifier), - _str_to_eviction_policy(eviction_policy)) - return tensor(x, desc.block_type) - - -def validate_store_like(desc: tensor_descriptor_base, value: tensor, offsets) -> None: - assert isinstance(desc, tensor_descriptor_base) - ndim = len(desc.block_shape) - assert len(offsets) == ndim, f"expected {ndim} offsets, but got {len(offsets)}" - assert value.shape == desc.block_shape - - -def descriptor_store(desc: tensor_descriptor_base, value: tensor, offsets, builder: ir.builder) -> tensor: - validate_store_like(desc, value, offsets) - offsets = _convert_to_ir_values(builder, offsets, require_i64=False) - return tensor(builder.create_descriptor_store(desc.handle, value.handle, offsets), real_void) - - -class base_value(_value): - """Base class of values that exist in the triton IR (i.e. not constexprs). - """ - type: base_type - - def _flatten_ir(self, handles: List[ir.value]) -> None: - """Flatten frontend value into a sequence of mlir handles, which are appended - to the output list - """ - raise NotImplementedError - - -class base_type: - - def __eq__(self, other): - raise NotImplementedError("Types must implement __eq__") - - def __ne__(self, other): - return not (self == other) - - def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[base_value, int]: - """Build a frontend value with the current dtype, wrapping a list of existing handles. - cursor is the index of the first handle relevant to this value, and the function - should return the updated cursor position after any handles consumed by the created value. - """ - raise NotImplementedError - - def mangle(self) -> str: - raise NotImplementedError(f"NYI: Type mangling for type {self.__class__}") - - def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None: - raise NotImplementedError - - -class tuple(base_value): - - def __init__(self, args: Sequence, type: tuple_type = None): - self.values = [i for i in args] - - def get_type(x): - if isinstance(x, dtype): - return dtype - if isinstance(x, (int, float)): - return constexpr - return x.type - - self.type = type or tuple_type([get_type(x) for x in self.values]) - - def __getitem__(self, idx: constexpr): - if isinstance(idx, int): - idx = constexpr(idx) - if isinstance(idx, constexpr): - return self.values[idx] - else: - assert isinstance(idx, (slice, builtins.slice)) - return tuple(self.values[idx.start:idx.stop:idx.step]) - - def __getattr__(self, name): - return self.values[self.type.fields.index(name)] - - def __setitem__(self, idx: constexpr, value): - if isinstance(idx, int): - idx = constexpr(idx) - assert isinstance(idx, constexpr) - self.values[idx] = value - - def __add__(self, other): - other = _normalize_tuple(other) - return tuple(self.values + other.values) - - def __mul__(self, other): - assert isinstance(other, constexpr) - return tuple(self.values * other.value) - - def __eq__(self, other): - other = _normalize_tuple(other) - return constexpr(self.values == other.values) - - def __hash__(self): - return hash(builtins.tuple(self.values)) - - def __str__(self): - return str([str(x) for x in self.values]) - - def __iter__(self): - return iter(self.values) - - def __len__(self): - return len(self.values) - - def _flatten_ir(self, handles: List[ir.value]): - for v in self.values: - print("[debug]tuple _flatten_ir: value:", v) - v._flatten_ir(handles) - print("[debug]tuple _flatten_ir: handles:", handles) - - def __repr__(self): - return f"({' ,'.join(repr(x) for x in self.values)})" - - -class tuple_type(base_type): - - def __init__(self, types, fields=None): - self.types = types - self.fields = fields or [''] * len(types) - self.name = '[' + ','.join([f"{k}:{v}" for k, v in zip(self.fields, self.types)]) + ']' - - def __str__(self): - return self.name - - def __iter__(self): - return iter(self.types) - - def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]): - for ty in self.types: - if not isinstance(ty, constexpr): - ty._flatten_ir_types(builder, out) - - def __getitem__(self, index: int) -> dtype: - return self.types[index] - - def __eq__(self, other): - return type(self) is type(other) and self.types == other.types and self.fields == other.fields - - def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tuple, int]: - values = [] - for ty in self.types: - value, cursor = ty._unflatten_ir(handles, cursor) - values.append(value) - return tuple(values, self), cursor - - def mangle(self): - return 'T' + '_'.join(ty.mangle for ty in self.types) + 'T' - - -class dtype(base_type): - SINT_TYPES = ['int8', 'int16', 'int32', 'int64'] - UINT_TYPES = ['int1', 'uint8', 'uint16', 'uint32', 'uint64'] - FP_TYPES = ['fp8e4b15', 'fp8e4nv', 'fp8e4b8', 'fp8e5', 'fp8e5b16', 'fp16', 'bf16', 'fp32', 'fp64'] - STANDARD_FP_TYPES = ['fp16', 'bf16', 'fp32', 'fp64'] - OTHER_TYPES = ['void'] - - class SIGNEDNESS(Enum): - SIGNED = 0 - UNSIGNED = 1 - - class KIND(Enum): - BOOLEAN = 0 - INTEGRAL = 1 - FLOATING = 2 - - def __init__(self, name): - name = _unwrap_if_constexpr(name) - self.name = name - assert name in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES, name - # flagtree backend specialization - from ._utils import spec_get_primitive_bitwidth - get_primitive_bitwidth = spec_get_primitive_bitwidth - self.primitive_bitwidth = get_primitive_bitwidth(name) - self.itemsize = self.primitive_bitwidth // 8 - if name in dtype.SINT_TYPES: - self.int_signedness = dtype.SIGNEDNESS.SIGNED - self.int_bitwidth = self.primitive_bitwidth - elif name in dtype.UINT_TYPES: - self.int_signedness = dtype.SIGNEDNESS.UNSIGNED - self.int_bitwidth = self.primitive_bitwidth - elif name in dtype.FP_TYPES: - if name == 'fp8e4b15': - self.fp_mantissa_width = 3 - self.exponent_bias = 15 - elif name == 'fp8e4nv': - self.fp_mantissa_width = 3 - self.exponent_bias = 7 - elif name == 'fp8e4b8': - self.fp_mantissa_width = 3 - self.exponent_bias = 8 - elif name == 'fp8e5': - self.fp_mantissa_width = 2 - self.exponent_bias = 15 - elif name == 'fp8e5b16': - self.fp_mantissa_width = 2 - self.exponent_bias = 16 - elif name == 'fp16': - self.fp_mantissa_width = 10 - self.exponent_bias = 15 - elif name == 'bf16': - self.fp_mantissa_width = 7 - self.exponent_bias = 127 - elif name == 'fp32': - self.fp_mantissa_width = 23 - self.exponent_bias = 127 - elif name == 'fp64': - self.fp_mantissa_width = 52 - self.exponent_bias = 1023 - else: - raise RuntimeError(f'Unsupported floating-point type {name}') - - def is_fp8(self): - return 'fp8' in self.name - - def is_fp8e4nv(self): - return self.name == 'fp8e4nv' - - def is_fp8e4b8(self): - return self.name == 'fp8e4b8' - - def is_fp8e4b15(self): - return self.name == 'fp8e4b15' - - def is_fp8e5(self): - return self.name == 'fp8e5' - - def is_fp8e5b16(self): - return self.name == 'fp8e5b16' - - def is_fp16(self): - return self.name == 'fp16' - - def is_bf16(self): - return self.name == 'bf16' - - def is_fp32(self): - return self.name == 'fp32' - - def is_fp64(self): - return self.name == 'fp64' - - def is_int1(self): - return self.name == 'int1' - - def is_int8(self): - return self.name == 'int8' - - def is_int16(self): - return self.name == 'int16' - - def is_int32(self): - return self.name == 'int32' - - def is_int64(self): - return self.name == 'int64' - - def is_uint8(self): - return self.name == 'uint8' - - def is_uint16(self): - return self.name == 'uint16' - - def is_uint32(self): - return self.name == 'uint32' - - def is_uint64(self): - return self.name == 'uint64' - - def is_floating(self): - return self.name in dtype.FP_TYPES - - def is_standard_floating(self): - return self.name in dtype.STANDARD_FP_TYPES - - def is_int_signed(self): - return self.name in dtype.SINT_TYPES - - def is_int_unsigned(self): - return self.name in dtype.UINT_TYPES - - def is_int(self): - return self.name in dtype.SINT_TYPES + dtype.UINT_TYPES - - def is_bool(self): - return self.is_int1() - - def kind(self): - # Return int value following the type ordering bool < integer < fp - if self.is_bool(): - return dtype.KIND.BOOLEAN - elif self.is_int(): - return dtype.KIND.INTEGRAL - else: - assert self.is_floating() - return dtype.KIND.FLOATING - - def get_int_max_value(self): - if self.is_int_signed(): - return 2**(self.int_bitwidth - 1) - 1 - if self.is_int_unsigned(): - return 2**self.int_bitwidth - 1 - assert False - - def get_int_min_value(self): - if self.is_int_signed(): - return -2**(self.int_bitwidth - 1) - if self.is_int_unsigned(): - return 0 - assert False - - @staticmethod - def is_dtype(type_str): - return type_str in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES - - @staticmethod - def is_void(): - raise RuntimeError("Not implemented") - - @staticmethod - def is_block(): - return False - - @staticmethod - def is_ptr(): - return False - - @staticmethod - def is_const(): - return False - - def __eq__(self, other: dtype): - if not isinstance(other, dtype): - return False - return self.name == other.name - - def __hash__(self): - return hash((self.name, )) - - @property - def scalar(self): - return self - - def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None: - out.append(self.to_ir(builder)) - - def to_ir(self, builder: ir.builder) -> ir.type: - if self.name.startswith("fp8"): - if self.name not in builder.options.supported_fp8_dtypes: - raise ValueError(f'type {self} not supported in this architecture. ' - f'The supported fp8 dtypes are {builder.options.supported_fp8_dtypes}') - - if self.name == 'void': - return builder.get_void_ty() - elif self.name == 'int1': - return builder.get_int1_ty() - elif self.name in ('int8', 'uint8'): - return builder.get_int8_ty() - elif self.name in ('int16', 'uint16'): - return builder.get_int16_ty() - elif self.name in ('int32', 'uint32'): - return builder.get_int32_ty() - elif self.name in ('int64', 'uint64'): - return builder.get_int64_ty() - elif self.name == 'fp8e5': - return builder.get_fp8e5_ty() - elif self.name == 'fp8e5b16': - return builder.get_fp8e5b16_ty() - elif self.name == 'fp8e4nv': - return builder.get_fp8e4nv_ty() - elif self.name == 'fp8e4b8': - return builder.get_fp8e4b8_ty() - elif self.name == 'fp8e4b15': - return builder.get_fp8e4b15_ty() - elif self.name == 'fp16': - return builder.get_half_ty() - elif self.name == 'bf16': - return builder.get_bf16_ty() - elif self.name == 'fp32': - return builder.get_float_ty() - elif self.name == 'fp64': - return builder.get_double_ty() - raise ValueError(f'fail to convert {self} to ir type') - - def __str__(self): - return self.name - - def codegen_name(self): - if self.name.startswith("fp"): - return "float" + self.name[2:] - elif self.name.startswith("bf"): - return "bfloat" + self.name[2:] - else: - return self.name - - @property - def cache_key_part(self) -> str: - """See cache_key_part() in triton.cc.""" - return self.name - - def __repr__(self): - """Output of repr needs to be an evaluatable expression""" - return f'triton.language.{self.codegen_name()}' - - def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[base_value, int]: - return tensor(handles[cursor], self), cursor + 1 - - def mangle(self) -> str: - if self.is_int(): - SIGNED = dtype.SIGNEDNESS.SIGNED - prefix = 'i' if self.int_signedness == SIGNED else 'u' - return prefix + str(self.int_bitwidth) - if self.is_floating(): - return str(self) - if self.is_void(): - return 'V' - return super().mangle() - - def with_element_ty(self, element_ty: dtype): - assert not self.is_block() - return element_ty - - -class block_type(dtype): - - def __init__(self, element_ty: dtype, shape: List): - self.element_ty = element_ty - - # Note that block_type's shape is a list of int - # while tensor's shape is a list of constexpr. - assert (isinstance(shape, (list, tuple))) - - # shape can be empty ([]) when an input is a 0D tensor. - self.shape = tuple(_unwrap_shape(shape)) - if not self.shape: - raise TypeError('0d block_type is forbidden') - - self.numel = validate_block_shape(self.shape) - self.name = f'<{self.shape}, {self.element_ty}>' - - def to_ir(self, builder: ir.builder) -> ir.block_type: - return builder.get_block_ty(self.element_ty.to_ir(builder), self.shape) - - def __str__(self): - return self.name - - def __repr__(self): - return self.__str__() - - def is_block(self): - return True - - def get_block_shapes(self) -> Tuple[int]: - return self.shape - - def with_element_ty(self, scalar_ty: dtype) -> block_type: - return block_type(scalar_ty, self.shape) - - def __eq__(self, other) -> bool: - if not isinstance(other, block_type): - return False - return self.element_ty == other.element_ty and self.shape == other.shape - - @property - def scalar(self): - return self.element_ty - - def mangle(self) -> str: - elt = self.scalar.mangle() - shape = '_'.join(map(str, self.shape)) - return f'{elt}S{shape}S' - - -class tuple_type(base_type): - - def __init__(self, types, fields=None): - self.types = types - self.fields = fields or [''] * len(types) - self.name = '[' + ','.join([f"{k}:{v}" for k, v in zip(self.fields, self.types)]) + ']' - - def __str__(self): - return self.name - - def __iter__(self): - return iter(self.types) - - def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]): - for ty in self.types: - if not isinstance(ty, constexpr): - ty._flatten_ir_types(builder, out) - - def __getitem__(self, index: int) -> dtype: - return self.types[index] - - def __eq__(self, other): - return type(self) is type(other) and self.types == other.types and self.fields == other.fields - - def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tuple, int]: - values = [] - for ty in self.types: - value, cursor = ty._unflatten_ir(handles, cursor) - values.append(value) - return tuple(values, self), cursor - - def mangle(self): - return 'T' + '_'.join(ty.mangle for ty in self.types) + 'T' - - -class tensor_descriptor_base_type(base_type): - - def __init__(self, block_type: block_type): - self.block_type = block_type - - def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tensor_descriptor_base, int]: - value = tensor_descriptor_base(handles[cursor], self.block_type) - return value, cursor + 1 - - def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None: - is_signed = self.block_type.element_ty.is_int_signed() - out.append(builder.create_tensor_descriptor_type(self.block_type.to_ir(builder), is_signed)) - - def __str__(self) -> str: - # ex. "tensor_descriptor" - return f"tensor_descriptor<{self.block_type}>" - - def __eq__(self, other) -> bool: - if type(other) is not type(self): - return False - return self.block_type == other.block_type - - def __neq__(self, other) -> bool: - return not (self == other) - - def mangle(self) -> str: - return f"TD{self.block_type.mangle()}" - - -class tensor_descriptor_base(base_value): - """" - A tensor descriptor with unknown shape and strides - """ - - def __init__(self, handle, block_type: block_type): - """Not called by user code.""" - super().__init__(handle) - - self.handle = handle # IR handle - self.type = tensor_descriptor_base_type(block_type) # Tensor type (block_type) - - def _flatten_ir(self, handles: List[ir.value]) -> None: - handles.append(self.handle) - - @property - def block_type(self): - return self.type.block_type - - @property - def block_shape(self): - return self.type.block_type.shape - - @property - def dtype(self): - return self.type.block_type.element_ty - - def __str__(self) -> str: - return str(self.type) - - @builtin - def load(self, offsets: Sequence[constexpr | tensor], _builder=None) -> tensor: - """Load a block from the descriptor starting at the given element offsets. - - Values outside of the tensor bounds will be filled with zeros. - - :note: Offset must be a multiple of 16-bytes - """ - return descriptor_load(self, offsets, "", "", _builder) - - @builtin - def store(self, offsets: Sequence[constexpr | tensor], value: tensor, _builder=None) -> tensor: - """Store a block from the descriptor starting at the given element offsets. - - Values outside of the tensor bounds will be ignored. - - :note: Offset must be a multiple of 16-bytes - """ - return descriptor_store(self, value, offsets, _builder) - - -class tensor_descriptor_type(tensor_descriptor_base_type): - - def __init__(self, block_type: block_type, shape_type: tuple_type, strides_type: tuple_type): - self.block_type = block_type - self.shape_type = shape_type - self.strides_type = strides_type - - def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tensor_descriptor_base, int]: - handle = handles[cursor] - cursor += 1 - shape, cursor = self.shape_type._unflatten_ir(handles, cursor) - strides, cursor = self.strides_type._unflatten_ir(handles, cursor) - shape = shape.values - strides = strides.values - value = tensor_descriptor(handle, shape, strides, self.block_type) - return value, cursor - - def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None: - super()._flatten_ir_types(builder, out) - self.shape_type._flatten_ir_types(builder, out) - self.strides_type._flatten_ir_types(builder, out) - - def __eq__(self, other): - return super().__eq__(other) and (self.shape_type == other.shape_type) and (self.strides_type - == other.strides_type) - - -class tensor_descriptor(tensor_descriptor_base): - """A descriptor representing a tensor in global memory. - """ - - def __init__(self, handle, shape: List[tensor], strides: List[tensor], block_type: block_type): - """Not called by user code.""" - # IR handle - super().__init__(handle, block_type) - # Global shape - self.shape = tuple(shape) - self.strides = tuple(strides) - self.type = tensor_descriptor_type( - block_type, - shape_type=self.shape.type, - strides_type=self.strides.type, - ) - - def _flatten_ir(self, handles: List[ir.value]) -> None: - handles.append(self.handle) - self.shape._flatten_ir(handles) - self.strides._flatten_ir(handles) diff --git a/third_party/ascend/backend/spec/triton/runtime/autotiling_tuner.py b/third_party/ascend/backend/spec/triton/runtime/autotiling_tuner.py deleted file mode 100644 index 9ecebbe92..000000000 --- a/third_party/ascend/backend/spec/triton/runtime/autotiling_tuner.py +++ /dev/null @@ -1,237 +0,0 @@ -from __future__ import annotations - -import builtins -import os -import time -from typing import Dict, List - -from .autotuner import Autotuner, Config -from .utils import get_byte_per_numel, is_valid_axis_name, valid_axis_names -from .autoparser import SplitAxesParser, TilingAxesParser, LowDimsAxesParser, PtrNumsParser - - -class AutoTilingTuner(Autotuner): - """ - Automatic generateing candidate tiling configs and evaluating their performance to get the best config. - """ - - def __init__(self, fn, arg_names, configs, key, reset_to_zero, restore_value, pre_hook=None, post_hook=None, - prune_configs_by: Dict = None, warmup=None, rep=None, use_cuda_graph=False, do_bench=None, - auto_profile_dir=None): - """ - :param key: a list of argument name, where the change of arguments in value will triger re-generating candidates configs and evaluating. - The parameters in the list will be assigned axis names in sequence, with the axis name being in - {'x','y','z','w','v','t','rx','ry','rz','rw','rv','rt}, where the prefix 'r' means a reduction axis. - Only the axis name in this param should add perfix 'r' if it's a reduction axis. - :type key: List[str] - """ - super().__init__( - fn, - arg_names, - configs, - key, - reset_to_zero, - restore_value, - pre_hook, - post_hook, - prune_configs_by, - warmup, - rep, - use_cuda_graph, - do_bench, - auto_profile_dir, - ) - - if not configs: - self.user_configs = [] - else: - self.user_configs = configs - self.gen_configs = [] # generated configs from TileGenerator - - self.split_params = None - self.tiling_params = None - self.low_dims = None - self.dual_reduction = False - self.persistent_reduction = False - self.input_ptr_num = 0 - if len(key) > len(valid_axis_names): - raise ValueError("Number of parameters exceeds the number of available axes.") - self.keys = {axis: param for axis, param in zip(valid_axis_names, key)} - self.print_autotuning = os.getenv("TRITON_PRINT_AUTOTUNING", None) == "1" - - def _gen_tile_configs(self, kv_dict: Dict[str, int], dtype: torch.dtype) -> List[Config]: - from .tile_generator import KernelMeta, TileGenerator - - axis_sizes = {} - for k, v in kv_dict.items(): - if not is_valid_axis_name(k): - continue - if not isinstance(v, int): - raise ValueError(f"Not supported dim type: {type(v)}, `int` is the only supported type") - axis_sizes[k] = v - - kernel_meta = KernelMeta( - axis_sizes, - self.split_params, - self.tiling_params, - self.low_dims, - dtype, - self.persistent_reduction, - self.dual_reduction, - self.input_ptr_num, - ) - tile_gen = TileGenerator(kernel_meta=kernel_meta) - tile_gen.descend_split_tiling() - - self.gen_configs.clear() - self.gen_configs = tile_gen.configs - if len(self.gen_configs) == 0: - print("[WARNING] The generated candidate tiling configs are empty based on provided parameters!") - - if len(self.gen_configs) == 0 and len(self.user_configs) == 0: - return [ - Config( - {}, - num_warps=4, - num_stages=2, - num_ctas=1, - num_buffers_warp_spec=0, - num_consumer_groups=0, - reg_dec_producer=0, - reg_inc_consumer=0, - ) - ] - else: - return self.gen_configs + self.user_configs - - def run(self, *args, **kwargs): - self.nargs = dict(zip(self.arg_names, args)) - used_cached_result = True - - # generate key - all_args = {**self.nargs, **kwargs} - _args = {k: v for (k, v) in all_args.items() if k in self.arg_names} - _kv_dict = {k: _args[v] for k, v in self.keys.items() if v in _args} - key = list(_kv_dict.values()) - - # Currently, we use the dtype with maximum byte length - dtype = None - for _, arg in _args.items(): - if hasattr(arg, "dtype"): - key.append(str(arg.dtype)) - dtype = (arg.dtype if get_byte_per_numel(arg.dtype) >= get_byte_per_numel(dtype) else dtype) - if dtype is None: - raise NotImplementedError("Not support for non-Tensor inputs") - - key = tuple(key) - if key not in self.cache: - miss_params = [arg for arg in self.arg_names if arg not in all_args.keys()] - # parse pointer params nums - self.input_ptr_nums = self.autoparse_ptr_nums(miss_params) - - # parse autotiling axes - if not self.low_dims: - self.low_dims = self.autoparse_low_dims() - if not self.split_params: - self.split_params = self.autoparse_split_params(miss_params) - miss_params = [arg for arg in miss_params if arg not in self.split_params.values()] - if not self.tiling_params: - self.tiling_params = self.autoparse_tiling_params(miss_params) - miss_params = [arg for arg in miss_params if arg not in self.tiling_params.values()] - if miss_params: - raise ValueError(f"Missing required arguments: {miss_params}. " - f"These arguments must be explicitly provided and cannot be automatically tuned. " - f"Please ensure that these arguments are passed when calling the function.") - - # prune configs - self.configs = self._gen_tile_configs(_kv_dict, dtype) - pruned_configs = self.prune_configs(kwargs) - if len(pruned_configs) > 1: - used_cached_result = False - bench_start = time.time() - timings = self._batch_bench(*args, configs=pruned_configs, **kwargs) - bench_end = time.time() - self.bench_time = bench_end - bench_start - self.cache[key] = builtins.min(timings, key=timings.get) - full_nargs = {**self.nargs, **kwargs, **self.cache[key].all_kwargs()} - self.pre_hook(full_nargs, reset_only=True) - self.configs_timings = timings - config = self.cache[key] - else: - config = pruned_configs[0] - else: - config = self.cache[key] - - self.best_config = config - if os.getenv("TRITON_PRINT_AUTOTUNING", None) == "1" and not used_cached_result: - print(f"Triton autotuning for function {self.base_fn.__name__} finished after " - f"{self.bench_time:.2f}s; best config selected: {self.best_config};") - - if not used_cached_result and self.auto_profile_dir is not None: - self._profile(*args, config=self.best_config, **kwargs) - if config.pre_hook is not None: - full_nargs = {**self.nargs, **kwargs, **config.all_kwargs()} - config.pre_hook(full_nargs) - ret = self.fn.run( - *args, - **kwargs, - **config.all_kwargs(), - ) - self.nargs = None - return ret - - def autoparse_split_params(self, candidates_params: List[str]) -> Dict[str, str]: - """ - Extracts the split axis parameters from triton kernel code. - """ - if self.print_autotuning: - print(f"Triton autotuning: Starting split params parsing...") - func_ast = self.fn.parse() - parser = SplitAxesParser(func_ast, self.keys, candidates_params) - split_axes = parser.parse() - if self.print_autotuning: - print(f"Triton autotuning: Split params parsing complete. " - f"Split params: {split_axes}") - return split_axes - - def autoparse_tiling_params(self, candidates_params: List[str]) -> Dict[str, str]: - """ - Extracts the tiling axis parameters from triton kernel code. - """ - if self.print_autotuning: - print(f"Triton autotuning: Starting tiling params parsing...") - func_ast = self.fn.parse() - parser = TilingAxesParser(func_ast, self.keys, candidates_params) - tiling_axes = parser.parse() - if self.print_autotuning: - print(f"Triton autotuning: Tiling params parsing complete. " - f"Tiling params: {tiling_axes}") - return tiling_axes - - def autoparse_low_dims(self) -> List[str]: - """ - Extracts the low dimension axis from triton kernel code. - """ - if self.print_autotuning: - print(f"Triton autotuning: Starting Low dims axes parsing...") - func_ast = self.fn.parse() - parser = LowDimsAxesParser(func_ast, self.keys) - low_dims = parser.parse() - if self.print_autotuning: - print(f"Triton autotuning: Low dims axes parsing complete. " - f"Keys: {self.keys}, Low dims: {low_dims}") - return low_dims - - def autoparse_ptr_nums(self, miss_params: List[str]) -> int: - """ - Counts the number of pointer parameters from triton kernel code. - """ - if self.print_autotuning: - print(f"Triton autotuning: Starting ptr nums parsing...") - func_ast = self.fn.parse() - parser = PtrNumsParser(func_ast, miss_params) - ptr_nums, ptr_params = parser.parse() - if self.print_autotuning: - print(f"Triton autotuning: Pointer nums parsing complete. " - f"Pointer params: {ptr_params}, pointer nums: {ptr_nums}") - return ptr_nums diff --git a/third_party/ascend/backend/spec/triton/runtime/autotuner.py b/third_party/ascend/backend/spec/triton/runtime/autotuner.py index 813356d87..5afa5228d 100644 --- a/third_party/ascend/backend/spec/triton/runtime/autotuner.py +++ b/third_party/ascend/backend/spec/triton/runtime/autotuner.py @@ -1,191 +1,408 @@ +from __future__ import annotations + +import builtins import os -import threading -from concurrent.futures import ThreadPoolExecutor -import logging +import time +import inspect +from typing import Dict + +from .jit import KernelInterface +from .errors import OutOfResources +from .driver import driver + + +class Autotuner(KernelInterface): + + def __init__( + self, + fn, + arg_names, + configs, + key, + reset_to_zero, + restore_value, + pre_hook=None, + post_hook=None, + prune_configs_by: Dict = None, + warmup=None, + rep=None, + use_cuda_graph=False, + do_bench=None, + ): + """ + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs. + """ + if not configs: + self.configs = [ + Config({}, num_warps=4, num_stages=2, num_ctas=1, num_buffers_warp_spec=0, num_consumer_groups=0, + reg_dec_producer=0, reg_inc_consumer=0) + ] + else: + self.configs = configs + self.keys = key + self.cache = {} + self.arg_names = arg_names + + # Reset to zero or restore values + self.reset_to_zero = [] + if reset_to_zero is not None: + self.reset_to_zero = list(reset_to_zero) + self.restore_value = [] + if restore_value is not None: + self.restore_value = list(restore_value) + + # Hook to reset or restore for required tensors + self.pre_hook = lambda kwargs, reset_only=False: 0 + self.post_hook = lambda kwargs, exception: 0 + self.user_defined_pre_hook = False + self.user_defined_post_hook = False + if pre_hook: + self.pre_hook = pre_hook + self.user_defined_pre_hook = True + elif (len(self.reset_to_zero) > 0 or len(self.restore_value) > 0): + + def _pre_hook(kwargs, reset_only=False): + for name in self.reset_to_zero: + kwargs[name].zero_() + if not reset_only: + self.restore_copies = {name: kwargs[name].clone() for name in self.restore_value} + + self.pre_hook = _pre_hook + + if post_hook: + self.post_hook = post_hook + self.user_defined_post_hook = True + elif len(self.restore_value) > 0: + + def _post_hook(kwargs, exception): + for name in self.restore_value: + kwargs[name].copy_(self.restore_copies[name]) + self.restore_copies = {} + + self.post_hook = _post_hook + + self.perf_model = None + self.configs_top_k = 1.0 + self.early_config_prune = None + if prune_configs_by: + self.perf_model = prune_configs_by.get("perf_model", self.perf_model) + self.configs_top_k = prune_configs_by.get("top_k", self.configs_top_k) + self.early_config_prune = prune_configs_by.get("early_config_prune", self.early_config_prune) + + self.fn = fn + self.base_fn = fn + while not inspect.isfunction(self.base_fn): + self.base_fn = self.base_fn.fn + + self.num_warmups = warmup + self.num_reps = rep + self.use_cuda_graph = use_cuda_graph + + # If we got explicitly called via the old interface, raise a warning + # and proceed with the old behavior. + if warmup is not None or rep is not None or use_cuda_graph: + import warnings + warnings.warn(("warmup, rep, and use_cuda_graph parameters are deprecated. See " + "https://github.com/triton-lang/triton/pull/4496 for details."), DeprecationWarning, + stacklevel=1) + if use_cuda_graph: + from ..testing import do_bench_cudagraph + self.do_bench = lambda kernel_call, quantiles: do_bench_cudagraph( + kernel_call, + rep=rep if rep is not None else 100, + quantiles=quantiles, + ) + return + + import triton.testing + self.do_bench = lambda kernel_call, quantiles: triton.testing.do_bench( + kernel_call, + warmup=warmup if warmup is not None else 25, + rep=rep if rep is not None else 100, + quantiles=quantiles, + ) + return + + if do_bench is None: + self.do_bench = driver.active.get_benchmarker() + else: + self.do_bench = do_bench + + def _bench(self, *args, config, **meta): + from ..compiler.errors import CompileTimeAssertionFailure, MLIRCompilationError + + # check for conflicts, i.e. meta-parameters both provided + # as kwargs and by the autotuner + conflicts = meta.keys() & config.kwargs.keys() + if conflicts: + raise ValueError(f"Conflicting meta-parameters: {', '.join(conflicts)}." + " Make sure that you don't re-define auto-tuned symbols.") + # augment meta-parameters with tunable ones + current = dict(meta, **config.all_kwargs()) + full_nargs = {**self.nargs, **current} + + def kernel_call(): + if config.pre_hook: + config.pre_hook(full_nargs) + self.pre_hook(full_nargs) + try: + self.fn.run( + *args, + **current, + ) + except Exception as e: + try: + self.post_hook(full_nargs, exception=e) + finally: + # Throw exception raised by `self.fn.run` + raise + self.post_hook(full_nargs, exception=None) -def set_Autotuner_auto_profile_dir(autotuner, auto_profile_dir): - autotuner.auto_profile_dir = auto_profile_dir + try: + return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8)) + except (OutOfResources, CompileTimeAssertionFailure, MLIRCompilationError): + return [float("inf"), float("inf"), float("inf")] + + def run(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + used_cached_result = True + if len(self.configs) > 1: + all_args = {**self.nargs, **kwargs} + _args = {k: v for (k, v) in all_args.items() if k in self.arg_names} + key = [_args[key] for key in self.keys if key in _args] + for _, arg in _args.items(): + if hasattr(arg, "dtype"): + key.append(str(arg.dtype)) + key = tuple(key) + if key not in self.cache: + # prune configs + used_cached_result = False + pruned_configs = self.prune_configs(kwargs) + bench_start = time.time() + timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} + bench_end = time.time() + self.bench_time = bench_end - bench_start + self.cache[key] = builtins.min(timings, key=timings.get) + full_nargs = {**self.nargs, **kwargs, **self.cache[key].all_kwargs()} + self.pre_hook(full_nargs, reset_only=True) + self.configs_timings = timings + config = self.cache[key] + else: + config = self.configs[0] + self.best_config = config + if os.getenv("TRITON_PRINT_AUTOTUNING", None) == "1" and not used_cached_result: + print(f"Triton autotuning for function {self.base_fn.__name__} finished after " + f"{self.bench_time:.2f}s; best config selected: {self.best_config};") + if config.pre_hook is not None: + full_nargs = {**self.nargs, **kwargs, **config.all_kwargs()} + config.pre_hook(full_nargs) + ret = self.fn.run( + *args, + **kwargs, + **config.all_kwargs(), + ) + self.nargs = None + return ret + + def prune_configs(self, kwargs): + pruned_configs = self.configs + if self.early_config_prune: + pruned_configs = self.early_config_prune(self.configs, self.nargs, **kwargs) + if self.perf_model: + top_k = self.configs_top_k + if isinstance(top_k, float) and top_k <= 1.0: + top_k = int(len(self.configs) * top_k) + if len(pruned_configs) > top_k: + est_timing = { + config: self.perf_model( + **self.nargs, + **kwargs, + **config.all_kwargs(), + ) + for config in pruned_configs + } + pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] + return pruned_configs + + def warmup(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + ret = [] + for config in self.prune_configs(kwargs): + ret.append(self.fn.warmup( + *args, + **kwargs, + **config.all_kwargs(), + )) + self.nargs = None + return ret -def ext_Autotuner_do_bench_MLIRCompilationError(): - from ..compiler.errors import MLIRCompilationError - return (MLIRCompilationError, ) +class Config: + """ + An object that represents a possible kernel configuration for the auto-tuner to try. + + :ivar kwargs: a dictionary of meta-parameters to pass to the kernel as keyword arguments. + :type kwargs: dict[Str, Any] + :ivar num_warps: the number of warps to use for the kernel when compiled for GPUs. For example, if + `num_warps=8`, then each kernel instance will be automatically parallelized to + cooperatively execute using `8 * 32 = 256` threads. + :type num_warps: int + :ivar num_stages: the number of stages that the compiler should use when software-pipelining loops. + Mostly useful for matrix multiplication workloads on SM80+ GPUs. + :type num_ctas: int + :ivar num_ctas: number of blocks in a block cluster. SM90+ only. + :type maxnreg: Optional[int] + :ivar maxnreg: maximum number of registers one thread can use. Corresponds + to ptx .maxnreg directive. Not supported on all platforms. + :ivar pre_hook: a function that will be called before the kernel is called. Parameters of this + function are args. + """ + def __init__(self, kwargs, num_warps=4, num_stages=2, num_ctas=1, num_buffers_warp_spec=0, num_consumer_groups=0, + reg_dec_producer=0, reg_inc_consumer=0, maxnreg=None, pre_hook=None): + self.kwargs = kwargs + self.num_warps = num_warps + self.num_ctas = num_ctas + self.num_stages = num_stages + self.num_buffers_warp_spec = num_buffers_warp_spec + self.num_consumer_groups = num_consumer_groups + self.reg_dec_producer = reg_dec_producer + self.reg_inc_consumer = reg_inc_consumer + self.maxnreg = maxnreg + self.pre_hook = pre_hook + + def all_kwargs(self): + return { + **self.kwargs, **{ + k: v + for (k, v) in ( + ("num_warps", self.num_warps), + ("num_ctas", self.num_ctas), + ("num_stages", self.num_stages), + ("num_buffers_warp_spec", self.num_buffers_warp_spec), + ("num_consumer_groups", self.num_consumer_groups), + ("reg_dec_producer", self.reg_dec_producer), + ("reg_inc_consumer", self.reg_inc_consumer), + ("maxnreg", self.maxnreg), + ) if v is not None + } + } + + def __str__(self): + res = [] + for k, v in self.kwargs.items(): + res.append(f"{k}: {v}") + res.append(f"num_warps: {self.num_warps}") + res.append(f"num_ctas: {self.num_ctas}") + res.append(f"num_stages: {self.num_stages}") + res.append(f"num_buffers_warp_spec: {self.num_buffers_warp_spec}") + res.append(f"num_consumer_groups: {self.num_consumer_groups}") + res.append(f"reg_dec_producer: {self.reg_dec_producer}") + res.append(f"reg_inc_consumer: {self.reg_inc_consumer}") + res.append(f"maxnreg: {self.maxnreg}") + return ", ".join(res) + + +def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_value=None, pre_hook=None, post_hook=None, + warmup=None, rep=None, use_cuda_graph=False, do_bench=None): + """ + Decorator for auto-tuning a :code:`triton.jit`'d function. + + .. highlight:: python + .. code-block:: python + + @triton.autotune(configs=[ + triton.Config(kwargs={'BLOCK_SIZE': 128}, num_warps=4), + triton.Config(kwargs={'BLOCK_SIZE': 1024}, num_warps=8), + ], + key=['x_size'] # the two above configs will be evaluated anytime + # the value of x_size changes + ) + @triton.jit + def kernel(x_ptr, x_size, **META): + BLOCK_SIZE = META['BLOCK_SIZE'] + :note: When all the configurations are evaluated, the kernel will run multiple times. + This means that whatever value the kernel updates will be updated multiple times. + To avoid this undesired behavior, you can use the `reset_to_zero` argument, which + resets the value of the provided tensor to `zero` before running any configuration. + + If the environment variable :code:`TRITON_PRINT_AUTOTUNING` is set to + :code:`"1"`, Triton will print a message to stdout after autotuning each + kernel, including the time spent autotuning and the best configuration. + + :param configs: a list of :code:`triton.Config` objects + :type configs: list[triton.Config] + :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs. + :type key: list[str] + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs. + :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs. + :type reset_to_zero: list[str] + :param restore_value: a list of argument names whose value will be restored after evaluating any configs. + :type restore_value: list[str] + :param pre_hook: a function that will be called before the kernel is called. + This overrides the default pre_hook used for 'reset_to_zero' and 'restore_value'. + 'kwargs': a dict of all arguments passed to the kernel. + 'reset_only': a boolean indicating whether the pre_hook is called to reset the values only, without a corresponding post_hook. + :type pre_hook: lambda args, reset_only + :param post_hook: a function that will be called after the kernel is called. + This overrides the default post_hook used for 'restore_value'. + 'kwargs': a dict of all arguments passed to the kernel. + 'exception': the exception raised by the kernel in case of a compilation or runtime error. + :type post_hook: lambda args, exception + :param warmup: warmup time (in ms) to pass to benchmarking (deprecated). + :type warmup: int + :param rep: repetition time (in ms) to pass to benchmarking (deprecated). + :type rep: int + :param do_bench: a benchmark function to measure the time of each run. + :type do_bench: lambda fn, quantiles + """ -def _tiling_kernel(self, *args, config, **meta): - # check for conflicts, i.e. meta-parameters both provided - # as kwargs and by the autotuner - conflicts = meta.keys() & config.kwargs.keys() - if conflicts: - raise ValueError(f"Conflicting meta-parameters: {', '.join(conflicts)}." - " Make sure that you don't re-define auto-tuned symbols.") - # augment meta-parameters with tunable ones - current = dict(meta, **config.all_kwargs()) - full_nargs = {**self.nargs, **current} + def decorator(fn): + return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, pre_hook=pre_hook, + post_hook=post_hook, prune_configs_by=prune_configs_by, warmup=warmup, rep=rep, + use_cuda_graph=use_cuda_graph) - def kernel_call(): - if config.pre_hook: - config.pre_hook(full_nargs) - self.pre_hook(full_nargs) - try: - self.fn.run( - *args, - **current, - ) - except Exception as e: - try: - self.post_hook(full_nargs, exception=e) - finally: - # Throw exception raised by `self.fn.run` - raise + return decorator + + +class Heuristics(KernelInterface): - self.post_hook(full_nargs, exception=None) + def __init__(self, fn, arg_names, values) -> None: + self.fn = fn + self.values = values + self.arg_names = arg_names - return kernel_call + def run(self, *args, **kwargs): + for v, heur in self.values.items(): + kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs}) + return self.fn.run(*args, **kwargs) -def _batch_benchmark(self, kernel_dict, rep=10, quantiles=None): +def heuristics(values): """ - Benchmark the runtime of the provided function. - By default, return the median runtime of :code:`fn` along with - the 20-th and 80-th performance percentile. - - :param kernel_dict: Function to benchmark - :type kernel_dict: Callable - :param rep: Repetition time (in ms) - :type rep: int - :param quantiles: Performance percentile to return in addition to the median. - :type quantiles: list[float], optional + Decorator for specifying how the values of certain meta-parameters may be computed. + This is useful for cases where auto-tuning is prohibitevely expensive, or just not applicable. + + .. highlight:: python + .. code-block:: python + + @triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))}) + @triton.jit + def kernel(x_ptr, x_size, **META): + BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size + :param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter. + each such function takes a list of positional arguments as input. + :type values: dict[str, Callable[[list[Any]], Any]] """ - assert len(kernel_dict) > 0, f"ERROR: length of kernel_dict is empty." - kernel_dict_temp_lock = threading.Lock() - tiling_dict_lock = threading.Lock() - tiling_dict = {} - kernel_dict_temp = {} - from triton.compiler.errors import CompileTimeAssertionFailure, CompilationError - from triton.runtime.errors import OutOfResources - from ..compiler.errors import MLIRCompilationError - - def run_fn(config, fn): - try: - with kernel_dict_temp_lock: - fn() - kernel_dict_temp[config] = fn - except (CompileTimeAssertionFailure, MLIRCompilationError, CompilationError) as ex: - with tiling_dict_lock: - tiling_dict[config] = [float('inf')] - raise ex - - def run_all_fns(): - import psutil - max_workers = psutil.cpu_count(logical=False) - with ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = [] - for config, fn in kernel_dict.items(): - future = executor.submit(run_fn, config, fn) - futures.append(future) - for future in futures: - try: - future.result() - except Exception as ex: - logging.info(f"Exception raised while benchmarking function.{ex}") - - run_all_fns() - - if self.do_bench.__module__ == "triton.testing": - enable_bench_npu = os.getenv("TRITON_BENCH_METHOD", 'default').lower() == 'npu' - import torch - if torch.npu.is_available() and enable_bench_npu: - from triton.testing import do_bench_multiple_kernel_npu - tiling_dict_temp = do_bench_multiple_kernel_npu(kernel_dict_temp, active=max(30, rep), prof_dir=None, - keep_res=False) - tiling_dict.update(tiling_dict_temp) - return tiling_dict - for config, kernel_call in kernel_dict_temp.items(): - try: - tiling_dict[config] = self.do_bench(kernel_call, quantiles=quantiles) - except (OutOfResources, CompileTimeAssertionFailure, MLIRCompilationError) as ex: - tiling_dict[config] = [float("inf"), float("inf"), float("inf")] - return tiling_dict - - -def _profile(autotuner, *args, config, **meta): - from ..testing import do_bench_npu - kernel_call = _tiling_kernel(*args, config=config, **meta) - do_bench_npu(kernel_call, prof_dir=autotuner.auto_profile_dir, keep_res=True) - - -def _batch_bench(self, *args, configs, **kwargs): - kernel_dict = {config: _tiling_kernel(self, *args, config=config, **kwargs) for config in configs} - return _batch_benchmark(self, kernel_dict=kernel_dict, quantiles=(0.5, 0.2, 0.8)) - - -def ext_Autotuner_batch_bench(autotuner, *args, configs, **kwargs): - return _batch_bench(autotuner, *args, configs=configs, **kwargs) - - -def ext_Autotuner_profile(autotuner, used_cached_result, args, kwargs): - if not used_cached_result and autotuner.auto_profile_dir is not None: - _profile(autotuner, *args, config=autotuner.best_config, **kwargs) - - -def default_Config_arg_is_none(): - return True - - -def set_Config_extra_options(config, bishengir_options): - # BiShengIR Options allowed for autotune - config.multibuffer = bishengir_options.get("multibuffer", None) # Compiler Default True - config.sync_solver = bishengir_options.get("sync_solver", None) # Compiler Default False - config.unit_flag = bishengir_options.get("unit_flag", None) # Compiler Default False - config.limit_auto_multi_buffer_only_for_local_buffer = bishengir_options.get( - "limit_auto_multi_buffer_only_for_local_buffer", None) # Compiler Default False - config.limit_auto_multi_buffer_of_local_buffer = bishengir_options.get("limit_auto_multi_buffer_of_local_buffer", - None) # Compiler Default no-limit - config.set_workspace_multibuffer = bishengir_options.get("set_workspace_multibuffer", None) # Compiler Default 1 - config.enable_hivm_auto_cv_balance = bishengir_options.get("enable_hivm_auto_cv_balance", - None) # Compiler Default True - config.tile_mix_vector_loop = bishengir_options.get("tile_mix_vector_loop", None) # Compiler Default 1 - config.tile_mix_cube_loop = bishengir_options.get("tile_mix_cube_loop", None) # Compiler Default 1 - - -def ext_Config_all_kwargs(config): - return ( - ("force_simt_template", config.force_simt_template), - ("enable_linearize", config.enable_linearize), - ("multibuffer", config.multibuffer), - ("enable_hivm_auto_cv_balance", config.enable_hivm_auto_cv_balance), - ("sync_solver", config.sync_solver), - ("unit_flag", config.unit_flag), - ("limit_auto_multi_buffer_only_for_local_buffer", \ - config.limit_auto_multi_buffer_only_for_local_buffer), - ("limit_auto_multi_buffer_of_local_buffer", config.limit_auto_multi_buffer_of_local_buffer), - ("set_workspace_multibuffer", config.set_workspace_multibuffer), - ("tile_mix_vector_loop", config.tile_mix_vector_loop), - ("tile_mix_cube_loop", config.tile_mix_cube_loop) - ) - - -def ext_Config_to_str(res, config): - res.append(f"multibuffer: {config.multibuffer}") - res.append(f"enable_hivm_auto_cv_balance: {config.enable_hivm_auto_cv_balance}") - res.append(f"sync_solver: {config.sync_solver}") - res.append(f"unit_flag: {config.unit_flag}") - res.append(f"limit_auto_multi_buffer_only_for_local_buffer: \ - {config.limit_auto_multi_buffer_only_for_local_buffer}") - res.append(f"limit_auto_multi_buffer_of_local_buffer: {config.limit_auto_multi_buffer_of_local_buffer}") - res.append(f"set_workspace_multibuffer: {config.set_workspace_multibuffer}") - res.append(f"tile_mix_vector_loop: {config.tile_mix_vector_loop}") - res.append(f"tile_mix_cube_loop: {config.tile_mix_cube_loop}") - res.append(f"force_simt_template: {config.force_simt_template}") - - -def new_AutoTilingTuner(hints, fn, configs, key, reset_to_zero, restore_value, pre_hook, post_hook, prune_configs_by, - warmup, rep, use_cuda_graph, do_bench, auto_profile_dir): - if hints is not None and hints.get("enable_ascend_autotune"): - from .autotiling_tuner import AutoTilingTuner - return AutoTilingTuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, pre_hook=pre_hook, - post_hook=post_hook, prune_configs_by=prune_configs_by, warmup=warmup, rep=rep, - use_cuda_graph=use_cuda_graph, do_bench=do_bench, auto_profile_dir=auto_profile_dir) - return None + + def decorator(fn): + return Heuristics(fn, fn.arg_names, values) + + return decorator diff --git a/third_party/ascend/backend/spec/triton/runtime/code_cache.py b/third_party/ascend/backend/spec/triton/runtime/code_cache.py index 66534c505..43d841cd3 100644 --- a/third_party/ascend/backend/spec/triton/runtime/code_cache.py +++ b/third_party/ascend/backend/spec/triton/runtime/code_cache.py @@ -1,16 +1,25 @@ -# Copyright © 2024 BAAI +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# Copyright 2018-2020 Philippe Tillet +# Copyright 2020-2022 OpenAI +# Copyright © 2024 BAAI. All rights reserved. # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: # -# http://www.apache.org/licenses/LICENSE-2.0 +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. # Modifications: # - 2025-06-03: diff --git a/third_party/ascend/backend/spec/triton/runtime/interpreter.py b/third_party/ascend/backend/spec/triton/runtime/interpreter.py new file mode 100644 index 000000000..7ad9b1b9f --- /dev/null +++ b/third_party/ascend/backend/spec/triton/runtime/interpreter.py @@ -0,0 +1,1539 @@ +import ast +import textwrap +import inspect +from typing import Tuple + +import math +import numpy as np + +import triton +import triton.language as tl +from dataclasses import dataclass +from .errors import InterpreterError +from functools import partial +from .._C.libtriton import interpreter as _interpreter +from .._C.libtriton import ir as _ir + + +class TensorHandle: + + def __init__(self, data, dtype): + ''' + data: numpy array + dtype: triton type, either pointer_type or scalar_type. + we don't store block_type here because the shape information is already availale in the data field + attr: a dictionary of attributes + ''' + self.data = data + self.dtype = dtype + self.attr = {} + + def __bool__(self): + return bool(self.data.all()) + + def get_element_ty(self): + dtype = self.dtype + while hasattr(dtype, "element_ty"): + dtype = dtype.element_ty + return dtype + + def clone(self): + return TensorHandle(self.data.copy(), self.dtype) + + def set_attr(self, key, value): + self.attr[key] = value + + +class BlockPointerHandle: + + def __init__(self, base, shape, strides, offsets, tensor_shape, order): + self.base = base + self.shape = shape + self.strides = strides + self.offsets = offsets + self.tensor_shape = tensor_shape + self.order = order + + def materialize_pointers(self, boundary_check): + dtype_tt = self.base.get_element_ty() + n_bytes = dtype_tt.primitive_bitwidth // 8 + tensor_shape = self.tensor_shape + ptrs = np.broadcast_to(self.base.data, self.tensor_shape) + masks = np.ones(self.tensor_shape, dtype=bool) + for dim in range(len(tensor_shape)): + bcast_dims = [1] * len(tensor_shape) + bcast_dims[dim] = tensor_shape[dim] + off = (self.offsets[dim].data + np.arange(tensor_shape[dim])).reshape(bcast_dims) + ptrs = ptrs + (n_bytes * off * self.strides[dim].data).astype(np.uint64) + if dim in boundary_check: + masks = np.logical_and(masks, off < self.shape[dim].data) + ptrs = TensorHandle(ptrs, self.base.dtype.scalar) + return ptrs, masks + + +@dataclass(frozen=True) +class InterpreterOptions: + extern_libs: dict = None + debug: bool = False + sanitize_overflow: bool = True + arch: str = None + supported_fp8_dtypes: Tuple[str] = ("fp8e5", "fp8e5b16", "fp8e4nv", "fp8e4b8", "fp8e4b15") + deprecated_fp8_dtypes: Tuple[str] = () + default_dot_input_precision: str = "tf32" + allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee") + max_num_imprecise_acc_default: int = 0 + backend_name: str = "interpreter" + + +def _get_signed_np_dtype(dtype): + if dtype == np.uint8: + return np.int8 + if dtype == np.uint16: + return np.int16 + if dtype == np.uint32: + return np.int32 + if dtype == np.uint64: + return np.int64 + return dtype + + +def _get_np_dtype(tt_dtype): + if isinstance(tt_dtype, tl.pointer_type): + return np.dtype(np.uint64) + np_types = { + tl.int1: np.dtype(bool), + tl.float16: np.dtype(np.float16), + tl.float32: np.dtype(np.float32), + tl.float64: np.dtype(np.float64), + tl.int8: np.dtype(np.int8), + tl.uint8: np.dtype(np.uint8), + tl.int16: np.dtype(np.int16), + tl.uint16: np.dtype(np.uint16), + tl.int32: np.dtype(np.int32), + tl.uint32: np.dtype(np.uint32), + tl.int64: np.dtype(np.int64), + tl.uint64: np.dtype(np.uint64), + # bfloat16 types are stored as uint16 + tl.bfloat16: np.dtype(np.uint16), + # float8 types are stored as uint8 + tl.float8e5: np.dtype(np.uint8), + tl.float8e5b16: np.dtype(np.uint8), + tl.float8e4nv: np.dtype(np.uint8), + tl.float8e4b8: np.dtype(np.uint8), + tl.float8e4b15: np.dtype(np.uint8), + } + if isinstance(tt_dtype, tl.block_type): + if isinstance(tt_dtype.element_ty, tl.pointer_type): + return np.dtype(np.uint64) + return np_types[tt_dtype.element_ty] + return np_types[tt_dtype] + + +def _convert_float(input, input_dtype, output_dtype, rounding_mode): + input_uint_dtype = getattr(np, f"uint{input_dtype.primitive_bitwidth}") + output_unint_dtype = getattr(np, f"uint{output_dtype.primitive_bitwidth}") + input_bin = np.frombuffer(input.tobytes(), dtype=input_uint_dtype) + sign = (input_bin >> (input_dtype.primitive_bitwidth - 1)) & 0x01 + input_exponent_width = input_dtype.primitive_bitwidth - input_dtype.fp_mantissa_width - 1 + output_exponent_width = output_dtype.primitive_bitwidth - output_dtype.fp_mantissa_width - 1 + significand = input_bin & ((1 << input_dtype.fp_mantissa_width) - 1) + bias_input = input_dtype.exponent_bias + bias_output = output_dtype.exponent_bias + exponent = ((input_bin >> input_dtype.fp_mantissa_width) & ((1 << input_exponent_width) - 1)).astype(np.int32) + subnormal_index = exponent == 0 + if np.any(subnormal_index): + # Credit to Phil: phil@openai.com + # subnormal repr: ((-1.0)**sign) * (2.0**(1 - exp_bias)) * (2^(m0) + 2^(m1) + ... + 2^(mn)) + # where m0, m1, ..., mn are the 1-bit of the mantissa + # convert it to normal repr: ((-1.0)**sign) * (2.0**(1 + m0 - exp_bias)) * (1 + 2^(m1 - m0) + ... + 2^(mn - m0)) + bit_pos = np.zeros_like(input_bin, dtype=np.int32) + # Find the most significant bit of the mantissa in the significand + for i in range(input_dtype.fp_mantissa_width): + bit_index = ((significand >> i) & 0x01) + # pos should be >= 1 + bit_pos[bit_index == 1] = input_dtype.fp_mantissa_width - i + zero_significand_index = significand == 0 + exponent[subnormal_index] = 1 - bit_pos[subnormal_index] + # 0 significand and subnormal should be treated as 0 + exponent[zero_significand_index & subnormal_index] = bias_input - bias_output + significand[subnormal_index] = (significand[subnormal_index] << bit_pos[subnormal_index]) & ( + (1 << input_dtype.fp_mantissa_width) - 1) + # Prevent overflow and underflow + exponent_output = np.maximum(0, np.minimum((exponent - bias_input + bias_output), (1 << output_exponent_width) - 1)) + exponent_output = exponent_output.astype(output_unint_dtype) + sign_output = sign.astype(output_unint_dtype) + if input_dtype.primitive_bitwidth > output_dtype.primitive_bitwidth: # Downcast + significand_output = (significand >> (input_dtype.fp_mantissa_width - output_dtype.fp_mantissa_width)) & ( + (1 << output_dtype.fp_mantissa_width) - 1) + if rounding_mode == _ir.ROUNDING_MODE.RTNE: # Round to nearst even + # find the cut-off bit + cut_off = significand & (1 << (input_dtype.fp_mantissa_width - output_dtype.fp_mantissa_width - 1)) + significand_output = significand_output + (cut_off > 0) + significand_output = significand_output.astype(output_unint_dtype) + else: # Upcast + significand_output = (significand.astype(output_unint_dtype) << + (output_dtype.fp_mantissa_width - input_dtype.fp_mantissa_width)) & ( + (1 << output_dtype.fp_mantissa_width) - 1) + subnormal_index = exponent_output == 0 + if np.any(subnormal_index): # underflow + # normal repr: ((-1.0)**sign) * (2.0**(exp - exp_bias_input)) * (1 + 2^(m0) + 2^(m1) + ... + 2^(mn)) + # where m0, m1, ..., mn are the 1-bit of the mantissa + # shift = (1 - exp_bias_output) - (exp - exp_bias_input) + # convert it to subnormal repr: ((-1.0)**sign) * (2.0**(1 - exp_bias_output)) * (2^(-shift) + 2^(m0 - shift) + 2^(m1 - shift) + ... + 2^(mn - shift)) + exponent = ((input_bin >> input_dtype.fp_mantissa_width) & ((1 << input_exponent_width) - 1)).astype(np.int32) + non_zero_exponent_index = exponent != 0 + # If the original exponent is not zero, we still need to shift the significand and consider the 1.0 part in mantissa + subnormal_index = subnormal_index & non_zero_exponent_index + shift = np.zeros_like(input_bin, dtype=np.int32) + shift[subnormal_index] = (1 - bias_output) - (exponent[subnormal_index] - bias_input) + significand_output[subnormal_index] = (significand_output[subnormal_index] >> shift[subnormal_index]) | ( + 1 << (output_dtype.fp_mantissa_width - shift[subnormal_index])) + output = (sign_output << (output_dtype.primitive_bitwidth - 1)) | ( + exponent_output << output_dtype.fp_mantissa_width) | significand_output + return output.reshape(input.shape) + + +def _erf(x): + # Numpy does not support erf + return math.erf(x) + + +def _umulhi_64(a, b): + # Numpy does not support 128-bit multiplication + # So we have to implement it manually + return (int(a) * int(b)) >> 64 + + +np_erf_fp32 = np.vectorize(_erf, otypes=[np.float32]) +np_erf_fp64 = np.vectorize(_erf, otypes=[np.float64]) +np_umulhi_u64 = np.vectorize(_umulhi_64, otypes=[np.uint64]) + + +class ExtraFunctions: + + @staticmethod + def _convert_custom_types(input, dst_ty, fp_downcast_rounding, _builder): + return tl.tensor(_builder.create_fp_to_fp(input.handle, dst_ty, fp_downcast_rounding), dst_ty) + + +class InterpreterBuilder: + ir_sem_to_interpreter_sem = { + _ir.MEM_SEMANTIC.ACQUIRE: _interpreter.MEM_SEMANTIC.ACQUIRE, + _ir.MEM_SEMANTIC.RELEASE: _interpreter.MEM_SEMANTIC.RELEASE, + _ir.MEM_SEMANTIC.RELAXED: _interpreter.MEM_SEMANTIC.RELAXED, + _ir.MEM_SEMANTIC.ACQUIRE_RELEASE: _interpreter.MEM_SEMANTIC.ACQUIRE_RELEASE, + } + + ir_rmw_op_to_interpreter_rmw_op = { + _ir.ATOMIC_OP.ADD: _interpreter.RMW_OP.ADD, + _ir.ATOMIC_OP.FADD: _interpreter.RMW_OP.FADD, + _ir.ATOMIC_OP.MIN: _interpreter.RMW_OP.MIN, + _ir.ATOMIC_OP.UMIN: _interpreter.RMW_OP.UMIN, + _ir.ATOMIC_OP.MAX: _interpreter.RMW_OP.MAX, + _ir.ATOMIC_OP.UMAX: _interpreter.RMW_OP.UMAX, + _ir.ATOMIC_OP.AND: _interpreter.RMW_OP.AND, + _ir.ATOMIC_OP.OR: _interpreter.RMW_OP.OR, + _ir.ATOMIC_OP.XOR: _interpreter.RMW_OP.XOR, + _ir.ATOMIC_OP.XCHG: _interpreter.RMW_OP.XCHG, + } + + def __init__(self) -> None: + self.arch = None + self.options = InterpreterOptions() + self.codegen_fns = {} + self.codegen_fns["convert_custom_types"] = ExtraFunctions._convert_custom_types + # For interpreter mode, don't enforce GPU hardware shape constraints + # NumPy matmul works with any size, including small matrices + self.codegen_fns["min_dot_size"] = lambda lhsType, rhsType: (1, 1, 1) + # Sub-vector core ID for simulating 1:2 hardware ratio + self.sub_vec_id = 0 + + def set_grid_idx(self, x, y, z): + if not x < self.grid_dim[0]: + raise ValueError("x >= grid_dim[0]") + if not y < self.grid_dim[1]: + raise ValueError("y >= grid_dim[1]") + if not z < self.grid_dim[2]: + raise ValueError("z >= grid_dim[2]") + self.grid_idx = (x, y, z) + + def set_grid_dim(self, nx, ny, nz): + self.grid_dim = (nx, ny, nz) + + # constants + + def get_half_ty(self): + return tl.float16 + + def get_bf16_ty(self): + return tl.bfloat16 + + def get_float_ty(self): + return tl.float32 + + def get_double_ty(self): + return tl.float64 + + def get_int8_ty(self): + return tl.int8 + + def get_uint8_ty(self): + return tl.uint8 + + def get_int16_ty(self): + return tl.int16 + + def get_uint16_ty(self): + return tl.uint16 + + def get_int32_ty(self): + return tl.int32 + + def get_uint32_ty(self): + return tl.uint32 + + def get_int64_ty(self): + return tl.int64 + + def get_uint64_ty(self): + return tl.uint64 + + def get_fp8e4nv_ty(self): + return tl.float8e4nv + + def get_fp8e4b15_ty(self): + return tl.float8e4b15 + + def get_fp8e4b8_ty(self): + return tl.float8e4b8 + + def get_fp8e5_ty(self): + return tl.float8e5 + + def get_fp8e5b16_ty(self): + return tl.float8e5b16 + + def get_ptr_ty(self, elt_ty, addr_space): + return tl.pointer_type(elt_ty, addr_space) + + def get_block_ty(self, dtype, shape): + return tl.block_type(dtype, shape) + + def get_int1(self, value): + return TensorHandle(np.array([value], dtype=np.bool_), tl.int1) + + def get_uint8(self, value): + return TensorHandle(np.array([value], dtype=np.uint8), tl.uint8) + + def get_int8(self, value): + return TensorHandle(np.array([value], dtype=np.int8), tl.int8) + + def get_uint16(self, value): + return TensorHandle(np.array([value], dtype=np.uint16), tl.uint16) + + def get_int16(self, value): + return TensorHandle(np.array([value], dtype=np.int16), tl.int16) + + def get_uint32(self, value): + return TensorHandle(np.array([value], dtype=np.uint32), tl.uint32) + + def get_int32(self, value): + return TensorHandle(np.array([value], dtype=np.int32), tl.int32) + + def get_uint64(self, value): + return TensorHandle(np.array([value], dtype=np.uint64), tl.uint64) + + def get_int64(self, value): + return TensorHandle(np.array([value], dtype=np.int64), tl.int64) + + def get_fp16(self, value): + return TensorHandle(np.array([value], dtype=np.float16), tl.float16) + + def get_fp32(self, value): + return TensorHandle(np.array([value], dtype=np.float32), tl.float32) + + def get_fp64(self, value): + return TensorHandle(np.array([value], dtype=np.float64), tl.float64) + + def get_null_value(self, type): + return TensorHandle(np.array([0], dtype=_get_np_dtype(type)), type) + + # programming model + def create_get_program_id(self, axis): + if self.grid_idx is None: + raise ValueError("grid_idx is None") + return TensorHandle(np.array([self.grid_idx[axis]], dtype=np.int32), tl.int32) + + def create_get_num_programs(self, axis): + return TensorHandle(np.array([self.grid_dim[axis]], dtype=np.int32), tl.int32) + + # memory ops + def create_load(self, ptr, _0, _1, is_volatile): + mask = TensorHandle(np.ones_like(ptr.data, dtype=bool), tl.int1) + other = None + return self.create_masked_load(ptr, mask, other, _0, _1, is_volatile) + + def create_store(self, ptr, val, _0, _1): + mask = TensorHandle(np.ones_like(ptr.data, dtype=bool), tl.int1) + return self.create_masked_store(ptr, val, mask, None, None) + + def create_masked_load(self, ptrs, mask, other, cache_modifier, eviction_policy, is_volatile): + dtype_tt = ptrs.get_element_ty() + dtype_np = _get_np_dtype(dtype_tt) + if other is None: + other = TensorHandle(np.zeros_like(ptrs.data, dtype=dtype_np), dtype_tt) + ret = _interpreter.load(ptrs.data, mask.data, other.data, dtype_np) + return TensorHandle(ret, dtype_tt) + + def create_masked_store(self, ptrs, value, mask, cache_modifier, eviction_policy): + return _interpreter.store(ptrs.data, value.data, mask.data) + + # casting ops + def cast_impl(self, src, dst_type): + src_element_type = src.dtype.scalar + dst_element_type = dst_type.scalar + if (src_element_type == tl.bfloat16 and dst_element_type == tl.float32) or \ + (src_element_type == tl.float32 and dst_element_type == tl.bfloat16): + data = _convert_float(src.data, src_element_type, dst_element_type, None).view(_get_np_dtype(dst_type)) + return TensorHandle(data, dst_type.scalar) + else: + return TensorHandle(src.data.astype(_get_np_dtype(dst_type)), dst_type.scalar) + + create_si_to_fp = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_ui_to_fp = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_fp_to_si = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_fp_to_ui = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_fp_ext = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_fp_trunc = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_int_cast = lambda self, src, dst_type, is_signed: self.cast_impl(src, dst_type) + + def create_fp_to_fp(self, src, dst_type, rounding_mode): + src_element_type = src.dtype.scalar + dst_element_type = dst_type.scalar + data = _convert_float(src.data, src_element_type, dst_element_type, rounding_mode).view(_get_np_dtype(dst_type)) + return TensorHandle(data, dst_type.scalar) + + def create_bitcast(self, src, dst_type): + return TensorHandle(src.data.view(_get_np_dtype(dst_type)), dst_type.scalar) + + # binary operators + def binary_op(self, lhs, rhs, op): + return TensorHandle(op(lhs.data, rhs.data), lhs.dtype.scalar) + + create_fadd = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.add) + create_fmul = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.multiply) + create_fdiv = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.divide) + create_frem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.remainder) + create_fsub = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.subtract) + create_mul = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.multiply) + create_precise_divf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.divide) + create_sdiv = lambda self, lhs, rhs: self.create_idiv(lhs, rhs) + create_udiv = lambda self, lhs, rhs: self.create_idiv(lhs, rhs) + # LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders. + create_srem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.fmod) + create_urem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.fmod) + create_add = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.add) + create_sub = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.subtract) + create_shl = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.left_shift) + create_lshr = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.right_shift) + create_minsi = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum) + create_minui = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum) + create_minimumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum) + create_minnumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum) + create_maxsi = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum) + create_maxui = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum) + create_maximumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum) + create_maxnumf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum) + create_icmpSLE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal) + create_icmpSLT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less) + create_icmpSGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal) + create_icmpSGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater) + create_icmpULE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal) + create_icmpULT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less) + create_icmpUGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal) + create_icmpUGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater) + create_icmpEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal) + create_icmpNE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal) + create_fcmpOLT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less) + create_fcmpOGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater) + create_fcmpOLE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal) + create_fcmpOGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal) + create_fcmpOEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal) + create_fcmpONE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal) + create_fcmpULT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less) + create_fcmpUGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater) + create_fcmpULE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal) + create_fcmpUGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal) + create_fcmpUEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal) + create_fcmpUNE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal) + create_and = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_and) + create_xor = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_xor) + create_or = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_or) + create_int_to_ptr = create_bitcast + create_ptr_to_int = create_bitcast + + def create_idiv(self, lhs, rhs): + # Triton has IEEE, not numpy/torch, semantics for %, and those carry + # through to //, so we have to use a nonstandard expression to get a + # reference result for //. + return TensorHandle((lhs.data - np.fmod(lhs.data, rhs.data)) // rhs.data, lhs.dtype.scalar) + + def create_ashr(self, lhs, rhs): + # Triton's rshift operator depends on the signedness of the left operand + lhs_dtype = _get_signed_np_dtype(lhs.data.dtype) + rhs_dtype = _get_signed_np_dtype(rhs.data.dtype) + lhs.data = lhs.data.astype(lhs_dtype) + rhs.data = rhs.data.astype(rhs_dtype) + return self.binary_op(lhs, rhs, np.right_shift) + + def create_umulhi(self, lhs, rhs): + dtype = lhs.data.dtype + if dtype == np.int64 or dtype == np.uint64: + return TensorHandle(np_umulhi_u64(lhs.data, rhs.data), lhs.dtype.scalar) + else: + compute_dtype = getattr(np, f"uint{dtype.itemsize * 8 * 2}") + lhs_data = lhs.data.astype(compute_dtype) + rhs_data = rhs.data.astype(compute_dtype) + ret_data = np.multiply(lhs_data, rhs_data) >> (dtype.itemsize * 8) + return TensorHandle(ret_data.astype(dtype), lhs.dtype.scalar) + + # ternary functions + def ternary_op(self, lhs, rhs, other, op): + return TensorHandle(op(lhs.data, rhs.data, other.data), other.dtype.scalar) + + create_clampf = lambda self, arg, lo, hi, propagate_nans: self.ternary_op(arg, lo, hi, np.clip) + create_select = lambda self, cond, lhs, rhs: self.ternary_op(cond, lhs, rhs, np.where) + + def create_fma(self, x, y, z): + return TensorHandle(x.data * y.data + z.data, z.dtype.scalar) + + # unary functions + def unary_op(self, arg, op): + return TensorHandle(op(arg.data), arg.dtype.scalar) + + def create_fabs(self, arg): + # Mask out the sign bit based on the primitive length + dtype_tt = arg.dtype + mask_bitwidth = dtype_tt.primitive_bitwidth - 1 + np_uint_dtype = getattr(np, f"uint{dtype_tt.primitive_bitwidth}") + data = arg.data.view(np_uint_dtype) + mask = (1 << mask_bitwidth) - 1 + ret = (data & mask).view(_get_np_dtype(dtype_tt)) + return TensorHandle(ret, arg.dtype.scalar) + + create_cos = lambda self, arg: self.unary_op(arg, np.cos) + create_exp = lambda self, arg: self.unary_op(arg, np.exp) + create_exp2 = lambda self, arg: self.unary_op(arg, np.exp2) + create_iabs = lambda self, arg: self.unary_op(arg, np.abs) + create_floor = lambda self, arg: self.unary_op(arg, np.floor) + create_ceil = lambda self, arg: self.unary_op(arg, np.ceil) + create_log = lambda self, arg: self.unary_op(arg, np.log) + create_log2 = lambda self, arg: self.unary_op(arg, np.log2) + create_precise_sqrt = lambda self, arg: self.unary_op(arg, np.sqrt) + create_sqrt = lambda self, arg: self.unary_op(arg, np.sqrt) + create_sin = lambda self, arg: self.unary_op(arg, np.sin) + + def create_erf(self, arg): + ret = np_erf_fp32(arg.data) if arg.data.dtype == np.float32 else np_erf_fp64(arg.data) + return TensorHandle(ret, arg.dtype.scalar) + + def create_rsqrt(self, arg): + return TensorHandle(1 / np.sqrt(arg.data), arg.dtype.scalar) + + # tensor operators + create_reshape = lambda self, arg, shape, allow_reorder: TensorHandle(arg.data.reshape(shape), arg.dtype.scalar) + + def create_trans(self, arg, perm): + return TensorHandle(np.transpose(arg.data, perm), arg.dtype.scalar) + + def create_dot(self, a, b, d, input_precision, max_num_imprecise_acc): + a_data = a.data + b_data = b.data + if (a.dtype.primitive_bitwidth == 8 and a.dtype.is_floating()) or \ + (b.dtype.primitive_bitwidth == 8 and b.dtype.is_floating()): + a_data = _convert_float(a_data, a.dtype, tl.float16, None).view(np.float16) + b_data = _convert_float(b_data, b.dtype, tl.float16, None).view(np.float16) + return TensorHandle(np.matmul(a_data, b_data, dtype=d.data.dtype) + d.data, d.dtype.scalar) + + def create_make_range(self, start, stop): + return TensorHandle(np.arange(start, stop, dtype=np.int32), tl.int32) + + def create_histogram(self, data, bins): + return TensorHandle(np.histogram(data.data, bins=bins, range=(0, bins))[0], tl.int32) + + # pointer arithmetic + + def create_addptr(self, ptr, offset): + dtype_tt = ptr.get_element_ty() + element_bitwidth = dtype_tt.primitive_bitwidth + # int1's bitwidth is 1, but we need to use 8 for pointer arithmetic + element_bytewidth = max(1, element_bitwidth // 8) + return TensorHandle(ptr.data + element_bytewidth * offset.data.astype(np.uint64), ptr.dtype) + + def create_tensor_pointer_load(self, ptr, boundary_check, padding_option, cache_modifier, eviction_policy, + is_volatile): + ptrs, masks = ptr.materialize_pointers(boundary_check) + dtype_tt = ptrs.get_element_ty() + dtype_np = _get_np_dtype(dtype_tt) + if padding_option is None: + other = None + elif padding_option == _ir.PADDING_OPTION.PAD_ZERO: + other = TensorHandle(np.zeros_like(ptrs.data, dtype=dtype_np), dtype_tt) + elif padding_option == _ir.PADDING_OPTION.PAD_NAN: + other = TensorHandle(np.full_like(ptrs.data, float('nan'), dtype=dtype_np), dtype_tt) + else: + raise ValueError(f"unsupported padding option {padding_option}") + return self.create_masked_load(ptrs, masks, other, cache_modifier, eviction_policy, is_volatile) + + def create_tensor_pointer_store(self, ptr, value, boundary_check, cache_modifier, eviction_policy): + ptrs, masks = ptr.materialize_pointers(boundary_check) + return self.create_masked_store(ptrs, value, masks, cache_modifier, eviction_policy) + + def create_expand_dims(self, arg, axis): + return TensorHandle(np.expand_dims(arg.data, axis), arg.dtype.scalar) + + def create_broadcast(self, arg, shape): + return TensorHandle(np.broadcast_to(arg.data, shape), arg.dtype.scalar) + + def create_cat(self, lhs, rhs): + return TensorHandle(np.concatenate([lhs.data, rhs.data]), lhs.dtype.scalar) + + def create_join(self, lhs, rhs): + # Triton only supports joining two original tensors into a new one along the last axis + return TensorHandle(np.stack([lhs.data, rhs.data], axis=-1), lhs.dtype.scalar) + + def create_split(self, val): + # Triton only supports splitting the original tensor into two along the last axis + return (TensorHandle(val.data[..., 0], val.dtype.scalar), TensorHandle(val.data[..., 1], val.dtype.scalar)) + + def create_splat(self, arg, shape): + if isinstance(arg.dtype, tl.block_type): + return TensorHandle(np.full(shape, arg.data[0], dtype=_get_np_dtype(arg.dtype)), arg.dtype.scalar) + else: # scalar + return TensorHandle(np.full(shape, arg.data, dtype=_get_np_dtype(arg.dtype)), arg.dtype.scalar) + + # Extension ops for Ascend + def create_extract_scalar(self, tensor_handle, indices): + """ + Extract a scalar from a tensor using indices (equivalent to get_element). + + :param tensor_handle: The tensor to extract from + :param indices: List of scalar indices (can be TensorHandle or Python int) + :return: Scalar value + """ + # Convert indices from TensorHandle or Python int to integers + index_values = [] + for idx in indices: + if isinstance(idx, int): + # Python int passed directly (e.g., from loop counter) + index_values.append(idx) + elif isinstance(idx, TensorHandle): + # Interpreter TensorHandle + index_values.append(int(idx.data.item()) if hasattr(idx.data, 'item') else int(idx.data)) + else: + # Fallback: try to extract data + index_values.append( + int(idx.data.item()) if hasattr(idx, 'data') and hasattr(idx.data, 'item') else + int(idx.data) if hasattr(idx, 'data') else int(idx)) + + # Extract the scalar value + scalar_data = tensor_handle.data[tuple(index_values)] + return TensorHandle(np.array([scalar_data]), tensor_handle.dtype.scalar) + + def create_insert_slice(self, full_tensor, sub_tensor, offsets, sizes, strides): + """ + Insert a sub-tensor into a full tensor at specified offsets. + + :param full_tensor: The full tensor (destination) + :param sub_tensor: The sub-tensor to insert + :param offsets: List of offset TensorHandle objects or Python ints + :param sizes: List of size integers + :param strides: List of stride integers + :return: Modified tensor with sub_tensor inserted + """ + result = full_tensor.data.copy() + + # Convert offsets from TensorHandle or Python int to integers + offset_values = [] + for off in offsets: + if isinstance(off, int): + # Python int passed directly + offset_values.append(off) + elif isinstance(off, TensorHandle): + # Interpreter TensorHandle + offset_values.append(int(off.data.item()) if hasattr(off.data, 'item') else int(off.data)) + else: + # Fallback + offset_values.append( + int(off.data.item()) if hasattr(off, 'data') and hasattr(off.data, 'item') else + int(off.data) if hasattr(off, 'data') else int(off)) + + # Build slices for insertion + slices = [] + for i, (offset, size, stride) in enumerate(zip(offset_values, sizes, strides)): + end = offset + size * stride + if stride == 1: + slices.append(slice(offset, end)) + else: + slices.append(slice(offset, end, stride)) + + # Insert the sub-tensor + result[tuple(slices)] = sub_tensor.data + + return TensorHandle(result, full_tensor.dtype.scalar) + + def create_extract_slice(self, full_tensor, offsets, sizes, strides): + """ + Extract a slice from a full tensor. + + :param full_tensor: The full tensor + :param offsets: List of offset TensorHandle objects or Python ints + :param sizes: List of size integers + :param strides: List of stride integers + :return: Extracted sub-tensor + """ + # Convert offsets from TensorHandle or Python int to integers + offset_values = [] + for off in offsets: + if isinstance(off, int): + # Python int passed directly + offset_values.append(off) + elif isinstance(off, TensorHandle): + # Interpreter TensorHandle + offset_values.append(int(off.data.item()) if hasattr(off.data, 'item') else int(off.data)) + else: + # Fallback + offset_values.append( + int(off.data.item()) if hasattr(off, 'data') and hasattr(off.data, 'item') else + int(off.data) if hasattr(off, 'data') else int(off)) + + # Build slices for extraction + slices = [] + for i, (offset, size, stride) in enumerate(zip(offset_values, sizes, strides)): + end = offset + size * stride + if stride == 1: + slices.append(slice(offset, end)) + else: + slices.append(slice(offset, end, stride)) + + # Extract the slice + extracted = full_tensor.data[tuple(slices)] + + return TensorHandle(extracted, full_tensor.dtype.scalar) + + def create_index_select_simd(self, src_ptr, index_tensor, dim, src_shape, src_offset, read_shape, result_shape): + """ + SIMD index_select operation (gather with indices along a dimension). + + :param src_ptr: Source tensor pointer + :param index_tensor: 1D tensor of indices + :param dim: Dimension to select from + :param src_shape: List of source shape (int or TensorHandle) + :param src_offset: List of source offset (int or TensorHandle) + :param read_shape: List of read shape (int or TensorHandle) + :param result_shape: List of result shape (int or TensorHandle) + :return: Result tensor with selected indices + """ + + # Convert src_shape, src_offset, read_shape to integers + def to_int(val): + if isinstance(val, TensorHandle): + return int(val.data.item()) + return int(val) + + src_shape_vals = [to_int(s) for s in src_shape] + src_offset_vals = [to_int(o) if o != -1 else -1 for o in src_offset] + read_shape_vals = [to_int(r) if r != -1 else -1 for r in read_shape] + result_shape_vals = [to_int(r) for r in result_shape] + + # Get index values - handle both array and TensorHandle + if isinstance(index_tensor, TensorHandle): + indices = index_tensor.data.flatten() + else: + indices = np.asarray(index_tensor).flatten() + + # Ensure indices are integers + if indices.dtype not in [np.int32, np.int64]: + indices = indices.astype(np.int32) + + # Create result tensor + result = np.empty(result_shape_vals, dtype=src_ptr.data.dtype) + + # Perform index_select: for each index, read the specified data + for out_idx, in_idx in enumerate(indices): + in_idx = int(in_idx) + + # Validate index bounds + if not (0 <= in_idx < src_shape_vals[dim]): + # Out of bounds - fill with zeros + result_slices = [slice(None)] * len(result_shape_vals) + result_slices[dim] = slice(out_idx, out_idx + 1) + result[tuple(result_slices)] = 0 + continue + + # Build source slice + src_slices = [] + for d in range(len(src_shape_vals)): + if d == dim: + src_slices.append(slice(in_idx, in_idx + 1)) + else: + offset = src_offset_vals[d] if src_offset_vals[d] != -1 else 0 + read_size = read_shape_vals[d] if read_shape_vals[d] != -1 else src_shape_vals[d] + # Clamp to valid range + offset = max(0, min(offset, src_shape_vals[d] - 1)) + read_size = min(read_size, src_shape_vals[d] - offset) + src_slices.append(slice(offset, offset + read_size)) + + # Build result slice + result_slices = [] + for d in range(len(result_shape_vals)): + if d == dim: + result_slices.append(slice(out_idx, out_idx + 1)) + else: + result_slices.append(slice(None)) + + # Copy data with proper shape handling + try: + src_data = src_ptr.data[tuple(src_slices)] + # Handle shape mismatch by resizing + target_shape = [result_shape_vals[d] if d != dim else 1 for d in range(len(result_shape_vals))] + if src_data.shape != tuple(target_shape): + # Pad or trim as needed + pad_width = [(0, target_shape[d] - src_data.shape[d]) for d in range(len(target_shape))] + src_data = np.pad(src_data, pad_width, mode='constant', constant_values=0) + result[tuple(result_slices)] = src_data + except Exception as e: + # On error, fill with zeros + result[tuple(result_slices)] = 0 + + return TensorHandle(result, src_ptr.dtype.scalar) + + def create_get_sub_vec_id(self): + """ + Get the Vector Core index on the AI Core. + + In Interpreter mode, simulate multiple vector cores by maintaining + a sub_vec_id counter. This is used for 1:2 hardware ratio emulation + where different vector cores process different partitions of the data. + + :return: Vector Core ID as TensorHandle (int64, scalar) + """ + # Return the current sub_vec_id (set by GridExecutor) + vec_id = np.int64(self.sub_vec_id) + return TensorHandle(np.array([vec_id], dtype=np.int64), tl.int64) + + def sync_block_set(self, sender, receiver, event_id, sender_pipe_value, receiver_pipe_value): + """ + Set synchronization event between compute and vector units. + + In Interpreter mode, this is a no-op since we execute single-threaded. + Synchronization is not needed in CPU emulation. + + :param sender: Source unit ("cube" or "vector") + :param receiver: Destination unit ("cube" or "vector") + :param event_id: Event ID (TensorHandle) + :param sender_pipe_value: Sender pipe value + :param receiver_pipe_value: Receiver pipe value + """ + # No-op in interpreter mode: single-threaded execution doesn't need sync + pass + + def sync_block_wait(self, sender, receiver, event_id, sender_pipe_value, receiver_pipe_value): + """ + Wait for synchronization event between compute and vector units. + + In Interpreter mode, this is a no-op since we execute single-threaded. + Synchronization is not needed in CPU emulation. + + :param sender: Source unit ("cube" or "vector") + :param receiver: Destination unit ("cube" or "vector") + :param event_id: Event ID (TensorHandle) + :param sender_pipe_value: Sender pipe value + :param receiver_pipe_value: Receiver pipe value + """ + # No-op in interpreter mode: single-threaded execution doesn't need sync + pass + + def sync_block_all(self, mode, event_id): + """ + Synchronize all compute or vector units globally. + + In Interpreter mode, this is a no-op since we execute single-threaded. + Synchronization is not needed in CPU emulation. + + :param mode: Sync mode ("all_cube", "all_vector", "all", "all_sub_vector") + :param event_id: Event ID (int, constexpr, or TensorHandle) + """ + # No-op in interpreter mode: single-threaded execution doesn't need sync + pass + + def create_atomic_cas(self, ptr, cmp, val, sem, scope): + if sem not in self.ir_sem_to_interpreter_sem: + raise ValueError(f"unsupported semantic {sem}") + sem = self.ir_sem_to_interpreter_sem[sem] + return TensorHandle(_interpreter.atomic_cas(ptr.data, cmp.data, val.data, sem), cmp.dtype.scalar) + + def create_atomic_rmw(self, rmwOp, ptr, val, mask, sem, scope): + if rmwOp not in self.ir_rmw_op_to_interpreter_rmw_op: + raise ValueError(f"unsupported rmwOp {rmwOp}") + if sem not in self.ir_sem_to_interpreter_sem: + raise ValueError(f"unsupported semantic {sem}") + rmwOp = self.ir_rmw_op_to_interpreter_rmw_op[rmwOp] + sem = self.ir_sem_to_interpreter_sem[sem] + return TensorHandle(_interpreter.atomic_rmw(rmwOp, ptr.data, val.data, mask.data, sem), val.dtype.scalar) + + def create_extern_elementwise(self, libName, libPath, symbol, argList, retType, isPure): + raise NotImplementedError("extern_elementwise not supported in interpreter mode") + + def create_inline_asm(self, inlineAsm, constraints, values, type, isPure, pack): + raise NotImplementedError("inline_asm not supported in interpreter mode") + + def create_print(self, prefix, hex, values, isSigned): + # NOTE: the `isSigned` variable is not really used here; because Signness is already known + # by `values` themselves in python interpreter, thus not really needed here; + # it is only used for triton PrintOpToLLVM to correctly construct the format specifier. + # Interpreter's device_print function has a different format than Triton's device_print + msg = f"({self.grid_idx[0]}, {self.grid_idx[1]}, {self.grid_idx[2]})" + if prefix: + msg += f" {prefix}" + if hex: + np.set_printoptions(formatter={'all': lambda x: f"0x{x:02x}"}) + for value in values: + print(msg + f" {value.data}") + if hex: + np.set_printoptions(formatter=None) + + def create_assert(self, condition, message): + # Interpreter's device_assert function has a different format than Triton's device_assert + assert condition, f"{message}" + + def create_assume(self, condition): + assert condition, "Assume failed" + + def create_barrier(self): + # Triton's barrier applies to each program in a grid, so it's a no-op in the interpreter + pass + + def create_make_block_ptr(self, base, shape, strides, offsets, tensor_shape, order): + # Create new offsets to avoid modifying the original + new_offsets = [offset.clone() for offset in offsets] + return BlockPointerHandle(base, shape, strides, new_offsets, tensor_shape, order) + + def create_advance(self, ptr, offsets): + if len(ptr.offsets) != len(offsets): + raise ValueError("len(ptr.offsets) != len(offsets)") + # Create new offsets to avoid modifying the original + new_offsets = [offset.clone() for offset in ptr.offsets] + ret = BlockPointerHandle(ptr.base, ptr.shape, ptr.strides, new_offsets, ptr.tensor_shape, ptr.order) + for i in range(len(offsets)): + ret.offsets[i].data += offsets[i].data + return ret + + def get_all_ones_value(self, type): + np_type = _get_np_dtype(type) + if "int" in np_type.name: + return TensorHandle(np.full(1, -1, dtype=np_type), type.scalar) + else: + raise TypeError(f"unsupported type {type}") + + +def _patch_attr(obj, name, member, builder): + new_member = lambda *args, member=member, **kwargs: (member(*args, ** + {k: v + for k, v in kwargs.items() + if k != "_builder"}, _builder=builder)) + setattr(obj, name, new_member) + + +def _patch_builtin(pkg, builder): + for name, member in inspect.getmembers(pkg): + if tl.core.is_builtin(member): + _patch_attr(pkg, name, member, builder) + + +def _patch_lang_tensor(tensor): + + def _get_bool(self): + data = self.handle.data + # in triton, only scalars can be converted to booleans + # here we need this hack because all scalars are tensors + return bool(data) if data.size == 1 else True + + def _get_transpose(self): + return tl.core.tensor(TensorHandle(np.transpose(self.handle.data), self.handle.dtype), self.dtype.scalar) + + tensor.__index__ = lambda self: int(self.handle.data) + tensor.__bool__ = lambda self: _get_bool(self) + tensor.__repr__ = lambda self: repr(self.handle.data) + tensor.__str__ = lambda self: str(self.handle.data) + tensor.T = property(_get_transpose) + + +class ReduceScanOpIneterface: + + def __init__(self, axis, combine_fn): + self.axis = axis + self.combine_fn = combine_fn + + def check_axis(self, shape, axis): + if axis is not None and axis >= len(shape): + raise ValueError(f"axis {axis} out of bounds for shape {shape}") + + def check_tensor(self, input): + for arg in input: + if not isinstance(arg, tl.core.tensor): + raise ValueError(f"input must be a tensor, got {type(arg)}") + self.check_axis(arg.shape, self.axis) + + def to_tensor(self, ret, dtype): + if hasattr(ret, "shape") and ret.shape: + ret_type = tl.block_type(dtype, ret.shape) + else: + ret = np.array([ret]).astype(_get_np_dtype(dtype)) + ret_type = dtype + return tl.core.tensor(TensorHandle(ret, dtype.scalar), ret_type) + + def apply(self, input): + if not isinstance(input, tuple): + input = (input, ) + self.check_tensor(input) + return self.apply_impl(input) + + def apply_impl(self, input): + raise NotImplementedError("apply_impl not implemented") + + +class ReduceOps(ReduceScanOpIneterface): + + def __init__(self, axis, combine_fn, keep_dims): + super().__init__(axis, combine_fn) + self.keep_dims = keep_dims + + def unravel(self, input, axis): + ret = [] + for data in input: + if axis is not None: + ret.append(data) + else: + axis = 0 + ret.append(self.to_tensor(data.handle.data.flatten(), data.dtype)) + return tuple(ret), axis + + def generic_reduce(self, input): + original_axis = self.axis + input, axis = self.unravel(input, self.axis) + input_data = [] + output_data = [] + input_shape = input[0].handle.data.shape + output_shape = input_shape[0:axis] + input_shape[axis + 1:] + for arg in input: + input_data.append(arg.handle.data) + output_data.append(np.zeros(output_shape, dtype=arg.handle.data.dtype)) + # Reduce on axis + for i in range(input_data[0].size): + # Recover input_index from i using input_shape + input_index = np.unravel_index(i, input_shape) + output_index = input_index[0:axis] + input_index[axis + 1:] + input_tuple = tuple(self.to_tensor(d[input_index], input[ii].dtype) for ii, d in enumerate(input_data)) + if input_index[axis] == 0: + # First element + for j in range(len(output_data)): + output_data[j][output_index] = input_tuple[j].handle.data.item() + else: + acc_tuple = tuple(self.to_tensor(o[output_index], input[oi].dtype) for oi, o in enumerate(output_data)) + combine_fn_ret = self.combine_fn.fn(*acc_tuple, *input_tuple) + acc_tuple = (combine_fn_ret, ) if not isinstance(combine_fn_ret, tuple) else combine_fn_ret + for j in range(len(output_data)): + output_data[j][output_index] = acc_tuple[j].handle.data.item() if isinstance( + acc_tuple[j], tl.core.tensor) else acc_tuple[j] + # Pack output + ret = [] + for i, data in enumerate(output_data): + if self.keep_dims: + if original_axis is not None: + data = np.expand_dims(data, axis) + else: + for _ in range(len(input_shape)): + data = np.expand_dims(data, 0) + + elif original_axis is None: + # Take a scalar + data = data.item() + ret.append(self.to_tensor(data, input[i].dtype)) + return ret[0] if len(ret) == 1 else tuple(ret) + + def min_max(self, input, val_reduce_op, idx_reduce_op=None): + # If input is a tuple, it must be (val, index), and we only take val + input = input[0] if isinstance(input, tuple) else input + val = None + idx = None + if val_reduce_op: + val = self.to_tensor(val_reduce_op(input.handle.data, axis=self.axis, keepdims=self.keep_dims), input.dtype) + if idx_reduce_op: + idx = self.to_tensor(idx_reduce_op(input.handle.data, axis=self.axis, keepdims=self.keep_dims), tl.int32) + if val is not None and idx is not None: + return val, idx + elif val is not None: + return val + elif idx is not None: + return idx + else: + raise ValueError("val_reduce_op and idx_reduce_op are both None") + + def sum(self, input): + return self.to_tensor(np.sum(input.handle.data, axis=self.axis, keepdims=self.keep_dims), input.dtype) + + def apply_impl(self, input): + if self.combine_fn == tl.standard._argmin_combine_tie_break_left: + return self.min_max(input[0], val_reduce_op=np.min, idx_reduce_op=np.argmin) + elif self.combine_fn == tl.standard._argmax_combine_tie_break_left: + return self.min_max(input[0], val_reduce_op=np.max, idx_reduce_op=np.argmax) + elif self.combine_fn == tl.standard._elementwise_max: + return self.min_max(input[0], val_reduce_op=np.max, idx_reduce_op=None) + elif self.combine_fn == tl.standard._elementwise_min: + return self.min_max(input[0], val_reduce_op=np.min, idx_reduce_op=None) + elif self.combine_fn == tl.standard._sum_combine: + return self.sum(input[0]) + else: + # Fall back to the slow mode + return self.generic_reduce(input) + + +class ScanOps(ReduceScanOpIneterface): + + def __init__(self, axis, combine_fn, reverse): + super().__init__(axis, combine_fn) + self.reverse = reverse + + def cumsum(self, input): + return [self.to_tensor(np.cumsum(input.handle.data, axis=self.axis), dtype=input.dtype)] + + def cumprod(self, input): + return [self.to_tensor(np.cumprod(input.handle.data, axis=self.axis), dtype=input.dtype)] + + def generic_scan(self, input): + input_data = [] + output_data = [] + shape = input[0].handle.data.shape + for arg in input: + input_data.append(arg.handle.data) + output_data.append(np.zeros(shape, dtype=arg.handle.data.dtype)) + # Scan on axis + for i in range(input_data[0].size): + # Recover index from i using shape + index = np.unravel_index(i, shape) + data = tuple(self.to_tensor(d[index], input[ii].dtype) for ii, d in enumerate(input_data)) + if index[self.axis] == 0: + # First element + for j in range(len(output_data)): + output_data[j][index] = data[j].handle.data.item() + else: + prev_index = tuple(index[i] - 1 if i == self.axis else index[i] for i in range(len(index))) + acc_tuple = tuple(self.to_tensor(o[prev_index], input[oi].dtype) for oi, o in enumerate(output_data)) + combine_fn_ret = self.combine_fn.fn(*acc_tuple, *data) + acc_tuple = (combine_fn_ret, ) if not isinstance(combine_fn_ret, tuple) else combine_fn_ret + for j in range(len(output_data)): + output_data[j][index] = acc_tuple[j].handle.data.item() if isinstance( + acc_tuple[j], tl.core.tensor) else acc_tuple[j] + # Pack output + ret = [] + for i, data in enumerate(output_data): + ret.append(self.to_tensor(data, input[i].dtype)) + return ret + + def apply_impl(self, input): + new_input = [] + if self.reverse: + for arg in input: + new_input.append(self.to_tensor(np.flip(arg.handle.data, axis=self.axis), arg.dtype)) + else: + new_input = input + if self.combine_fn == tl.standard._sum_combine: + ret = self.cumsum(new_input[0]) + elif self.combine_fn == tl.standard._prod_combine: + ret = self.cumprod(new_input[0]) + else: + # Fall back to the slow mode + ret = self.generic_scan(new_input) + if self.reverse: + for arg in ret: + arg.handle.data = np.flip(arg.handle.data, axis=self.axis) + return len(ret) == 1 and ret[0] or tuple(ret) + + +def _patch_reduce_scan(): + # Because interpreter doesn't support region_builder_fn, we cannot patch the builder + # to use the new reduce and scan functions. + # Instead, we need to patch reduce and reduce functions in tl and tl.core + def _new_reduce(input, axis, combine_fn, keep_dims=False, **kwargs): + return ReduceOps(axis, combine_fn, keep_dims).apply(input) + + def _new_scan(input, axis, combine_fn, reverse=False, **kwargs): + return ScanOps(axis, combine_fn, reverse).apply(input) + + tl.reduce = _new_reduce + tl.associative_scan = _new_scan + tl.core.reduce = _new_reduce + tl.core.associative_scan = _new_scan + + +def _patch_lang_core(lang): + + def _new_to_ir(self, builder): + # We need to specify signedness for integer types in the numpy mode + if self.name == 'void': + return builder.get_void_ty() + elif self.name == 'int1': + return builder.get_int1_ty() + elif self.name == 'int8': + return builder.get_int8_ty() + elif self.name == 'uint8': + return builder.get_uint8_ty() + elif self.name == 'int16': + return builder.get_int16_ty() + elif self.name == 'uint16': + return builder.get_uint16_ty() + elif self.name == 'int32': + return builder.get_int32_ty() + elif self.name == 'uint32': + return builder.get_uint32_ty() + elif self.name == 'int64': + return builder.get_int64_ty() + elif self.name == 'uint64': + return builder.get_uint64_ty() + elif self.name == 'fp8e5': + return builder.get_fp8e5_ty() + elif self.name == 'fp8e4nv': + return builder.get_fp8e4nv_ty() + elif self.name == 'fp8e4b15': + return builder.get_fp8e4b15_ty() + elif self.name == 'fp16': + return builder.get_half_ty() + elif self.name == 'bf16': + return builder.get_bf16_ty() + elif self.name == 'fp32': + return builder.get_float_ty() + elif self.name == 'fp64': + return builder.get_double_ty() + raise ValueError(f'fail to convert {self} to ir type') + + # can't just map lang.static_range to `range`, because `tl.static_range` + # can get `step` passed by keyword + def _new_range(arg1, arg2=None, step=None, **kwargs): + if step is None: + step = 1 + if arg2 is None: + start, end = 0, arg1 + else: + start, end = arg1, arg2 + return range(start, end, step) + + def _new_static_assert(cond, msg=""): + assert cond, msg + + def _set_attr(input, values, name): + # skip non tensor types. This may happen for induction variables. + if not isinstance(input, tl.tensor): + return input + # Unwrap constexpr + values = [values] if not isinstance(values, (list, tuple)) else values + values = [v.value if isinstance(v, tl.constexpr) else v for v in values] + if len(values) != max(1, len(input.shape)): + raise ValueError(f"len(values) != len(input.shape) for {name}") + input.handle.set_attr(name, values) + return input + + lang.range = _new_range + lang.static_range = _new_range + lang.static_assert = _new_static_assert + lang.static_print = print + lang.dtype.to_ir = _new_to_ir + lang.multiple_of = partial(_set_attr, name="tt.divisiblity") + lang.max_contiguous = partial(_set_attr, name="tt.contiguity") + lang.max_constancy = partial(_set_attr, name="tt.constancy") + + _patch_reduce_scan() + + +def _patch_lang(fn): + langs = [value for _, value in fn.__globals__.items() if value in [tl, tl.core]] + assert len(langs) >= 1, "triton.language must be visible from within jit'd function" + for lang in langs: + _patch_builtin(lang, interpreter_builder) + _patch_builtin(lang.tensor, interpreter_builder) + if lang == tl: + _patch_builtin(lang.math, interpreter_builder) + _patch_lang_tensor(lang.tensor) + _patch_lang_core(lang) + + # Patch all modules in fn's globals that might be extension modules + for name, value in list(fn.__globals__.items()): + if value is None: + continue + try: + # Check if it looks like an extension module (has builtin functions) + if hasattr(value, '__name__') and 'extension' in str(value.__name__): + _patch_builtin(value, interpreter_builder) + # Also try patching any module-like object that might have builtin functions + elif hasattr(value, '__dict__') and not isinstance(value, type): + # Try to patch it and ignore if it fails + try: + _patch_builtin(value, interpreter_builder) + except Exception: + pass + except Exception: + pass + + # Also try importing extension directly as fallback + try: + import triton.language.extra.cann.extension as extension + _patch_builtin(extension, interpreter_builder) + except (ImportError, AttributeError): + # Extension module not available (e.g., non-Ascend backend) + pass + + +# TODO: wrap everything in triton tensors +def _implicit_cvt(arg): + if isinstance(arg, int): + ty = tl.str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg))) + dtype = np.int32 + if -2**31 <= arg < 2**31: + dtype = np.int32 + elif 2**31 <= arg < 2**32: + dtype = np.uint32 + elif -2**63 <= arg < 2**63: + dtype = np.int64 + elif 2**63 <= arg < 2**64: + dtype = np.uint64 + else: + raise ValueError(f"Unsupported integer value {arg}") + handle = TensorHandle(np.array([arg], dtype=dtype), ty) + return tl.tensor(handle, ty) + if hasattr(arg, "data_ptr"): + ty = tl.str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg))) + handle = TensorHandle(np.array([arg.data_ptr()], dtype=np.uint64), ty) + return tl.tensor(handle, ty) + return arg + + +interpreter_builder = InterpreterBuilder() + +# These keywords are not supported by the interpreter +RESERVED_KWS = ["num_warps", "num_stages", "num_ctas", "enable_fp_fusion", "grid", "maxnreg", "multibuffer"] + + +class GridExecutor: + + def __init__(self, fn, arg_names, grid): + from .jit import _normalize_ty # TODO: modularize + + self.fn = fn + self.arg_names = arg_names + self.grid = grid + __annotations__ = {name: _normalize_ty(ty) for name, ty in fn.__annotations__.items()} + self.constexprs = [name for name in arg_names if __annotations__.get(name) == "constexpr"] + + def _init_args_hst(self, args_dev, kwargs): + args_hst = [] + for arg in args_dev: + if hasattr(arg, "data_ptr"): + args_hst.append(arg.cpu()) + else: + args_hst.append(arg) + # Process keyword arguments + kwargs_hst = {} + for key, value in kwargs.items(): + if hasattr(value, "data_ptr"): + kwargs_hst[key] = value.cpu() + else: + kwargs_hst[key] = value + return args_hst, kwargs_hst + + def _restore_args_dev(self, args_dev, args_hst, kwargs, kwargs_hst): + for arg_dev, arg_hst in zip(args_dev, args_hst): + if hasattr(arg_dev, "data_ptr"): + arg_dev.data.copy_(arg_hst.to(arg_dev.device).data) + + # Restore keyword arguments + for key, kwarg_dev in kwargs.items(): + kwarg_hst = kwargs_hst[key] + if hasattr(kwarg_dev, "data_ptr"): + kwarg_dev.data.copy_(kwarg_hst.to(kwarg_dev.device).data) + + def __call__(self, *args_dev, **kwargs): + # removes reserved keywords from kwargs + kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS} + if kwargs.pop("warmup", False): + return + # copy arguments to the host + args_hst, kwargs_hst = self._init_args_hst(args_dev, kwargs) + # remaps core language functions to interpreted ones + _patch_lang(self.fn) + # we need to copy arguments to the host for the interpreter + # implicitly convert tensor arguments to their base pointers + args = inspect.getcallargs(self.fn, *args_hst, **kwargs_hst) + args = {name: arg if name in self.constexprs else _implicit_cvt(arg) for name, arg in args.items()} + # iterate through grid + grid = self.grid(args) if callable(self.grid) else self.grid + assert len(grid) <= 3, "grid must have at most 3 dimensions" + grid = grid + (1, ) * (3 - len(grid)) + interpreter_builder.set_grid_dim(*grid) + + # Infer the number of sub-vector cores from kernel parameters + # Check for M and sub_M parameters (common pattern for 1:2 ratio) + num_sub_vec_ids = 1 + if 'M' in args and 'sub_M' in args: + M = args['M'] + sub_M = args['sub_M'] + # Extract scalar values if they're TensorHandle + if isinstance(M, TensorHandle): + M = int(M.data.item() if hasattr(M.data, 'item') else M.data) + if isinstance(sub_M, TensorHandle): + sub_M = int(sub_M.data.item() if hasattr(sub_M.data, 'item') else sub_M.data) + # Number of vector cores = M / sub_M + if isinstance(M, int) and isinstance(sub_M, int) and sub_M > 0: + num_sub_vec_ids = max(1, M // sub_M) + + try: + # Loop over sub-vector IDs to simulate parallel vector core execution + for sub_vec_id in range(num_sub_vec_ids): + interpreter_builder.sub_vec_id = sub_vec_id + for x in range(grid[0]): + for y in range(grid[1]): + for z in range(grid[2]): + interpreter_builder.set_grid_idx(x, y, z) + self.fn(**args) + except Exception as e: + raise InterpreterError(repr(e)) from e + # copy arguments back to propagate side-effects + self._restore_args_dev(args_dev, args_hst, kwargs, kwargs_hst) + + +class ASTTransformer(ast.NodeTransformer): + + def visit_Assign(self, node): + names = [] + for target in node.targets: + names += [self.visit(target)] + if len(names) > 1: + raise ValueError("Multiple assignments are not supported") + # Modify the assignment x = value to + # triton.language.semantic.to_tensor(value, interpreter_builder, False) + node.value = ast.Call( + func=ast.Attribute( + value=ast.Attribute( + value=ast.Attribute(value=ast.Name(id='triton', ctx=ast.Load()), attr='language', ctx=ast.Load()), + attr='semantic', ctx=ast.Load()), attr='to_tensor', ctx=ast.Load()), + args=[node.value, ast.Name(id='interpreter_builder', ctx=ast.Load()), + ast.Constant(value=False)], keywords=[]) + return node + + +class FunctionRewriter: + ast_transformer = ASTTransformer() + + def __init__(self, fn, **kwargs): + self.fn = fn + self.kwargs = kwargs + self.filename: str = "" + # Absolute line number in the file + self.def_file_lineno: int = 0 + + def rewrite_ast(self): + # If exception is raise, it means the function does not have source code available, + # e.g., dynamically generated functions, we cannot rewrite it so just return the original function + try: + lines, _ = inspect.getsourcelines(self.fn) + except Exception: + return self.fn + + # truncate lines before def + # @triton.autotune(...) + # ... + # @triton.jit + # ... + # def foo(...): <- this line is the function definition + self.filename, self.def_file_lineno = self._get_jit_fn_file_line() + self.def_lineno = self._find_def(lines) + src = self._prepare_source(lines) + transformed_ast = self._transform_ast(src) + return self._compile_and_exec(transformed_ast) + + def _get_jit_fn_file_line(self): + from .jit import get_jit_fn_file_line, JITFunction + return get_jit_fn_file_line(JITFunction(self.fn)) + + def _find_def(self, lines): + def_lineno = 0 + # Line numbers start from 1 + for i, line in enumerate(lines): + if line.strip().startswith("def "): + def_lineno = i + 1 + return def_lineno + + def _prepare_source(self, lines): + lines = lines[self.def_lineno - 1:] + src = ''.join(lines) + return textwrap.dedent(src) + + def _transform_ast(self, src): + # src is like: + # 1: def foo(...): + # 2: ... + parsed_ast = ast.parse(src) + transformed_ast = self.ast_transformer.visit(parsed_ast) + ast.fix_missing_locations(transformed_ast) + inc_lineno = self.def_file_lineno - 1 + ast.increment_lineno(transformed_ast, inc_lineno) + return transformed_ast + + def _compile_and_exec(self, transformed_ast): + compiled_code = compile(transformed_ast, filename=self.filename, mode='exec') + local_namespace = {**self.kwargs} + fn_globals = self.fn.__globals__ + for key, value in globals().items(): + if key not in fn_globals: + fn_globals[key] = value + exec(compiled_code, fn_globals, local_namespace) + return local_namespace[self.fn.__name__] + + +class InterpretedFunction: + # Cache all rewritten functions + rewritten_fn = {} + + def __init__(self, fn, **kwargs) -> None: + self.fn = fn + self.rewriter = FunctionRewriter(fn, **kwargs) + + def run(*args, **kwargs): + grid = kwargs["grid"] + fn = self.rewrite() + return GridExecutor(fn, self.arg_names, grid)(*args, **kwargs) + + self.run = run + signature = inspect.signature(fn) + self.arg_names = [v.name for v in signature.parameters.values()] + + def rewrite(self): + if self.fn not in self.rewritten_fn: + self.rewritten_fn[self.fn] = self.rewriter.rewrite_ast() + return self.rewritten_fn[self.fn] + + @property + def __name__(self): + return self.fn.__name__ + + def __getitem__(self, grid): + fn = self.rewrite() + return GridExecutor(fn, self.arg_names, grid) + + def __call__(self, *args, **kwargs): + # This is a device function call + _patch_lang(self.fn) + fn = self.rewrite() + try: + return fn(*args, **kwargs) + except Exception as e: + raise InterpreterError(repr(e)) from e diff --git a/third_party/ascend/backend/spec/triton/runtime/jit.py b/third_party/ascend/backend/spec/triton/runtime/jit.py index 7c6f17457..45178a40b 100644 --- a/third_party/ascend/backend/spec/triton/runtime/jit.py +++ b/third_party/ascend/backend/spec/triton/runtime/jit.py @@ -1,66 +1,951 @@ -def enable_stream_in_kwargs(kwargs): - return True - - -def ignore_params_in_JITFunction_run(kwargs, excess_kwargs): - ignor_params = ["debug", "sanitize_overflow", "llvm_version", "kernel_name", \ - "allowed_dot_input_precisions", "multibuffer", "stream", "inject_barrier_all", \ - "inject_block_all", "limit_auto_multi_buffer_only_for_local_buffer"] - not_work_params = [] - for k in kwargs: - if k in ignor_params: - continue - elif k in excess_kwargs: - not_work_params.append(k) - if len(not_work_params) != 0: - print("[WARNING] Please DO NOT tune args {}!".format(not_work_params)) - - -def check_grid_size(grid_0, grid_1, grid_2): - import os - grid_all_size = grid_0 * grid_1 * grid_2 - if os.getenv("TRITON_ALL_BLOCKS_PARALLEL", "0") == "0": - if grid_all_size > 65535: +from __future__ import annotations, division +import ast +import hashlib +import inspect +import itertools +import os +import re +import textwrap +from collections import defaultdict +from functools import cached_property +from typing import Callable, Generic, Iterable, Optional, TypeVar, Union, overload, Dict, Any, Tuple +from ..runtime.driver import driver +from types import ModuleType + +TRITON_MODULE = __name__[:-len(".runtime.jit")] + +T = TypeVar("T") + +# ----------------------------------------------------------------------------- +# Dependencies Finder +# ----------------------------------------------------------------------------- + + +class DependenciesFinder(ast.NodeVisitor): + """ + This AST visitor is used to find dependencies of a JITFunction. This can + be used to invalidate a JITFunction's hash when its source code -- or + that of its dependencies -- changes. + + This visitor also keeps track of the global variables touched by the + JITFunction. When we launch the kernel, we check that these have the same + values as they did when we ran this visitor. If not, we raise an error (or + otherwise we could recompile). + """ + + def __init__(self, name, globals, src) -> None: + super().__init__() + self.name = name + self.hasher = hashlib.sha256(src.encode("utf-8")) + + # This function's __globals__ dict. + self.globals = globals + + # Python builtins that can be accessed from Triton kernels. + self.supported_python_builtins = { + 'float', + 'getattr', + 'int', + 'isinstance', + 'len', + 'list', + 'max', + 'min', + 'print', + 'range', + } + + # used_global_vals tells us which global variables are used by this + # function and all those it transitively calls, plus the values of those + # variables when each function was initially run. (That is, if A calls + # C, and B calls C, then the values for C in used_global_vals will be + # from the first time C was run, either by A or B.) + # + # Each function may have a different __globals__ dict, so the global + # variable `foo` may actually have a different value in the different + # functions. Thus this map is actually + # (var_name, id(__globals__)) -> (var_value, __globals__). + self.used_global_vals: Dict[Tuple[str, int], Tuple[Any, Dict[str, Any]]] = {} + + self.visiting_arg_default_value = False + + @property + def ret(self): + return self.hasher.hexdigest() + + def _is_triton_builtin(self, node, func): + if inspect.isbuiltin(node.func): + return True + module = getattr(func, "__module__", "") + return module.startswith(TRITON_MODULE) + + def _update_hash(self, func): + if isinstance(func, JITFunction): + # Merge our used_global_vals with those of the called function, + # after checking that all overlapping values are consistent. + for k in self.used_global_vals.keys() & func.used_global_vals.keys(): + var_name, _ = k + v1, _ = self.used_global_vals[k] + v2, _ = func.used_global_vals[k] + if v1 != v2: + raise RuntimeError( + f"Global variable {var_name} has value {v1} when compiling {self.name}, but inner kernel {func.__name__} has conflicting value {v2} from when it was first compiled. This is not allowed." + ) + self.used_global_vals.update(func.used_global_vals) + # update hash + func_key = func.cache_key + func_key += str(getattr(func, "noinline", False)) + self.hasher.update(func_key.encode("utf-8")) + + def visit_Name(self, node): + if type(node.ctx) is ast.Store: + return node.id + + if node.id in self.local_names: + # The global name is hidden by the local name. + return None + + val = self.globals.get(node.id, None) + + # Only keep track of "interesting" global variables, that non-evil users + # might change. Don't consider functions, modules, builtins, etc. This + # helps keep the list of vars we have to check small. + if (val is not None # + # Python default arguments are resolved only once, when the + # function is defined. So if you do `foo(a=A)` and the value of + # A changes, foo will still use the old value of A. + and not self.visiting_arg_default_value + # It would be pretty evil if someone did `import x` and then + # `x = blah`. + and type(val) is not ModuleType + # It would be pretty evil if we used function `foo` inside of + # `bar` and then someone did `foo = baz`. + and not isinstance(val, JITFunction) and not getattr(val, "__triton_builtin__", False) # + and node.id not in self.supported_python_builtins): + self.used_global_vals[(node.id, id(self.globals))] = (val, self.globals) + + self._update_hash(val) + return val + + def visit_Tuple(self, node): + # We need to explicitly return the tuple values so that visit_Assign can + # access them in the case of `a, b = ...`. + return [self.visit(elt) for elt in node.elts] + + def visit_Attribute(self, node): + lhs = self.visit(node.value) + while isinstance(lhs, ast.Attribute): + lhs = self.visit(lhs.value) + if lhs is None or (getattr(lhs, "__name__", "") == TRITON_MODULE): + return None + ret = getattr(lhs, node.attr) + self._update_hash(ret) + return ret + + def visit_FunctionDef(self, node): + # Save the local name, which may hide the global name. + self.local_names = {arg.arg for arg in node.args.args} + self.generic_visit(node) + + def visit_arguments(self, node): + # The purpose of this function is to visit everything in `arguments` + # just like `generic_visit`, except when we're visiting default values + # (i.e. the `foo` part of `def fn(x = foo)`), we set + # self.visiting_arg_default_value = True. This allows visit_Name to be + # aware that we're inside function default values, which have special + # semantics. + + # According to the AST docs, the arguments node has the following structure. + # + # arguments = (arg* posonlyargs, arg* args, arg? vararg, arg* kwonlyargs, + # expr* kw_defaults, arg? kwarg, expr* defaults) + def visit_defaults(defaults): + try: + assert not self.visiting_arg_default_value + self.visiting_arg_default_value = True + for expr in defaults: + if expr is not None: + self.visit(expr) + finally: + self.visiting_arg_default_value = False + + for arg in itertools.chain(node.posonlyargs, node.args, [node.vararg] if node.vararg else [], node.kwonlyargs): + self.visit(arg) + + visit_defaults(node.kw_defaults) + + if node.kwarg is not None: + self.visit(node.kwarg) + + visit_defaults(node.defaults) + + def visitAssnTarget(self, node): + # Target is either a single string, or a list of strings (if the assn + # target is a tuple). + target = self.visit(node) + if isinstance(target, list): + self.local_names |= set(target) + else: + self.local_names.add(target) + + def visit_Assign(self, node): + if len(node.targets) != 1: + # TODO(jlebar): I don't actually know how to hit this. You don't + # get it from `a, b = ...` -- in that case, node.targets is a single + # Tuple, and in fact we *do* need to handle that case if we want + # existing code to work. + raise TypeError("Simultaneous multiple assignment is not supported.") + + self.visitAssnTarget(node.targets[0]) + + # This will re-visit the target, but that's OK. + self.generic_visit(node) + + def visit_AnnAssign(self, node): + self.visitAssnTarget(node.target) + + # This will re-visit the target, but that's OK. + self.generic_visit(node) + + def visit_For(self, node): + self.visitAssnTarget(node.target) + + # This will re-visit the target, but that's fine. + self.generic_visit(node) + + +# ----------------------------------------------------------------------------- +# JITFunction +# ----------------------------------------------------------------------------- + + +def _normalize_ty(ty) -> str: + if isinstance(ty, type): + return ty.__name__ + elif isinstance(ty, str): + return ty + return repr(ty) + + +class KernelParam: + """Represents a parameter (name plus metadata) to a @jit'ed function.""" + + def __init__(self, num: int, param: inspect.Parameter, do_not_specialize: bool, + do_not_specialize_on_alignment: bool): + self.num = num + self._param = param + self.do_not_specialize = do_not_specialize + self.do_not_specialize_on_alignment = do_not_specialize_on_alignment + + @cached_property + def name(self): + return self._param.name + + @cached_property + def annotation(self): + if not self._param.annotation or self._param.annotation == inspect.Parameter.empty: + return "" + return _normalize_ty(self._param.annotation) + + @cached_property + def annotation_type(self): + annotation = self.annotation + for ty1, ty2 in [("uint", 'u'), ("int", 'i')]: + width = annotation[annotation.find(ty1) + len(ty1):] + if width and ty1 in annotation: + return f"{ty2}{width}" + if annotation == "bool": + return "u1" + return "" + + @cached_property + def is_constexpr(self): + return "constexpr" in self.annotation + + @cached_property + def is_const(self): + return "const" in self.annotation and not self.is_constexpr + + @property + def default(self): + return self._param.default + + @property + def has_default(self): + return self._param.default != inspect.Parameter.empty + + +def compute_spec_key(v, align): + + if align and hasattr(v, "data_ptr") and (v.data_ptr() % 16 == 0): + return "D" + elif isinstance(v, int): + # bool is a subclass of int, so we don't check explicitly above. + if align and (v % 16 == 0): + return "D" + elif v == 1: + return "1" + return "N" + + +dtype2str = {} + + +def mangle_type(arg, is_const=False): + + if arg is None: + return "none" + elif isinstance(arg, bool): + return "i1" + elif isinstance(arg, int): + if -(2**31) <= arg and arg <= 2**31 - 1: + return "i32" + elif 2**63 <= arg and arg <= 2**64 - 1: + return "u64" + else: + return "i64" + elif isinstance(arg, float): + return "fp32" + elif hasattr(arg, "tma_desc_cpu_ptr"): + return "nvTmaDesc" + else: + # dtypes are hashable so we can memoize this mapping: + dsk = (arg.dtype, is_const) + res = dtype2str.get(dsk, None) + if res is None: + res = ("*k" if dsk[1] else "*") + type_canonicalisation_dict[str(dsk[0]).split('.')[-1]] + dtype2str[dsk] = res + return res + + +class KernelInterface(Generic[T]): + run: T + + def __getitem__(self, grid) -> T: + """ + A JIT function is launched with: fn[grid](*args, **kwargs). + Hence JITFunction.__getitem__ returns a callable proxy that + memorizes the grid. + """ + return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) + # return cast(T, functools.partial(cast(Callable, self.run), grid=grid)) + + +def serialize_specialization_data(name, signature, constants, attrs, options, key): + constants = {key: str(value) if value.__class__.__name__ == "dtype" else value for key, value in constants.items()} + import json + obj = { + 'name': name, 'signature': signature, 'constants': constants, 'attrs': attrs.to_dict(), 'options': + options.__dict__, 'key': key + } + serialized_obj = json.dumps(obj) + return serialized_obj + + +def create_function_from_signature(sig, kparams, backend): + """ + Equivalent to sig.bind followed by apply_defaults. This generates a + native Python function (using exec) which can be memoized on a per-kernel + basis to avoid having to run these expensive functions -- which constitute + much of the kernel launch overhead -- every time we run the kernel. + """ + + assert len(sig.parameters) == len(kparams) + + # Create the function argument list and the dict entries for the return statement + func_args = [] + dict_entries = [] + constexpr_vals = [] + non_constexpr_vals = [] + signature_types = [] + specialisations = [] + + for ((name, sp), kp) in zip(sig.parameters.items(), kparams): + if sp.default is inspect.Parameter.empty: + func_args.append(name) + dict_entries.append(f"'{name}': {name}") + else: + func_args.append(f"{name}=default_{name}") + dict_entries.append(f"'{name}': {name}") + if kp.is_constexpr: + constexpr_vals.append(name) + else: + non_constexpr_vals.append(name) + if not kp.do_not_specialize: + if not kp.do_not_specialize_on_alignment: + specialisations.append('compute_spec_key(%s, align=True)' % name) + else: + specialisations.append('compute_spec_key(%s, align=False)' % name) + if kp.annotation_type: + signature_types.append('"%s"' % kp.annotation_type) + else: + signature_types.append('mangle_type(%s, %s)' % (name, 'True' if kp.is_const else 'False')) + + cache_key = ''.join([x + ', ' for x in signature_types + specialisations]) + constexpr_vals = ''.join([x + ', ' for x in constexpr_vals]) + non_constexpr_vals = ''.join([x + ', ' for x in non_constexpr_vals]) + + func_args.append('**excess_kwargs') + + # Join all arguments into a function definition string + args_str = ', '.join(func_args) + dict_str = ', '.join(dict_entries) + func_body = "def dynamic_func(%s):\n return {%s}, (%s), (%s), (%s), excess_kwargs" % ( + args_str, dict_str, cache_key, constexpr_vals, non_constexpr_vals) + + # Prepare defaults to be inserted into function namespace + func_namespace = { + f"default_{name}": param.default + for name, param in sig.parameters.items() + if param.default is not inspect.Parameter.empty + } + + func_namespace['mangle_type'] = mangle_type + func_namespace['compute_spec_key'] = backend.compute_spec_key + + # Execute the function string in func_namespace to create the function + exec(func_body, func_namespace) + + # Extract the newly created function from the namespace + return func_namespace['dynamic_func'] + + +type_canonicalisation_dict = { + "bool": "i1", + "float8e4nv": "fp8e4nv", + "float8e5": "fp8e5", + "float8e4b15": "fp8e4b15", + "float8_e4m3fn": "fp8e4nv", + "float8e4b8": "fp8e4b8", + "float8_e4m3fnuz": "fp8e4b8", + "float8_e5m2": "fp8e5", + "float8e5b16": "fp8e5b16", + "float8_e5m2fnuz": "fp8e5b16", + "float16": "fp16", + "bfloat16": "bf16", + "float32": "fp32", + "float64": "fp64", + "int8": "i8", + "int16": "i16", + "int32": "i32", + "int64": "i64", + "uint8": "u8", + "uint16": "u16", + "uint32": "u32", + "uint64": "u64", +} + +for v in list(type_canonicalisation_dict.values()): + type_canonicalisation_dict[v] = v + + +class JITFunction(KernelInterface[T]): + # Hook for inspecting compiled functions and modules + cache_hook = None + # Hook to signal that a kernel is done compiling and inspect compiled function. + # cache_hook will always be called before compilation and compiled_hook after. + compiled_hook = None + + @staticmethod + def _key_of(arg): + if hasattr(arg, "dtype"): + return arg.dtype + elif isinstance(arg, bool): + return "i1" + elif isinstance(arg, int): + if -(2**31) <= arg and arg <= 2**31 - 1: + return "i32" + elif 2**63 <= arg and arg <= 2**64 - 1: + return "u64" + else: + return "i64" + elif isinstance(arg, float): + return "fp32" + elif arg is None: + return None + else: + raise TypeError(f"Unsupported type {type(arg)} for {arg}") + + @staticmethod + def _type_of(key, is_const=False): + # `None` is nullptr. Implicitly convert to *i8. + if key is None: + return "*i8" + elif isinstance(key, str): + return key + + dtype_str = str(key).split(".")[-1] + dtype_str = type_canonicalisation_dict[dtype_str] + const_str = "*k" if is_const else "*" + return const_str + dtype_str + + def _make_constants(self, constexpr_key): + constants = dict(zip(self.constexprs, constexpr_key)) + return constants + + def _call_hook( + self, + key, + signature, + device, + constants, + options, + configs, + is_warmup, + before, + ): + hook = JITFunction.cache_hook if before else JITFunction.compiled_hook + if hook is None: + return False + + name = self.fn.__name__ + module = self.fn.__module__ + arg_reprs = ", ".join([f"{param.name}: {ty}" for param, ty in zip(self.params, key[1])]) + repr = f"{name}[num_warps={options.num_warps}, num_ctas={options.num_ctas}, num_stages={options.num_stages}, enable_fp_fusion={options.enable_fp_fusion}]({arg_reprs})" + + class JitFunctionInfo: + + def __init__(self, module, name, jit_function): + self.module = module + self.name = name + self.jit_function = jit_function + pass + + specialization_data = serialize_specialization_data(name, signature, constants, configs[0], options, key) + + kwargs = { + 'signature': signature, + 'device': device, + 'constants': constants, + 'num_warps': options.num_warps, + 'num_ctas': options.num_ctas, + 'num_stages': options.num_stages, + 'enable_fp_fusion': options.enable_fp_fusion, + 'extern_libs': options.extern_libs, + 'configs': configs, + 'specialization_data': specialization_data, + 'is_warmup': is_warmup, + } + + return hook( + key=key, + repr=repr, + fn=JitFunctionInfo(module, name, self), + compile={"key": key, **kwargs}, + is_manual_warmup=is_warmup, + already_compiled=False, + ) + + def add_pre_run_hook(self, hook): + ''' + Add a hook that will be executed prior to the execution of run + function with args and kwargs passed into the kernel + ''' + assert callable(hook) + self.pre_run_hooks.append(hook) + + def create_binder(self, backend): + """ + Precompute as much as possible. + """ + from ..compiler import CompiledKernel, compile, ASTSource, make_backend + self.CompiledKernel = CompiledKernel + self.compile = compile + self.ASTSource = ASTSource + self.make_backend = make_backend + self.binder = create_function_from_signature(self.signature, self.params, backend) + self.constexpr_indices = [i for (i, p) in enumerate(self.params) if p.is_constexpr] + self.non_constexpr_indices = [i for (i, p) in enumerate(self.params) if not p.is_constexpr] + self.specialised_indices = [ + i for (i, p) in enumerate(self.params) if (not p.do_not_specialize) and (not p.is_constexpr) + ] + + def run(self, *args, grid, warmup, **kwargs): + kwargs["debug"] = kwargs.get("debug", False) or os.environ.get("TRITON_DEBUG", "0") == "1" + + # parse options + from ..compiler import make_backend + device = driver.active.get_current_device() + stream = driver.active.get_current_stream(device) + target = driver.active.get_current_target() + backend = make_backend(target) + + # Execute pre run hooks with args and kwargs + for hook in self.pre_run_hooks: + hook(*args, **kwargs) + + if self.binder is None: + self.create_binder(backend) + + bound_args, sig_and_spec, constexpr_vals, non_constexpr_vals, excess_kwargs = self.binder(*args, **kwargs) + + # compute cache key + key = ''.join(sig_and_spec) + str((constexpr_vals, excess_kwargs)) + kernel = self.cache[device].get(key, None) + + if kernel is None: + # Kernel is not cached; we have to compile. + options = backend.parse_options(kwargs) + + # deprecated arguments + assert "device_type" not in kwargs, "device_type option is deprecated; current target will be used" + assert "device" not in kwargs, "device option is deprecated; current device will be used" + assert "stream" not in kwargs, "stream option is deprecated; current stream will be used" + for k in excess_kwargs: + if k not in options.__dict__: + raise KeyError("Keyword argument %s was specified but unrecognised" % k) + + bound_vals = tuple(bound_args.values()) + + # `None` is nullptr. Implicitly convert to *i8. This needs to be + # done here rather than when we build the signature as otherwise + # the kernel cache key could not distinguish between byte pointers + # and None arguments, resulting in a downstream mismatch: + sigkeys = [self.params[i].name for i in self.non_constexpr_indices] + sigvals = sig_and_spec[:len(sigkeys)] + signature = {k: ('*i8' if (v == 'none') else v) for (k, v) in zip(sigkeys, sigvals)} + + configs = (backend.get_attrs_descriptor(self.params, bound_vals), ) + constant_params = configs[0].get_constants() + constants = { + p.name: v + for (v, p) in zip(bound_vals, self.params) + if p.is_constexpr or (p.num in constant_params) or v is None + } + for i, arg in constants.items(): + if callable(arg): + raise TypeError(f"Callable constexpr at index {i} is not supported") + + if self._call_hook(key, signature, device, constants, options, configs, warmup, before=True): + return None + # compile the kernel + src = self.ASTSource(self, signature, constants, configs[0]) + kernel = self.compile( + src, + target=target, + options=options.__dict__, + ) + self.cache[device][key] = kernel + self._call_hook(key, signature, device, constants, options, configs, warmup, before=False) + + # Check that used global values have not changed. + not_present = object() + for (name, _), (val, globals_dict) in self.used_global_vals.items(): + if (newVal := globals_dict.get(name, not_present)) != val: + raise RuntimeError( + f"Global variable {name} has changed since we compiled this kernel, from {val} to {newVal}") + + if not warmup: + # canonicalize grid + assert grid is not None + if callable(grid): + # Arguments are passed as a dict to `grid`, by contract. + # TODO(jlebar): In the new launch API, pass the compiler flags as a + # second parameter to `grid`. + grid = grid(bound_args) + grid_size = len(grid) + grid_0 = grid[0] + grid_1 = grid[1] if grid_size > 1 else 1 + grid_2 = grid[2] if grid_size > 2 else 1 + + # launch kernel + launch_metadata = kernel.launch_metadata(grid, stream, *non_constexpr_vals) + kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata, + self.CompiledKernel.launch_enter_hook, self.CompiledKernel.launch_exit_hook, *non_constexpr_vals) + return kernel + + def __init__(self, fn, version=None, do_not_specialize=None, do_not_specialize_on_alignment=None, debug=None, + noinline=None, repr=None, launch_metadata=None): + do_not_specialize = do_not_specialize if do_not_specialize else [] + do_not_specialize_on_alignment = do_not_specialize_on_alignment if do_not_specialize_on_alignment else [] + + self.fn = fn + self.module = fn.__module__ + self.version = version + self.signature = inspect.signature(fn) + self.do_not_specialize = do_not_specialize + self.do_not_specialize_on_alignment = do_not_specialize_on_alignment + self.starting_line_number = inspect.getsourcelines(fn)[1] + self.repr = lambda _: fn.__name__ if repr is None else repr(_) + self.launch_metadata = launch_metadata + + self.binder = None + + self.params = [] + for i, param in enumerate(self.signature.parameters.values()): + dns = i in do_not_specialize or param.name in do_not_specialize + dns_oa = i in do_not_specialize_on_alignment or param.name in do_not_specialize_on_alignment + self.params.append(KernelParam(i, param, dns, dns_oa)) + + # function source code (without decorators) + self.src = textwrap.dedent(inspect.getsource(fn)) + self.src = self.src[re.search(r"^def\s+\w+\s*\(", self.src, re.MULTILINE).start():] + # cache of just-in-time compiled kernels + self.cache = defaultdict(dict) + self.hash = None + + # Map of global variables used by the function and any functions it + # transitively calls, plus their values. The values are collected when + # the function is first compiled. Then every time we run the function, + # we check that the values of the globals match what's expected, + # otherwise we raise an error. + # + # Different functions can have different __globals__ maps, so the map + # key is actually (var name, id(__globals__)), and the map value is + # (value, __globals__). + self.used_global_vals: Dict[Tuple[str, int], Tuple[Any, Dict[str, Any]]] = {} + + # JITFunction can be instantiated as kernel + # when called with a grid using __getitem__ + self.kernel = None + self.noinline = noinline + + # TODO(jlebar): Remove uses of these fields outside this file, then + # remove the fields here. + self.arg_names = [p.name for p in self.params] + self.constexprs = [p.num for p in self.params if p.is_constexpr] + + # Hooks that will be called prior to executing "run" + self.pre_run_hooks = [] + + # reuse docs of wrapped function + self.__doc__ = fn.__doc__ + self.__name__ = fn.__name__ + self.__globals__ = fn.__globals__ + self.__module__ = fn.__module__ + + @property + def cache_key(self): + # TODO : hash should be attribute of `self` + if self.hash is None: + dependencies_finder = DependenciesFinder(name=self.__name__, globals=self.__globals__, src=self.src) + dependencies_finder.visit(self.parse()) + self.hash = dependencies_finder.ret + str(self.starting_line_number) + self.used_global_vals = dict(sorted(dependencies_finder.used_global_vals.items())) + return self.hash + + def warmup(self, *args, grid, **kwargs): + return self.run(grid=grid, warmup=True, *map(MockTensor.wrap_dtype, args), **kwargs) + + def preload(self, specialization_data): + from ..compiler import compile, ASTSource + from triton.backends.compiler import AttrsDescriptor + import json + import triton.language as tl + device = driver.active.get_current_device() + deserialized_obj = json.loads(specialization_data) + if deserialized_obj['name'] != self.fn.__name__: raise RuntimeError( - "grid should be less than 65536! You can try \"export TRITON_ALL_BLOCKS_PARALLEL=1\" to avoid this problem." + f"Specialization data is for {deserialized_obj['name']} but trying to preload for {self.fn.__name__}") + constants = { + key: tl.dtype(value) if tl.dtype.is_dtype(value) else value + for key, value in deserialized_obj['constants'].items() + } + signature = dict(deserialized_obj['signature'].items()) + src = ASTSource(self, signature, constants, AttrsDescriptor.from_dict(deserialized_obj['attrs'])) + options = { + key: tuple(value) if isinstance(value, list) else value + for key, value in deserialized_obj['options'].items() + } + key = deserialized_obj['key'] + kernel = compile(src, None, options) + self.cache[device][key] = kernel + return kernel + + # we do not parse `src` in the constructor because + # the user might want to monkey-patch self.src dynamically. + # Our unit tests do this, for example. + def parse(self): + tree = ast.parse(self.src) + assert isinstance(tree, ast.Module) + assert len(tree.body) == 1 + assert isinstance(tree.body[0], ast.FunctionDef) + return tree + + def __call__(self, *args, **kwargs): + raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel") + + def __setattr__(self, name, value): + super(JITFunction, self).__setattr__(name, value) + # - when `.src` attribute is set, cache path needs + # to be reinitialized + if name == "src": + self.hash = None + + def __repr__(self): + return f"JITFunction({self.module}:{self.fn.__name__})" + + +# ----------------------------------------------------------------------------- +# `jit` decorator +# ----------------------------------------------------------------------------- + + +@overload +def jit(fn: T) -> JITFunction[T]: + ... + + +@overload +def jit( + *, + version=None, + repr: Optional[Callable] = None, + launch_metadata: Optional[Callable] = None, + do_not_specialize: Optional[Iterable[int]] = None, + do_not_specialize_on_alignment: Optional[Iterable[int]] = None, + debug: Optional[bool] = None, + noinline: Optional[bool] = None, +) -> Callable[[T], JITFunction[T]]: + ... + + +def jit( + fn: Optional[T] = None, + *, + version=None, + repr: Optional[Callable] = None, + launch_metadata: Optional[Callable] = None, + do_not_specialize: Optional[Iterable[int]] = None, + do_not_specialize_on_alignment: Optional[Iterable[int]] = None, + debug: Optional[bool] = None, + noinline: Optional[bool] = None, +) -> Union[JITFunction[T], Callable[[T], JITFunction[T]]]: + """ + Decorator for JIT-compiling a function using the Triton compiler. + + :note: When a jit'd function is called, arguments are + implicitly converted to pointers if they have a :code:`.data_ptr()` method + and a `.dtype` attribute. + + :note: This function will be compiled and run on the GPU. It will only have access to: + + * python primitives, + * builtins within the triton package, + * arguments to this function, + * other jit'd functions + + :param fn: the function to be jit-compiled + :type fn: Callable + """ + + def decorator(fn: T) -> JITFunction[T]: + assert callable(fn) + if os.getenv("TRITON_INTERPRET", "0") == "1": + from .interpreter import InterpretedFunction + return InterpretedFunction(fn, version=version, do_not_specialize=do_not_specialize, + do_not_specialize_on_alignment=do_not_specialize_on_alignment, debug=debug, + noinline=noinline, repr=repr, launch_metadata=launch_metadata) + else: + return JITFunction( + fn, + version=version, + do_not_specialize=do_not_specialize, + do_not_specialize_on_alignment=do_not_specialize_on_alignment, + debug=debug, + noinline=noinline, + repr=repr, + launch_metadata=launch_metadata, ) + if fn is not None: + return decorator(fn) + + else: + return decorator + + +# ----------------------------------------------------------------------------- +# Utilities for mocking tensors +# ----------------------------------------------------------------------------- + + +class MockTensor: + """ + Can be used in place of real tensors when calling: + kernel.warmup(MockTensor(torch.float32), ...) + """ + + @staticmethod + def wrap_dtype(arg): + if arg.__class__.__name__ == "dtype" and arg.__module__ == "torch": + return MockTensor(arg) + return arg + + def __init__(self, dtype): + self.dtype = dtype + + @staticmethod + def data_ptr(): + return 0 # optimistically assumes multiple of 16 + + @staticmethod + def ptr_range(): + return 0 # optimistically assumes 32 bit pointer range + + +class TensorWrapper: + + def __init__(self, base, dtype): + self.dtype = dtype + self.base = base + self.data = base.data + self.device = base.device + self.shape = self.base.shape + + def data_ptr(self): + return self.base.data_ptr() -def explicit_load_kernel_library(kernel): - # explicitly define run method and load kernel binary - kernel._init_handles() + def stride(self, i): + return self.base.stride(i) + def __str__(self) -> str: + return f"TensorWrapper[{self.dtype}]({self.base})" -def get_JITFunction_spec_attr(deserialized_obj): - from triton.backends.ascend.compiler import AscendAttrsDescriptor - return AscendAttrsDescriptor.from_dict(deserialized_obj['attrs']) + def element_size(self): + return self.base.element_size() + def cpu(self): + return TensorWrapper(self.base.cpu(), self.dtype) -def maps_line_numbers_to_comment_hints(jit_fn): - import tokenize - from io import StringIO - # Maps line numbers to comment hints - line_flagtree_hints = {} - code_str = jit_fn.src - g = tokenize.generate_tokens(StringIO(code_str).readline) - for tok_type, tok_text, start, end, _ in g: - if tok_type == tokenize.COMMENT: - comment = tok_text.replace(" ", "").strip() - if comment.startswith('#@hint:'): - flagtree_hints = comment[len('#@hint:'):].strip() - # Record the line number of the comment - line_num = start[0] - line_flagtree_hints[line_num] = flagtree_hints + def copy_(self, other): + self.base.copy_(other.base) - # print(f"[FLAGTREE] Parsed hint at line {line_num}: {flagtree_hints}") + def clone(self): + return TensorWrapper(self.base.clone(), self.dtype) - return line_flagtree_hints + def to(self, device): + return TensorWrapper(self.base.to(device), self.dtype) -def attach_line_number_to_comment_mapping(tree, line_flagtree_hints): - # Attach the line number to comment mapping to the function definition node - tree.body[0].line_flagtree_hints = line_flagtree_hints +def reinterpret(tensor, dtype): + if isinstance(tensor, TensorWrapper): + if dtype == tensor.base.dtype: + # Reinterpreting to the original interpretation; return the base. + return tensor.base + else: + # Reinterpreting a wrapped tensor to a different type. + return TensorWrapper(tensor.base, dtype) + elif hasattr(tensor, "data_ptr"): + # A new wrapper is needed around an unwrapped tensor. + return TensorWrapper(tensor, dtype) + else: + raise TypeError(f"Cannot reinterpret a {type(tensor)}.") -def enable_extra_option(): - return True +def get_jit_fn_file_line(fn): + base_fn = fn + while not isinstance(base_fn, JITFunction): + base_fn = base_fn.fn + file_name = base_fn.fn.__code__.co_filename + lines, begin_line = inspect.getsourcelines(base_fn.fn) + # Match the following pattern: + # @triton.autotune(...) <- foo.__code__.co_firstlineno + # @triton.heuristics(...) + # @triton.jit + # def foo(...): <- this line is the first line + for idx, line in enumerate(lines): + if line.strip().startswith("def "): + begin_line += idx + break + return file_name, begin_line diff --git a/third_party/ascend/backend/spec/triton/runtime/libentry.py b/third_party/ascend/backend/spec/triton/runtime/libentry.py index 93240f276..a358b9ae8 100644 --- a/third_party/ascend/backend/spec/triton/runtime/libentry.py +++ b/third_party/ascend/backend/spec/triton/runtime/libentry.py @@ -1,16 +1,25 @@ -# Copyright © 2024 BAAI +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# Copyright 2018-2020 Philippe Tillet +# Copyright 2020-2022 OpenAI +# Copyright © 2024 BAAI. All rights reserved. # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: # -# http://www.apache.org/licenses/LICENSE-2.0 +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. # Modifications: # - 2025-06-03: @@ -26,7 +35,7 @@ import ast import triton -from triton._C import libentry_ascend +from triton._C import libentryC import torch import torch_npu @@ -212,7 +221,7 @@ def run(self, *args, **kwargs): dns_args = [] # do not specialize arguments const_args = [] # constexpr arguments k_args = [] # kernel arguments - arg_processor = libentry_ascend.ArgProcessor(self.divisibility) + arg_processor = libentryC.ArgProcessor(self.divisibility) arg_processor.classify_arguments(list(args), kwargs, self.jit_function.params, set(self.specialize_indices), set(self.do_not_specialize_indices)) diff --git a/third_party/ascend/backend/spec/triton/runtime/utils.py b/third_party/ascend/backend/spec/triton/runtime/utils.py deleted file mode 100644 index 4cd9b2feb..000000000 --- a/third_party/ascend/backend/spec/triton/runtime/utils.py +++ /dev/null @@ -1,78 +0,0 @@ -import torch - -from triton.runtime.driver import driver - -# npu hardware params -target = driver.active.get_current_target() -device = driver.active.get_current_device() -prop = driver.active.utils.get_device_properties(device) - -num_cube_core = prop["num_aicore"] -num_vector_core = prop["num_aicore"] -# flagtree backend specialization: num_ub_max -num_ub_max = 192 - -# flagtree backend specialization -ASCEND_VARIANTS = ["Ascend910B", "Ascend910_93", "Ascend910_95"] -if any(variant in target.arch for variant in ASCEND_VARIANTS): - num_vector_core = num_cube_core * 2 - -# flagtree backend specialization -if '910_95' in target.arch: - num_ub_max = 256 - -# wrapper npu 32 bytes align, get and pass unalign info to triton meta -# then autotune choose tiling param and send them to bishengIR -byte_per_numel = { - torch.float32: 4, # torch.float32 or torch.float - torch.float64: 8, # torch.float64 or torch.double - torch.float16: 2, # torch.float16 or torch.half - torch.bfloat16: 2, # torch.bfloat16 - torch.int32: 4, # torch.int32 or torch.int - torch.int64: 8, # torch.int64 or torch.long - torch.int16: 2, # torch.int16 or torch.short - torch.int8: 1, # torch.int8 - torch.uint8: 1, # torch.uint8 - torch.bool: 1, # torch.bool - torch.complex32: 4, # torch.complex32 (not yet available in PyTorch as of the latest stable release) - torch.complex64: 8, # torch.complex64 - torch.complex128: 16, # torch.complex128 -} - -# flagtree backend specialization: replace '{}' with '[]' -valid_axis_names = [ - "x", - "y", - "z", - "w", - "v", - "t", - "rx", - "ry", - "rz", - "rw", - "rv", - "rt", -] - - -def get_byte_per_numel(dtype: torch.dtype) -> int: - return 1 if dtype is None else byte_per_numel[dtype] - - -def is_valid_axis_name(name: str) -> bool: - return name in valid_axis_names - - -# move to an appropriate place, currently duplicated with triton.__init__.py -def next_power_of_2(n: int): - """Return the smallest power of 2 greater than or equal to n""" - n -= 1 - n |= n >> 1 - n |= n >> 2 - n |= n >> 4 - n |= n >> 8 - n |= n >> 16 - n |= n >> 32 - n += 1 - return n diff --git a/third_party/ascend/backend/spec/triton/testing.py b/third_party/ascend/backend/spec/triton/testing.py index dd298d5b9..71cb8ab1e 100644 --- a/third_party/ascend/backend/spec/triton/testing.py +++ b/third_party/ascend/backend/spec/triton/testing.py @@ -1,228 +1,511 @@ -import torch +import functools import os -import multiprocessing -from concurrent.futures import ThreadPoolExecutor -from datetime import datetime, timezone -import logging -import builtins -from triton import runtime - - -def is_do_bench_npu(): - enable_bench_npu = os.getenv("TRITON_BENCH_METHOD", 'default').lower() == 'npu' - if torch.npu.is_available() and enable_bench_npu: - return True - return False - - -def collect_files(base_dir): - import pandas as pd - for root, dirs, files in os.walk(base_dir): - for file in files: - if file != 'op_statistic.csv': - continue - target_file = os.path.join(root, file) - df = pd.read_csv(target_file) - triton_rows = df[df['OP Type'].str.startswith('triton', na=False)] - if not triton_rows.empty: - return triton_rows['Avg Time(us)'].values[0] - return float('inf') - return float('inf') - - -def collect_single(base_dir: str, key: str = None) -> float: - if not os.path.exists(base_dir): - return float('inf') - - import pandas as pd - for root, _, files in os.walk(base_dir): - for file in files: - if file != 'op_statistic.csv': - continue - target_file = os.path.join(root, file) - df = pd.read_csv(target_file) - if key is not None: - key_rows = df[df['OP Type'].str.startswith(key, na=False)] - if not key_rows.empty: - return key_rows['Avg Time(us)'].values[0] - return float('inf') - else: - # default: read the first row except header - return df.loc[0, 'Avg Time(us)'] - - return float('inf') +import subprocess +import sys +from contextlib import contextmanager +from typing import Any, Dict, List +from . import language as tl +from . import runtime -def _rm_dic(keep_res, torch_path): - if keep_res: - return - import shutil - if os.path.exists(torch_path): - shutil.rmtree(torch_path) +def nvsmi(attrs): + attrs = ','.join(attrs) + cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits'] + out = subprocess.check_output(cmd) + ret = out.decode(sys.stdout.encoding).split(',') + ret = [int(x) for x in ret] + return ret -def _collect_mul_prof_result(base_dir: str, kernel_dict, total, key: str = None): - import numpy as np - import pandas as pd - tiling_dict = {} - kernel_details_file = None - for root, _, files in os.walk(base_dir): - for file in files: - if file == "kernel_details.csv": - kernel_details_file = os.path.join(root, file) - break - num_funcs = len(kernel_dict) - if kernel_details_file is None or os.path.exists(kernel_details_file) is False: - for config, _ in kernel_dict.items(): - tiling_dict[config] = [float('inf')] - return tiling_dict - df = pd.read_csv(kernel_details_file) - # filter out l2 cache clear operation - filter_cond = ~df["Name"].str.contains(r"zero|ZerosLike", case=False, na=False) - filter_df = df[filter_cond] - if key is not None: - key_rows = filter_df[filter_df["Name"].str.contains(key, na=False)] - else: - key_rows = filter_df - time_cost = [0] * num_funcs - for func_idx in np.arange(0, num_funcs): - for active_index in np.arange(0, total): - row_index = active_index + func_idx * total - time_cost[func_idx] += key_rows.iloc[row_index]["Duration(us)"] - time_cost = [x / total for x in time_cost] - for (config, avg_time) in zip(kernel_dict.keys(), time_cost): - tiling_dict[config] = [avg_time] - return tiling_dict - - -def do_bench_npu(fn, warmup=5, active=30, prof_dir=None, keep_res=False): - import torch_npu - import multiprocessing - from triton import runtime - - # warmup kernel - fn() - torch.npu.synchronize() - - experimental_config = torch_npu.profiler._ExperimentalConfig( - aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, - profiler_level=torch_npu.profiler.ProfilerLevel.Level1, l2_cache=False, data_simplification=False) - skip_first = 1 - wait = 0 - repeat = 1 - total = skip_first + (wait + warmup + active) * repeat - - if prof_dir is not None: - torch_path = prof_dir - else: - process = multiprocessing.current_process() - pid = process.pid - process_name = process.name - timestamp = datetime.now(tz=timezone.utc).strftime("%Y%m%d_%H%M%S") - base_path = os.path.join(runtime.cache.get_home_dir(), ".triton", "profile_results") - torch_path = os.path.join(base_path, f"prof_{timestamp}_{process_name}-{pid}") - with torch_npu.profiler.profile( - activities=[torch_npu.profiler.ProfilerActivity.NPU], - schedule=torch_npu.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=repeat, - skip_first=skip_first), - on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(torch_path), - record_shapes=False, - profile_memory=False, - with_stack=False, - with_flops=False, - with_modules=False, - experimental_config=experimental_config, - ) as prof: - for _ in builtins.range(total): +def _summarize_statistics(times, quantiles, return_mode): + import torch + if quantiles is not None: + ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist() + if len(ret) == 1: + ret = ret[0] + return ret + if return_mode == "all": + return times.tolist() + return getattr(torch, return_mode)(times).item() + + +def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mode="mean"): + """ + Benchmark the runtime of the provided function. + + :param fn: Function to benchmark + :type fn: Callable + :param rep: Repetition time (in ms) + :type rep: int + :param grad_to_none: Reset the gradient of the provided tensor to None + :type grad_to_none: torch.tensor, optional + :param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all" Default is "mean". + :type return_mode: str + """ + import torch + assert return_mode in ["min", "max", "mean", "median", "all"] + + with torch.cuda.stream(torch.cuda.Stream()): + # warmup + fn() + if grad_to_none is not None: + for x in grad_to_none: + x.detach_() + x.requires_grad_(True) + x.grad = None + # step 1 - we estimate the amount of time the kernel call takes + # NOTE: this estimate isn't super accurate because the GPU isn't warmed up at this point + # but it is probably good enough + # NOTE: we don't use a graph to estimate the runtime because creating a graph is expensive, + # ~300ms on A100, so we default to the same method used in `do_bench` (minus the L2 + # cache flush). + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(5): fn() - prof.step() - torch.npu.synchronize() - - time = collect_single(torch_path) - _rm_dic(keep_res, torch_path) - return time + end_event.record() + torch.cuda.synchronize() + estimate_ms = start_event.elapsed_time(end_event) / 5 + n_repeat = max(1, int(rep / estimate_ms)) + # step 2 - construct a cuda graph with `n_repeat` unrolled function calls to minimize + # host overhead + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + for _ in range(n_repeat): + if grad_to_none is not None: + for x in grad_to_none: + x.grad = None + fn() + torch.cuda.synchronize() + # measure time and return + ret = [] + n_retries = 10 + for _ in range(n_retries): + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + g.replay() + end_event.record() + torch.cuda.synchronize() + ret += [start_event.elapsed_time(end_event) / n_repeat] + return _summarize_statistics(torch.tensor(ret), quantiles, return_mode) + + +def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean"): + """ + Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with + the 20-th and 80-th performance percentile. + + :param fn: Function to benchmark + :type fn: Callable + :param warmup: Warmup time (in ms) + :type warmup: int + :param rep: Repetition time (in ms) + :type rep: int + :param grad_to_none: Reset the gradient of the provided tensor to None + :type grad_to_none: torch.tensor, optional + :param quantiles: Performance percentile to return in addition to the median. + :type quantiles: list[float], optional + :param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all" Default is "mean". :type return_mode: str + """ + assert return_mode in ["min", "max", "mean", "median", "all"] + import torch + di = runtime.driver.active.get_device_interface() -def do_bench_multiple_kernel_npu(kernel_dict, active=30, prof_dir=None, keep_res=False): + fn() + di.synchronize() + + cache = runtime.driver.active.get_empty_cache_for_benchmark() + + # Estimate the runtime of the function + start_event = di.Event(enable_timing=True) + end_event = di.Event(enable_timing=True) + start_event.record() + for _ in range(5): + cache.zero_() + fn() + end_event.record() + di.synchronize() + estimate_ms = start_event.elapsed_time(end_event) / 5 + + # compute number of warmup and repeat + n_warmup = max(1, int(warmup / estimate_ms)) + n_repeat = max(1, int(rep / estimate_ms)) + start_event = [di.Event(enable_timing=True) for i in range(n_repeat)] + end_event = [di.Event(enable_timing=True) for i in range(n_repeat)] + # Warm-up + for _ in range(n_warmup): + fn() + # Benchmark + for i in range(n_repeat): + # we don't want `fn` to accumulate gradient values + # if it contains a backward pass. So we clear the + # provided gradients + if grad_to_none is not None: + for x in grad_to_none: + x.grad = None + # we clear the L2 cache before each run + cache.zero_() + # record time of `fn` + start_event[i].record() + fn() + end_event[i].record() + # Record clocks + di.synchronize() + times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)], dtype=torch.float) + return _summarize_statistics(times, quantiles, return_mode) + + +def assert_close(x, y, atol=None, rtol=None, err_msg=''): + """ + Asserts that two inputs are close within a certain tolerance. + + :param x: The first input. + :type x: scala, list, numpy.ndarray, or torch.Tensor + :param y: The second input. + :type y: scala, list, numpy.ndarray, or torch.Tensor + :param atol: The absolute tolerance. Default value is 1e-2. + :type atol: float, optional + :param rtol: The relative tolerance. Default value is 0. + :type rtol: float, optional + :param err_msg: The error message to use if the assertion fails. + :type err_msg: str + """ + import numpy as np import torch - import torch_npu - from .compiler.errors import CompileTimeAssertionFailure, MLIRCompilationError, CompilationError - assert len(kernel_dict) > 0, f"ERROR: length of kernel_dict is {len(kernel_dict)}, no kernel is profiling." - - # warmup kernel - def run_fn(fn): - try: - fn() - except (CompileTimeAssertionFailure, MLIRCompilationError, CompilationError) as ex: - raise ex - - def run_all_fns(): - import psutil - max_workers = psutil.cpu_count(logical=False) - with ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = [] - for _, fn in kernel_dict.items(): - future = executor.submit(run_fn, fn) - futures.append(future) - for future in futures: + # canonicalize arguments to be tensors + if not isinstance(x, torch.Tensor): + x = torch.tensor(x) + if not isinstance(y, torch.Tensor): + y = torch.tensor(y) + # absolute tolerance + if atol is None: + atol = 1e-2 + atol = atol(x.dtype) if callable(atol) else atol + # relative tolerance hook + if rtol is None: + rtol = 0. + rtol = rtol(x.dtype) if callable(rtol) else rtol + # we use numpy instead of pytorch + # as it seems more memory efficient + # pytorch tends to oom on large tensors + if isinstance(x, torch.Tensor): + if x.dtype == torch.bfloat16: + x = x.float() + x = x.cpu().detach().numpy() + if isinstance(y, torch.Tensor): + if y.dtype == torch.bfloat16: + y = y.float() + y = y.cpu().detach().numpy() + # we handle size==1 case separately as we can + # provide better error message there + if x.size > 1 or y.size > 1: + np.testing.assert_allclose(x, y, atol=atol, rtol=rtol, equal_nan=True) + return + if not np.allclose(x, y, atol=atol, rtol=rtol): + raise AssertionError(f'{err_msg} {x} is not close to {y} (atol={atol}, rtol={rtol})') + + +class Benchmark: + """ + This class is used by the :code:`perf_report` function to generate line plots with a concise API. + """ + + def __init__( + self, + x_names: List[str], + x_vals: List[Any], + line_arg: str, + line_vals: List[Any], + line_names: List[str], + plot_name: str, + args: Dict[str, Any], + xlabel: str = '', + ylabel: str = '', + x_log: bool = False, + y_log: bool = False, + styles=None, + ): + """ + Constructor. + x_vals can be a list of scalars or a list of tuples/lists. If x_vals is a list + of scalars and there are multiple x_names, all arguments will have the same value. + If x_vals is a list of tuples/lists, each element should have the same length as + x_names. + + :param x_names: Name of the arguments that should appear on the x axis of the plot. + :type x_names: List[str] + :param x_vals: List of values to use for the arguments in :code:`x_names`. + :type x_vals: List[Any] + :param line_arg: Argument name for which different values correspond to different lines in the plot. + :type line_arg: str + :param line_vals: List of values to use for the arguments in :code:`line_arg`. + :type line_vals: List[Any] + :param line_names: Label names for the different lines. + :type line_names: List[str] + :param plot_name: Name of the plot. + :type plot_name: str + :param args: Dictionary of keyword arguments to remain fixed throughout the benchmark. + :type args: Dict[str, Any] + :param xlabel: Label for the x axis of the plot. + :type xlabel: str, optional + :param ylabel: Label for the y axis of the plot. + :type ylabel: str, optional + :param x_log: Whether the x axis should be log scale. + :type x_log: bool, optional + :param y_log: Whether the y axis should be log scale. + :type y_log: bool, optional + :param styles: A list of tuples, where each tuple contains two elements: a color and a linestyle. + :type styles: list[tuple[str, str]] + """ + self.x_names = x_names + self.x_vals = x_vals + self.x_log = x_log + self.line_arg = line_arg + self.line_vals = line_vals + self.line_names = line_names + self.y_log = y_log + self.styles = styles + # plot info + self.xlabel = xlabel + self.ylabel = ylabel + self.plot_name = plot_name + self.args = args + + +class Mark: + + def __init__(self, fn, benchmarks): + self.fn = fn + self.benchmarks = benchmarks + + def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: bool, diff_col=False, + save_precision=6, **kwrags): + import os + + import matplotlib.pyplot as plt + import pandas as pd + y_mean = bench.line_names + y_min = [f'{x}-min' for x in bench.line_names] + y_max = [f'{x}-max' for x in bench.line_names] + x_names = list(bench.x_names) + df = pd.DataFrame(columns=x_names + y_mean + y_min + y_max) + for x in bench.x_vals: + # x can be a single value or a sequence of values. + if not isinstance(x, (list, tuple)): + x = [x for _ in x_names] + + if len(x) != len(x_names): + raise ValueError(f"Expected {len(x_names)} values, got {x}") + x_args = dict(zip(x_names, x)) + + row_mean, row_min, row_max = [], [], [] + for y in bench.line_vals: + ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args, **kwrags) try: - future.result() - except Exception as ex: - logging.info(f"Exception raised while benchmarking function.{ex}") - - run_all_fns() + y_mean, y_min, y_max = ret + except TypeError: + y_mean, y_min, y_max = ret, None, None + row_mean += [y_mean] + row_min += [y_min] + row_max += [y_max] + df.loc[len(df)] = list(x) + row_mean + row_min + row_max + + if bench.plot_name: + plt.figure() + ax = plt.subplot() + # Plot first x value on x axis if there are multiple. + first_x = x_names[0] + for i, y in enumerate(bench.line_names): + y_min, y_max = df[y + '-min'], df[y + '-max'] + col = bench.styles[i][0] if bench.styles else None + sty = bench.styles[i][1] if bench.styles else None + ax.plot(df[first_x], df[y], label=y, color=col, ls=sty) + if not y_min.isnull().all() and not y_max.isnull().all(): + y_min = y_min.astype(float) + y_max = y_max.astype(float) + ax.fill_between(df[first_x], y_min, y_max, alpha=0.15, color=col) + ax.legend() + ax.set_xlabel(bench.xlabel or first_x) + ax.set_ylabel(bench.ylabel) + # ax.set_title(bench.plot_name) + ax.set_xscale("log" if bench.x_log else "linear") + ax.set_yscale("log" if bench.y_log else "linear") + if show_plots: + plt.show() + if save_path: + plt.savefig(os.path.join(save_path, f"{bench.plot_name}.png")) + df = df[x_names + bench.line_names] + if diff_col and df.shape[1] == 2: + col0, col1 = df.columns.tolist() + df['Diff'] = df[col1] - df[col0] + + if print_data: + print(bench.plot_name + ':') + print(df.to_string()) + if save_path: + df.to_csv(os.path.join(save_path, f"{bench.plot_name}.csv"), float_format=f"%.{save_precision}f", + index=False) + return df + + def run(self, show_plots=False, print_data=False, save_path='', return_df=False, **kwargs): + has_single_bench = isinstance(self.benchmarks, Benchmark) + benchmarks = [self.benchmarks] if has_single_bench else self.benchmarks + result_dfs = [] + if save_path: + # Create directory if it doesn't exist + os.makedirs(save_path, exist_ok=True) + html = open(os.path.join(save_path, "results.html"), "w") + html.write("\n") + for bench in benchmarks: + result_dfs.append(self._run(bench, save_path, show_plots, print_data, **kwargs)) + if save_path: + html.write(f"\n") + if save_path: + html.write("\n") + html.close() + if return_df: + if has_single_bench: + return result_dfs[0] + else: + return result_dfs + return None - experimental_config = torch_npu.profiler._ExperimentalConfig( - aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, - profiler_level=torch_npu.profiler.ProfilerLevel.Level1, l2_cache=False, data_simplification=False) - if prof_dir is not None: - torch_path = prof_dir - else: - process = multiprocessing.current_process() - pid = process.pid - process_name = process.name - timestamp = datetime.now(tz=timezone.utc).strftime("%Y%m%d_%H%M%S") - base_path = os.path.join(runtime.cache.get_home_dir(), ".triton", "profile_results") - torch_path = os.path.join(base_path, f"prof_{timestamp}_{process_name}-{pid}") - - l2_cache_size = 192 * (1 << 20) - buffer = torch.empty(l2_cache_size // 4, dtype=torch.int, device="npu") - buffer.zero_() - torch.npu.synchronize() # shake out of any npu error - - with torch_npu.profiler.profile( - activities=[torch_npu.profiler.ProfilerActivity.NPU], - on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(torch_path), - record_shapes=False, - profile_memory=False, - with_stack=False, - with_flops=False, - with_modules=False, - experimental_config=experimental_config, - ) as prof: - for _, fn in kernel_dict.items(): - for _ in builtins.range(active): - buffer.zero_() - fn() - torch.npu.synchronize() - del buffer +def perf_report(benchmarks): + """ + Mark a function for benchmarking. The benchmark can then be executed by using the :code:`.run` method on the return value. - tiling_dict = _collect_mul_prof_result(base_dir=torch_path, kernel_dict=kernel_dict, total=active) - _rm_dic(keep_res, torch_path) - return tiling_dict + :param benchmarks: Benchmarking configurations. + :type benchmarks: List of :class:`Benchmark` + """ + wrapper = lambda fn: Mark(fn, benchmarks) + return wrapper -def ext_do_bench_npu(fn, warmup, rep, quantiles, return_mode): +def get_dram_gbps(device=None): + ''' return DRAM bandwidth in GB/s ''' import torch - from triton.testing import _summarize_statistics - avg_time = do_bench_npu(fn, warmup=max(5, warmup), active=max(30, rep)) - return _summarize_statistics(torch.tensor([avg_time], dtype=torch.float), quantiles, return_mode) + from .runtime import driver + if not device: + device = torch.cuda.current_device() + mem_clock_khz = driver.active.utils.get_device_properties(device)["mem_clock_rate"] # in kHz + bus_width = driver.active.utils.get_device_properties(device)["mem_bus_width"] + bw_gbps = mem_clock_khz * bus_width * 2 / 1e6 / 8 # In GB/s + return bw_gbps -def testing_spec_range(num): - return builtins.range(num) +def get_max_tensorcore_tflops(dtype, clock_rate, device=None): + import torch + + from .runtime import driver + if not device: + device = torch.cuda.current_device() -testing_ext_spec_api_list = ["do_bench_npu", "do_bench_multiple_kernel_npu"] + num_subcores = driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4 + capability = torch.cuda.get_device_capability(device) + if capability[0] < 8: + assert dtype == torch.float16 + ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores + else: + if dtype in [torch.float32, torch.int32]: + ops_per_sub_core = 256 + elif dtype in [torch.float16, torch.bfloat16, torch.int16]: + ops_per_sub_core = 512 + elif dtype in [torch.int8, tl.float8e4nv, tl.float8e4b15, tl.float8e5]: + ops_per_sub_core = 1024 + else: + raise RuntimeError("dtype not supported") + tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9 + return tflops + + +# create decorator that wraps test function into +# a cuda-memcheck system call + + +def cuda_memcheck(**target_kwargs): + + def decorator(test_fn): + + @functools.wraps(test_fn) + def wrapper(*args, **kwargs): + import psutil + ppid_name = psutil.Process(os.getppid()).name() + run_cuda_memcheck = target_kwargs.items() <= kwargs.items() + if run_cuda_memcheck and ppid_name != "cuda-memcheck": + path = os.path.realpath(test_fn.__globals__["__file__"]) + # get path of current file + env = {"PATH": os.environ["PATH"], "PYTORCH_NO_CUDA_MEMORY_CACHING": "1"} + assert 'request' in kwargs, "memcheck'ed test must have a (possibly unused) `request` fixture" + test_id = kwargs['request'].node.callspec.id + cmd = f"{path}::{test_fn.__name__}[{test_id}]" + out = subprocess.run(["cuda-memcheck", "pytest", "-vs", cmd], capture_output=True, env=env) + assert out.returncode == 0, "cuda-memcheck returned an error: bounds checking failed" + assert "ERROR SUMMARY: 0 errors" in str(out.stdout) + else: + test_fn(*args, **kwargs) + + return wrapper + + return decorator + + +@contextmanager +def set_gpu_clock(ref_sm_clock=1350, ref_mem_clock=1215): + try: + subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "1"]) + subprocess.check_output([ + "nvidia-smi", + "-i", + "0", + f"--lock-gpu-clocks={ref_sm_clock},{ref_sm_clock}", + ]) + subprocess.check_output([ + "nvidia-smi", + "-i", + "0", + f"--lock-memory-clocks={ref_mem_clock},{ref_mem_clock}", + ]) + cur_sm_clock = nvsmi(["clocks.current.sm"])[0] + cur_mem_clock = nvsmi(["clocks.current.memory"])[0] + assert abs(cur_sm_clock - ref_sm_clock) < 10, f"GPU SMs must run at {ref_sm_clock} MHz" + assert abs(cur_mem_clock - ref_mem_clock) < 10, f"GPU SMs must run at {ref_mem_clock} MHz" + tflops = 1e-6 * 2 * 108 * 4 * 256 * ref_sm_clock + gbps = 640 * 2 * ref_mem_clock * 1e-3 + yield tflops, gbps + finally: + subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "0"]) + subprocess.check_output(["nvidia-smi", "-i", "0", "-rgc"]) + subprocess.check_output(["nvidia-smi", "-i", "0", "-rmc"]) + + +def get_max_simd_tflops(dtype, clock_rate, device=None): + import torch + + from .runtime import driver + if not device: + device = torch.cuda.current_device() + + num_subcores = driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4 + capability = torch.cuda.get_device_capability() + if capability[0] < 8: + if dtype == torch.float32: + ops_per_sub_core = 32 # 2*16 + elif dtype == torch.float16: + ops_per_sub_core = 64 + else: + raise RuntimeError("dtype not supported") + else: + if dtype == torch.float32: + ops_per_sub_core = 32 + elif dtype in [torch.float16, torch.bfloat16]: + ops_per_sub_core = 64 + else: + raise RuntimeError("dtype not supported") + tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9 + return tflops diff --git a/third_party/ascend/backend/testing.py b/third_party/ascend/backend/testing.py new file mode 100644 index 000000000..97e36ca2f --- /dev/null +++ b/third_party/ascend/backend/testing.py @@ -0,0 +1,148 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import builtins +import multiprocessing +import os +from datetime import datetime, timezone + +import triton.runtime as runtime + + +def do_bench_npu(funcs, warmup=5, active=30, clear_l2_cache=False, prof_dir=None, keep_res=False): + import torch + import torch_npu + + if not isinstance(funcs, list): + funcs = [funcs] + + # warmup kernel + for fn in funcs: + fn() + torch.npu.synchronize() + + experimental_config = torch_npu.profiler._ExperimentalConfig( + aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, + profiler_level=torch_npu.profiler.ProfilerLevel.Level1, + l2_cache=False, + data_simplification=False, + ) + + if prof_dir is not None: + torch_path = prof_dir + else: + process = multiprocessing.current_process() + pid = process.pid + process_name = process.name + timestamp = datetime.now(tz=timezone.utc).strftime("%Y%m%d_%H%M%S") + base_path = os.path.join(runtime.cache.get_home_dir(), ".triton", "profile_results") + torch_path = os.path.join(base_path, f"prof_{timestamp}_{process_name}-{pid}") + + if clear_l2_cache: + buffer = runtime.driver.active.get_empty_cache_for_benchmark() + buffer.zero_() + torch.npu.synchronize() # shake out of any npu error + + total = warmup + active + with torch_npu.profiler.profile( + activities=[torch_npu.profiler.ProfilerActivity.NPU], + on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(torch_path), + record_shapes=False, + profile_memory=False, + with_stack=False, + with_flops=False, + with_modules=False, + experimental_config=experimental_config, + ) as prof: + for fn in funcs: + for _ in builtins.range(total): + if clear_l2_cache: + buffer.zero_() + fn() + torch.npu.synchronize() + if clear_l2_cache: + del buffer + + time_cost = _collect_prof_result(torch_path, funcs, warmup, active) + _rm_dic(keep_res, torch_path) + return time_cost + + +def _rm_dic(keep_res, torch_path): + if keep_res: + return + import shutil + + if os.path.exists(torch_path): + shutil.rmtree(torch_path) + + +def _collect_prof_result(base_dir: str, funcs, num_warmup: int, num_active: int, key: str = None): + """ + Collect kernel performance from kernel_details.csv, returned in millisecond. + The first `num_warmup` rows of each function are warmup data and will be ignored, the next `num_active` rows will be averaged. + + :param base_dir: the profiler path + :type base_dir: str + :param funcs: a list of Callable being profiled + :type funcs: List[Callable] + :param num_warmup: warmup count in kernel_details.csv of each fn + :type num_warmup: int + :param num_active: active count in kernel_details.csv of each fn + :type num_active: int + :param key: filter key for kernel name + :type key: str + """ + + import numpy as np + import pandas as pd + + kernel_details_file = None + for root, _, files in os.walk(base_dir): + for file in files: + if file == "kernel_details.csv": + kernel_details_file = os.path.join(root, file) + break + num_funcs = len(funcs) + if kernel_details_file is None: + if num_funcs == 1: + return float("inf") + else: + return [float("inf")] * num_funcs + + df = pd.read_csv(kernel_details_file) + # filter out l2 cache clearing operation + filter_cond = ~df["Type"].str.contains(r"^ZerosLike$", case=False, na=False) + filter_df = df[filter_cond] + if key is not None: + key_rows = filter_df[filter_df["Name"].str.contains(key, na=False)] + else: + key_rows = filter_df + time_cost = [0] * num_funcs + for func_idx in np.arange(0, num_funcs): + for active_index in np.arange(0, num_active): + row_index = func_idx * (num_warmup + num_active) + num_warmup + active_index + time_cost[func_idx] += key_rows.iloc[row_index]["Duration(us)"] + time_cost = [x / num_active / 1e3 for x in time_cost] + + if num_funcs == 1: + return time_cost[0] + else: + return time_cost diff --git a/third_party/ascend/backend/utils.py b/third_party/ascend/backend/utils.py index 2fbb729cd..8f90b5fa4 100644 --- a/third_party/ascend/backend/utils.py +++ b/third_party/ascend/backend/utils.py @@ -1,5 +1,23 @@ -# -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + import functools import hashlib import os @@ -10,9 +28,28 @@ from pathlib import Path import logging from platform import python_version +from triton.backends.ascend.backend_register import backend_strategy_registry import pybind11 +backend_policy = None + + +def get_backend_func(name, *args, **kwargs): + global backend_policy + if backend_policy is None: + backend_policy_env = os.getenv("TRITON_BACKEND", "default").lower() + if backend_policy_env == "torch_npu" or backend_policy_env == "mindspore": + backend_policy = backend_policy_env + if backend_policy is None: + try: + import torch + import torch_npu + backend_policy = "torch_npu" + except ImportError: + backend_policy = "mindspore" + return backend_strategy_registry.execute_func(backend_policy, name, *args, **kwargs) + def get_logger(logger_name, logger_level_str): ''' @@ -113,19 +150,26 @@ def _get_llvm_path(path: str, *paths) -> str: def _get_npucompiler_path() -> str: - npu_compiler_path = shutil.which("bishengir-compile") - if npu_compiler_path is None: - npu_compiler_root = os.getenv("TRITON_NPU_COMPILER_PATH", "") - if npu_compiler_root is None: - raise EnvironmentError("Couldn't find executable bishengir-compile or TRITON_NPU_COMPILER_PATH.") - npu_compiler_path = os.path.join(npu_compiler_root, "npuc") - return npu_compiler_path + ascend_dir = os.path.dirname(os.path.abspath(__file__)) + env = os.environ.copy() + npu_compiler_path = os.path.join(ascend_dir, "bishengir", "bin", "bishengir-compile") + if os.path.exists(npu_compiler_path): + npuir_env_path = os.path.dirname(npu_compiler_path) + env["PATH"] = npuir_env_path + ":" + env["PATH"] + else: + npu_compiler_path = shutil.which("bishengir-compile") + if npu_compiler_path is None: + npu_compiler_root = os.getenv("TRITON_NPU_COMPILER_PATH", None) + if npu_compiler_root is None: + raise EnvironmentError("Couldn't find executable bishengir-compile or TRITON_NPU_COMPILER_PATH.") + npu_compiler_path = os.path.join(npu_compiler_root, "npuc") + return npu_compiler_path, env def _get_bisheng_path() -> str: bisheng_path = shutil.which("bisheng") if bisheng_path is None: - npu_compiler_root = os.getenv("TRITON_NPU_COMPILER_PATH", "") + npu_compiler_root = os.getenv("TRITON_NPU_COMPILER_PATH", None) if npu_compiler_root is None: raise EnvironmentError("Couldn't find executable bisheng or TRITON_NPU_COMPILER_PATH") bisheng_path = os.path.join(npu_compiler_root, "ccec") @@ -146,7 +190,7 @@ def _is_valid_bishengir_path(path: str) -> bool: # if bishengir-compile is a newer version which does not generate kernel_reloc.o # any more. def _check_bishengir_api_change() -> bool: - bishengir_path = _get_npucompiler_path() + bishengir_path, _ = _get_npucompiler_path() if not _is_valid_bishengir_path(bishengir_path): print(f"ERROR: Invalid bishengir path format: {bishengir_path}") return False @@ -169,7 +213,7 @@ def _check_bishengir_api_change() -> bool: def _check_bishengir_is_regbased() -> bool: - bishengir_path = _get_npucompiler_path() + bishengir_path, _ = _get_npucompiler_path() if not _is_valid_bishengir_path(bishengir_path): print(f"ERROR: Invalid bishengir path format: {bishengir_path}") return False @@ -249,14 +293,11 @@ def _get_cxx_precompiled(header_path): def _precompile_npu_hash(header_src): import sys - import torch - import torch_npu cxx = _get_cxx() py_version = sys.version - torch_version = torch.version.git_version - torch_npu_version = torch_npu.version.git_version asc_path = _get_ascend_path().name - version_txt = [header_src, cxx, py_version, torch_version, torch_npu_version, asc_path] + version_txt = [header_src, cxx, py_version, asc_path] + version_txt += get_backend_func("version_hash") hash_txt = hashlib.sha256("_".join(version_txt).encode("utf-8")).hexdigest() return hash_txt @@ -284,23 +325,22 @@ def _precompile_npu_ext(header_path): cc_cmd += [f"-I{os.path.dirname(os.path.realpath(__file__))}"] # find the ascend library asc_path = _get_ascend_path() + + rt_path = os.path.join(asc_path, "include/experiment/runtime/runtime/rt.h") + if not os.path.exists(rt_path): + cc_cmd += [ + f"-I{os.path.join(asc_path, 'pkg_inc')}", + f"-I{os.path.join(asc_path, 'pkg_inc/profiling')}", + ] + cc_cmd += [ f"-I{os.path.join(asc_path, 'include')}", f"-I{os.path.join(asc_path, 'include/experiment')}", f"-I{os.path.join(asc_path, 'include/experiment/msprof')}", f"-I{pybind11.get_include()}", ] - import torch - import torch_npu - torch_path = os.path.dirname(os.path.realpath(torch.__file__)) - torch_npu_path = os.path.dirname(os.path.realpath(torch_npu.__file__)) - use_cxx11_abi = _check_cxx11_abi() - cc_cmd += [ - f"-I{os.path.join(torch_path, 'include')}", - f"-I{os.path.join(torch_npu_path, 'include')}", - f"-D_GLIBCXX_USE_CXX11_ABI={use_cxx11_abi}", - ] + cc_cmd += get_backend_func("get_cc_cmd", build_pch=True) cc_cmd += ["-std=c++17", "-shared", "-fPIC", "-o", gch_path] @@ -342,6 +382,13 @@ def _build_npu_ext(obj_name: str, header_path, src_path, *, kernel_launcher="tor if header_path is not None: cc_cmd += [f"-I{os.path.dirname(header_path)}"] + rt_path = os.path.join(asc_path, "include/experiment/runtime/runtime/rt.h") + if not os.path.exists(rt_path): + cc_cmd += [ + f"-I{os.path.join(asc_path, 'pkg_inc')}", + f"-I{os.path.join(asc_path, 'pkg_inc/profiling')}", + ] + cc_cmd += [ f"-I{os.path.join(asc_path, 'include')}", f"-I{os.path.join(asc_path, 'include/experiment')}", @@ -353,19 +400,7 @@ def _build_npu_ext(obj_name: str, header_path, src_path, *, kernel_launcher="tor ] # FIXME: check why this condition works wrong in parall scene # if kernel_launcher == "torch": - import torch - import torch_npu - - torch_path = os.path.dirname(os.path.realpath(torch.__file__)) - torch_npu_path = os.path.dirname(os.path.realpath(torch_npu.__file__)) - use_cxx11_abi = _check_cxx11_abi() - cc_cmd += [ - f"-I{os.path.join(torch_path, 'include')}", - f"-I{os.path.join(torch_npu_path, 'include')}", - f"-L{os.path.join(torch_npu_path, 'lib')}", - "-ltorch_npu", - f"-D_GLIBCXX_USE_CXX11_ABI={use_cxx11_abi}", - ] + cc_cmd += get_backend_func("get_cc_cmd", build_pch=False) cc_cmd += ["-std=c++17", "-shared", "-fPIC", "-Winvalid-pch", "-o", so_path] @@ -401,9 +436,7 @@ def _get_kernel_target(metadata: dict): def _check_cxx11_abi(): - import torch - - return 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0 + return get_backend_func("cxx_abi") def convert_sigtype_to_int(sigty: str): @@ -434,27 +467,12 @@ def convert_sigtype_to_int(sigty: str): return MAP_SIGTYPE_TO_INT[sigty] -def convert_torch_dtype_to_numpy(torch_dtype): - import torch - import numpy as np - TORCH_TO_NUMPY_DTYPE = { - torch.float32: np.float32, - torch.float64: np.float64, - torch.float16: np.float16, - torch.int8: np.int8, - torch.uint8: np.uint8, - torch.int16: np.int16, - torch.int32: np.int32, - torch.int64: np.int64, - torch.bool: np.bool_, - torch.complex64: np.complex64, - torch.complex128: np.complex128, - } - return TORCH_TO_NUMPY_DTYPE[torch_dtype] +def convert_dtype_to_numpy(dtype): + return get_backend_func("type_convert")[dtype] def _check_bishengir_able_save_ir() -> bool: - bishengir_path = _get_npucompiler_path() + bishengir_path, _ = _get_npucompiler_path() if not _is_valid_bishengir_path(bishengir_path): print(f"ERROR: Invalid bishengir path format: {bishengir_path}") return False diff --git a/third_party/ascend/examples/autotune_cases/02-fused-softmax.py b/third_party/ascend/examples/autotune_cases/02-fused-softmax.py deleted file mode 100644 index 8c6544d49..000000000 --- a/third_party/ascend/examples/autotune_cases/02-fused-softmax.py +++ /dev/null @@ -1,109 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. -""" -Fused Softmax -============= -""" - -import os - -import torch -import torch_npu -import triton -import triton.language as tl -from triton.testing import do_bench_npu - - -# split_params={"x": "XBLOCK"}, tiling_params={"x": "XBLOCK_SUB"}, low_dims=["y"] -# persistent_reduction=False, dual_reduction=False -@triton.autotune( - configs=[], - hints={"enable_ascend_autotune": True}, - key=["n_rows", "n_cols"], -) -@triton.jit -def softmax_kernel( - output_ptr, - input_ptr, - input_row_stride, - output_row_stride, - n_rows, - n_cols, - BLOCK_SIZE: tl.constexpr, - XBLOCK: tl.constexpr, - XBLOCK_SUB: tl.constexpr, -): - # starting row of the program - row_start = tl.program_id(0) * XBLOCK - for row_idx in tl.range(0, XBLOCK, XBLOCK_SUB): - # The stride represents how much we need to increase the pointer to advance 1 row - row_offsets = row_start + row_idx + tl.arange(0, XBLOCK_SUB)[:, None] - col_offsets = tl.arange(0, BLOCK_SIZE)[None, :] - xmask = row_offsets < n_rows - ymask = col_offsets < n_cols - mask = xmask & ymask - input_ptrs = input_ptr + (row_offsets * input_row_stride + col_offsets) - # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols - row = tl.load(input_ptrs, mask=mask, other=-float("inf")) - # Subtract maximum for numerical stability - row_minus_max = row - tl.max(row, axis=1).reshape(XBLOCK_SUB, 1).broadcast_to(XBLOCK_SUB, BLOCK_SIZE) - # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA) - numerator = tl.exp(row_minus_max) - denominator = (tl.sum(numerator, axis=1).reshape(XBLOCK_SUB, 1).broadcast_to(XBLOCK_SUB, BLOCK_SIZE)) - softmax_output = numerator / denominator - # Write back output to DRAM - output_ptrs = output_ptr + (row_offsets * output_row_stride + col_offsets) - tl.store(output_ptrs, softmax_output, mask=mask) - - -def softmax_torch(x): - return torch.softmax(x, axis=-1) - - -def softmax_autotune(x): - n_rows, n_cols = x.shape - BLOCK_SIZE = n_cols - - # Allocate output - y = torch.empty_like(x) - grid = lambda meta: (triton.cdiv(n_rows, meta["XBLOCK"]), 1, 1) - # Create a number of persistent programs. - softmax_kernel[grid](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE) - return y - - -def test_softmax(shape, dtype): - os.environ["TRITON_BENCH_METHOD"] = ( - "npu" # use torch_npu.profiler to get calculating time - ) - x = torch.randn(shape, dtype=dtype, device="npu") - - y_torch = softmax_torch(x) - y_triton = softmax_autotune(x) - assert torch.allclose(y_triton, y_torch) - - time_eager = do_bench_npu(lambda: softmax_torch(x)) - time_triton = do_bench_npu(lambda: softmax_autotune(x)) - assert (time_eager / time_triton) >= 0.8 - print(f"Fused Softmax {shape} {dtype} PASSED!") - - -if __name__ == "__main__": - test_softmax((16896, 1024), torch.float32) diff --git a/third_party/ascend/examples/autotune_cases/03-layer-norm.py b/third_party/ascend/examples/autotune_cases/03-layer-norm.py deleted file mode 100644 index ebdc702f7..000000000 --- a/third_party/ascend/examples/autotune_cases/03-layer-norm.py +++ /dev/null @@ -1,139 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. -""" -Layer Normalization -============= -""" - -import os - -import torch -import torch_npu -import triton -import triton.language as tl -from triton.testing import do_bench_npu - - -# split_params={"x": "XBLOCK_SIZE"}, tiling_params={"y": "RBLOCK_SIZE"}, low_dims=["y"] -# persistent_reduction=False, dual_reduction=False, -@triton.autotune( - configs=[], - hints={"enable_ascend_autotune": True}, - key=["M", "N"], -) -@triton.jit -def _layer_norm_fwd_fused( - X, # pointer to the input - Y, # pointer to the output - W, # pointer to the weights - B, # pointer to the biases - Mean, # pointer to the mean - Rstd, # pointer to the 1/std - stride, # how much to increase the pointer when moving by 1 row - N, - M, # number of columns in X - eps, # epsilon to avoid division by zero - XBLOCK_SIZE: tl.constexpr, - RBLOCK_SIZE: tl.constexpr, -): - # Map the program id to the row of X and Y it should compute. - row_begin = tl.program_id(0) * XBLOCK_SIZE - row_idx = row_begin + tl.arange(0, XBLOCK_SIZE) - row_mask = row_idx < M - row_offsets = row_idx[:, None] * stride - # Compute mean - _mean = tl.zeros((XBLOCK_SIZE, RBLOCK_SIZE), dtype=tl.float32) - for off in range(0, N, RBLOCK_SIZE): - col_idx = off + tl.arange(0, RBLOCK_SIZE) - col_mask = col_idx < N - mask = row_mask[:, None] & col_mask[None, :] - a = tl.load(X + row_offsets + col_idx[None, :], mask=mask, other=0.0).to(tl.float32) - _mean += a - mean = tl.sum(_mean, axis=1, keep_dims=True) / N - # Compute variance - _var = tl.zeros((XBLOCK_SIZE, RBLOCK_SIZE), dtype=tl.float32) - for off in range(0, N, RBLOCK_SIZE): - col_idx = off + tl.arange(0, RBLOCK_SIZE) - col_mask = col_idx < N - mask = row_mask[:, None] & col_mask[None, :] - x = tl.load(X + row_offsets + col_idx[None, :], mask=mask, other=0.0).to(tl.float32) - x = tl.where(mask, x - mean, 0.0) - _var += x * x - var = tl.sum(_var, axis=1, keep_dims=True) / N - rstd = 1 / tl.sqrt(var + eps) - # Write mean / rstd - tl.store(Mean + row_idx[:, None], mean, mask=row_mask[:, None]) - tl.store(Rstd + row_idx[:, None], rstd, mask=row_mask[:, None]) - # Normalize and apply linear transformation - for off in range(0, N, RBLOCK_SIZE): - col_idx = off + tl.arange(0, RBLOCK_SIZE) - col_mask = col_idx < N - mask = row_mask[:, None] & col_mask[None, :] - w = tl.load(W + col_idx, mask=col_mask).reshape((1, RBLOCK_SIZE)) - b = tl.load(B + col_idx, mask=col_mask).reshape((1, RBLOCK_SIZE)) - x = tl.load(X + row_offsets + col_idx[None, :], mask=mask, other=0.0).to(tl.float32) - x_hat = (x - mean) * rstd - y = x_hat * w + b - # Write output - tl.store(Y + row_offsets + col_idx[None, :], y, mask=mask) - - -def layer_norm_torch(args): - x, w_shape, weight, bias, eps, dtype = args - return torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype) - - -def layer_norm_autotune(args): - x, weight, bias, eps = args - # allocate output - y = torch.empty_like(x) - # reshape input data into 2D tensor - x_arg = x.reshape(-1, x.shape[-1]) - M, N = x_arg.shape - mean = torch.empty((M, ), dtype=torch.float32, device=x.device) - rstd = torch.empty((M, ), dtype=torch.float32, device=x.device) - - grid = lambda meta: (triton.cdiv(M, meta["XBLOCK_SIZE"]), 1, 1) - # enqueue kernel - _layer_norm_fwd_fused[grid]( # - x_arg, y, weight, bias, mean, rstd, x_arg.stride(0), N, M, eps # - ) - return y - - -def test_layer_norm(shape, dtype, eps=1e-5): - os.environ["TRITON_BENCH_METHOD"] = ( - "npu" # use torch_npu.profiler to get calculating time - ) - M, N = shape - device = "npu" - x_shape = shape - w_shape = (x_shape[-1], ) - weight = torch.rand(w_shape, dtype=dtype, device=device) - bias = torch.rand(w_shape, dtype=dtype, device=device) - x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device=device) - y_torch = layer_norm_torch((x, w_shape, weight, bias, eps, dtype)) - y_triton = layer_norm_autotune((x, weight, bias, eps)) - assert torch.allclose(y_triton, y_torch, atol=1e-2, rtol=0) - print(f"Layer Normalization {M},{N} {dtype} PASSED!") - - -if __name__ == "__main__": - test_layer_norm((128, 128), torch.float16) diff --git a/third_party/ascend/examples/benchmark_cases/layernorm_perf.py b/third_party/ascend/examples/benchmark_cases/layernorm_perf.py deleted file mode 100644 index ed100aa38..000000000 --- a/third_party/ascend/examples/benchmark_cases/layernorm_perf.py +++ /dev/null @@ -1,469 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. -""" -Layer Normalization -==================== -In this tutorial, you will write a high-performance layer normalization -kernel that runs faster than the PyTorch implementation. - -In doing so, you will learn about: - -* Implementing backward pass in Triton. - -* Implementing parallel reduction in Triton. - -""" - -# %% -# Motivations -# ----------- -# -# The *LayerNorm* operator was first introduced in [BA2016]_ as a way to improve the performance -# of sequential models (e.g., Transformers) or neural networks with small batch size. -# It takes a vector :math:`x` as input and produces a vector :math:`y` of the same shape as output. -# The normalization is performed by subtracting the mean and dividing by the standard deviation of :math:`x`. -# After the normalization, a learnable linear transformation with weights :math:`w` and biases :math:`b` is applied. -# The forward pass can be expressed as follows: -# -# .. math:: -# y = \frac{ x - \text{E}[x] }{ \sqrt{\text{Var}(x) + \epsilon} } * w + b -# -# where :math:`\epsilon` is a small constant added to the denominator for numerical stability. -# Let’s first take a look at the forward pass implementation. - -import torch -import torch_npu - -import triton -import triton.language as tl - -import time - -HAS_APEX = False -DEVICE = "npu" - - -@triton.jit -def _layer_norm_fwd_fused(X, # pointer to the input - Y, # pointer to the output - W, # pointer to the weights - B, # pointer to the biases - Mean, # pointer to the mean - Rstd, # pointer to the 1/std - stride, # how much to increase the pointer when moving by 1 row - N, M, # number of columns in X - eps, # epsilon to avoid division by zero - XBLOCK_SIZE: tl.constexpr, RBLOCK_SIZE: tl.constexpr): - # Map the program id to the row of X and Y it should compute. - row_begin = tl.program_id(0) * RBLOCK_SIZE - row_idx = row_begin + tl.arange(0, RBLOCK_SIZE) - row_mask = row_idx < M - row_offsets = row_idx[:, None] * stride - # Compute mean - - _mean = tl.zeros((RBLOCK_SIZE, XBLOCK_SIZE), dtype=tl.float32) - for off in range(0, N, XBLOCK_SIZE): - col_idx = off + tl.arange(0, XBLOCK_SIZE) - col_mask = col_idx < N - mask = row_mask[:, None] & col_mask[None, :] - a = tl.load(X + row_offsets + col_idx[None, :], mask=mask, other=0.).to(tl.float32) - _mean += a - mean = tl.sum(_mean, axis=1, keep_dims=True) / N - - # Compute variance - _var = tl.zeros((RBLOCK_SIZE, XBLOCK_SIZE), dtype=tl.float32) - for off in range(0, N, XBLOCK_SIZE): - col_idx = off + tl.arange(0, XBLOCK_SIZE) - col_mask = col_idx < N - mask = row_mask[:, None] & col_mask[None, :] - x = tl.load(X + row_offsets + col_idx[None, :], mask=mask, other=0.).to(tl.float32) - x = tl.where(mask, x - mean, 0.) - _var += x * x - var = tl.sum(_var, axis=1, keep_dims=True) / N - - rstd = 1 / tl.sqrt(var + eps) - - # Write mean / rstd - tl.store(Mean + row_idx[:, None], mean, mask=row_mask[:, None]) - tl.store(Rstd + row_idx[:, None], rstd, mask=row_mask[:, None]) - # mean = mean.broadcast_to((RBLOCK_SIZE, XBLOCK_SIZE)) - # rstd = rstd.broadcast_to((RBLOCK_SIZE, XBLOCK_SIZE)) - # Normalize and apply linear transformation - for off in range(0, N, XBLOCK_SIZE): - col_idx = off + tl.arange(0, XBLOCK_SIZE) - col_mask = col_idx < N - mask = row_mask[:, None] & col_mask[None, :] - w = tl.load(W + col_idx, mask=col_mask).reshape((1, XBLOCK_SIZE)) - b = tl.load(B + col_idx, mask=col_mask).reshape((1, XBLOCK_SIZE)) - x = tl.load(X + row_offsets + col_idx[None, :], mask=mask, other=0.).to(tl.float32) - x_hat = (x - mean) * rstd - y = x_hat * w + b - # Write output - tl.store(Y + row_offsets + col_idx[None, :], y, mask=mask) - - -# %% -# Backward pass -# ------------- -# -# The backward pass for the layer normalization operator is a bit more involved than the forward pass. -# Let :math:`\hat{x}` be the normalized inputs :math:`\frac{ x - \text{E}[x] }{ \sqrt{\text{Var}(x) + \epsilon} }` before the linear transformation, -# the Vector-Jacobian Products (VJP) :math:`\nabla_{x}` of :math:`x` are given by: -# -# .. math:: -# \nabla_{x} = \frac{1}{\sigma}\Big( \nabla_{y} \odot w - \underbrace{ \big( \frac{1}{N} \hat{x} \cdot (\nabla_{y} \odot w) \big) }_{c_1} \odot \hat{x} - \underbrace{ \frac{1}{N} \nabla_{y} \cdot w }_{c_2} \Big) -# -# where :math:`\odot` denotes the element-wise multiplication, :math:`\cdot` denotes the dot product, and :math:`\sigma` is the standard deviation. -# :math:`c_1` and :math:`c_2` are intermediate constants that improve the readability of the following implementation. -# -# For the weights :math:`w` and biases :math:`b`, the VJPs :math:`\nabla_{w}` and :math:`\nabla_{b}` are more straightforward: -# -# .. math:: -# \nabla_{w} = \nabla_{y} \odot \hat{x} \quad \text{and} \quad \nabla_{b} = \nabla_{y} -# -# Since the same weights :math:`w` and biases :math:`b` are used for all rows in the same batch, their gradients need to sum up. -# To perform this step efficiently, we use a parallel reduction strategy: each kernel instance accumulates -# partial :math:`\nabla_{w}` and :math:`\nabla_{b}` across certain rows into one of :math:`\text{GROUP_SIZE_M}` independent buffers. -# These buffers stay in the L2 cache and then are further reduced by another function to compute the actual :math:`\nabla_{w}` and :math:`\nabla_{b}`. -# -# Let the number of input rows :math:`M = 4` and :math:`\text{GROUP_SIZE_M} = 2`, -# here's a diagram of the parallel reduction strategy for :math:`\nabla_{w}` (:math:`\nabla_{b}` is omitted for brevity): -# -# .. image:: parallel_reduction.png -# -# In Stage 1, the rows of X that have the same color share the same buffer and thus a lock is used to ensure that only one kernel instance writes to the buffer at a time. -# In Stage 2, the buffers are further reduced to compute the final :math:`\nabla_{w}` and :math:`\nabla_{b}`. -# In the following implementation, Stage 1 is implemented by the function :code:`_layer_norm_bwd_dx_fused` and Stage 2 is implemented by the function :code:`_layer_norm_bwd_dwdb`. - - -@triton.jit -def _layer_norm_bwd_dx_fused(DX, # pointer to the input gradient - DY, # pointer to the output gradient - DW, # pointer to the partial sum of weights gradient - DB, # pointer to the partial sum of biases gradient - X, # pointer to the input - W, # pointer to the weights - Mean, # pointer to the mean - Rstd, # pointer to the 1/std - Lock, # pointer to the lock - stride, # how much to increase the pointer when moving by 1 row - N, # number of columns in X - GROUP_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr): - # Map the program id to the elements of X, DX, and DY it should compute. - row = tl.program_id(0) - cols = tl.arange(0, BLOCK_SIZE_N) - mask = cols < N - X += row * stride - DY += row * stride - DX += row * stride - # Offset locks and weights/biases gradient pointer for parallel reduction - lock_id = row % GROUP_SIZE_M - Lock += lock_id - Count = Lock + GROUP_SIZE_M - DW = DW + lock_id * N + cols - DB = DB + lock_id * N + cols - # Load data to SRAM - x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) - dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) - w = tl.load(W + cols, mask=mask).to(tl.float32) - mean = tl.load(Mean + row) - rstd = tl.load(Rstd + row) - # Compute dx - xhat = (x - mean) * rstd - wdy = w * dy - xhat = tl.where(mask, xhat, 0.) - wdy = tl.where(mask, wdy, 0.) - c1 = tl.sum(xhat * wdy, axis=0) / N - c2 = tl.sum(wdy, axis=0) / N - dx = (wdy - (xhat * c1 + c2)) * rstd - # Write dx - tl.store(DX + cols, dx, mask=mask) - # Accumulate partial sums for dw/db - partial_dw = (dy * xhat).to(w.dtype) - partial_db = (dy).to(w.dtype) - while tl.atomic_cas(Lock, 0, 1) == 1: - pass - count = tl.load(Count) - # First store doesn't accumulate - if count == 0: - tl.atomic_xchg(Count, 1) - else: - partial_dw += tl.load(DW, mask=mask) - partial_db += tl.load(DB, mask=mask) - tl.store(DW, partial_dw, mask=mask) - tl.store(DB, partial_db, mask=mask) - # Release the lock - tl.atomic_xchg(Lock, 0) - - -@triton.jit -def _layer_norm_bwd_dwdb(DW, # pointer to the partial sum of weights gradient - DB, # pointer to the partial sum of biases gradient - FINAL_DW, # pointer to the weights gradient - FINAL_DB, # pointer to the biases gradient - M, # GROUP_SIZE_M - N, # number of columns - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr): - # Map the program id to the elements of DW and DB it should compute. - pid = tl.program_id(0) - cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - # Iterate through the rows of DW and DB to sum the partial sums. - for i in range(0, M, BLOCK_SIZE_M): - rows = i + tl.arange(0, BLOCK_SIZE_M) - mask = (rows[:, None] < M) & (cols[None, :] < N) - offs = rows[:, None] * N + cols[None, :] - dw += tl.load(DW + offs, mask=mask, other=0.) - db += tl.load(DB + offs, mask=mask, other=0.) - # Write the final sum to the output. - sum_dw = tl.sum(dw, axis=0) - sum_db = tl.sum(db, axis=0) - tl.store(FINAL_DW + cols, sum_dw, mask=cols < N) - tl.store(FINAL_DB + cols, sum_db, mask=cols < N) - - -# %% -# Benchmark -# --------- -# -# We can now compare the performance of our kernel against that of PyTorch. -# Here we focus on inputs that have Less than 64KB per feature. -# Specifically, one can set :code:`'mode': 'backward'` to benchmark the backward pass. - -device = torch.npu.current_device() -stream = torch.npu.current_stream(device).npu_stream -kernels = {} - - -class LayerNorm(torch.autograd.Function): - - @staticmethod - def forward(ctx, x, normalized_shape, weight, bias, eps): - - # allocate output - y = torch.empty_like(x) - # reshape input data into 2D tensor - x_arg = x.reshape(-1, x.shape[-1]) - M, N = x_arg.shape - mean = torch.empty((M, ), dtype=torch.float32, device=x.device) - rstd = torch.empty((M, ), dtype=torch.float32, device=x.device) - # Less than 64KB per feature: enqueue fused kernel - - MAX_FUSED_SIZE = 65536 // x.element_size() - XBLOCK_SIZE = 256 - RBLOCK_SIZE = 32 - NUM_CORE = (M - 1) // RBLOCK_SIZE + 1 - num_warps = min(max((N - 1) // XBLOCK_SIZE + 1, 1), 8) - # enqueue kernel - - kernel, num_programs = kernels.get(XBLOCK_SIZE ^ RBLOCK_SIZE, (None, NUM_CORE)) - if kernel is None: - kernel = _layer_norm_fwd_fused.warmup(x_arg, y, weight, bias, mean, rstd, # - x_arg.stride(0), N, M, eps, # - XBLOCK_SIZE=XBLOCK_SIZE, RBLOCK_SIZE=RBLOCK_SIZE, grid=(NUM_CORE, )) - kernel._init_handles() - kernels[XBLOCK_SIZE ^ RBLOCK_SIZE] = (kernel, num_programs) - - kernel[(num_programs, 1, 1)]( # - x_arg, - y, - weight, - bias, - mean, - rstd, # - x_arg.stride(0), - N, - M, - eps, # - stream=stream, - ) - - # _layer_norm_fwd_fused[(NUM_CORE, )]( # - # x_arg, y, weight, bias, mean, rstd, # - # x_arg.stride(0), N, M, eps, # - # XBLOCK_SIZE = XBLOCK_SIZE, - # RBLOCK_SIZE = RBLOCK_SIZE, - # num_warps=num_warps, - # num_ctas=1) - ctx.save_for_backward(x, weight, bias, mean, rstd) - # ctx.BLOCK_SIZE = XBLOCK_SIZE - ctx.num_warps = num_warps - ctx.eps = eps - return y - - @staticmethod - def backward(ctx, dy): - x, w, b, m, v = ctx.saved_tensors - # heuristics for amount of parallel reduction stream for DW/DB - N = w.shape[0] - GROUP_SIZE_M = 64 - if N <= 8192: GROUP_SIZE_M = 96 - if N <= 4096: GROUP_SIZE_M = 128 - if N <= 1024: GROUP_SIZE_M = 256 - # allocate output - locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device=w.device) - _dw = torch.zeros((GROUP_SIZE_M, N), dtype=x.dtype, device=w.device) - _db = torch.zeros((GROUP_SIZE_M, N), dtype=x.dtype, device=w.device) - dw = torch.empty((N, ), dtype=w.dtype, device=w.device) - db = torch.empty((N, ), dtype=w.dtype, device=w.device) - dx = torch.empty_like(dy) - # enqueue kernel using forward pass heuristics - # also compute partial sums for DW and DB - x_arg = x.reshape(-1, x.shape[-1]) - M, N = x_arg.shape - _layer_norm_bwd_dx_fused[(M, )]( # - dx, dy, _dw, _db, x, w, m, v, locks, # - x_arg.stride(0), N, # - BLOCK_SIZE_N=ctx.BLOCK_SIZE, # - GROUP_SIZE_M=GROUP_SIZE_M, # - num_warps=ctx.num_warps) - grid = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])] - # accumulate partial sums in separate kernel - _layer_norm_bwd_dwdb[grid]( - _dw, _db, dw, db, min(GROUP_SIZE_M, M), N, # - BLOCK_SIZE_M=32, # - BLOCK_SIZE_N=128, num_ctas=1) - return dx, None, dw, db, None - - -layer_norm = LayerNorm.apply - - -def test_layer_norm(M, N, dtype, eps=1e-5, device=DEVICE): - # create data - x_shape = (M, N) - w_shape = (x_shape[-1], ) - weight = torch.rand(w_shape, dtype=dtype, device=device, requires_grad=True) - bias = torch.rand(w_shape, dtype=dtype, device=device, requires_grad=True) - x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device=device) - dy = .1 * torch.randn_like(x) - x.requires_grad_(True) - # forward pass - y_tri = layer_norm(x, w_shape, weight, bias, eps) - y_ref = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype) - - assert torch.allclose(y_tri, y_ref, atol=1e-2, rtol=0) - - -@triton.testing.perf_report( - triton.testing.Benchmark(x_names=['N'], x_vals=[512 * i for i in range(20, 30)], line_arg='provider', - line_vals=['triton', 'torch'] + (['apex'] if HAS_APEX else []), - line_names=['Triton', 'Torch'] + (['Apex'] if HAS_APEX else []), styles=[('blue', '-'), - ('green', '-'), - ('orange', '-')], - ylabel='GB/s', plot_name='layer-norm-backward', - args={'M': 3072, 'dtype': torch.float16, 'mode': 'forward'}, # 4096 better - )) -def bench_layer_norm(M, N, dtype, provider, mode='forward', eps=1e-5, device=DEVICE): - # create data - x_shape = (M, N) - w_shape = (x_shape[-1], ) - weight = torch.rand(w_shape, dtype=dtype, device=device, requires_grad=True) - bias = torch.rand(w_shape, dtype=dtype, device=device, requires_grad=True) - x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device=device) - dy = .1 * torch.randn_like(x) - x.requires_grad_(True) - quantiles = [0.5, 0.2, 0.8] - - def y_fwd(): - - if provider == "triton": - return layer_norm(x, w_shape, weight, bias, eps) # noqa: F811, E704 - - if provider == "torch": - return torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps) # noqa: F811, E704 - - if provider == "apex": - apex_layer_norm = (apex.normalization.FusedLayerNorm(w_shape).to(x.device).to(x.dtype)) - return apex_layer_norm(x) # noqa: F811, E704 - - # forward pass - if mode == 'forward': - gbps = lambda ms: ms * 1000 - ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, quantiles=quantiles, rep=500) - # backward pass - if mode == 'backward': - y = y_fwd() - gbps = lambda ms: ms * 1000 - ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True), quantiles=quantiles, - grad_to_none=[x], rep=500) - return gbps(ms), gbps(max_ms), gbps(min_ms) - - -def benchmark_test(fn, fn_triton, args=(), name="gen_fn", times=100, repeat=10): - print(f"--------------------benchmark_{name} for {times * repeat} times--------------------") - stream = torch.npu.current_stream() - # warm_up - stream.synchronize() - for _ in range(10): - fn_triton(*args) - stream.synchronize() - - start = time.perf_counter() - for _ in range(times * repeat): - fn_triton(*args) - stream.synchronize() - end = time.perf_counter() - - time_compiled = (end - start) / (times * repeat) - time_compiled *= 1000000 - print(f"time_triton:{time_compiled:.6f}") - - print(f"Runing eager {name} for {times * repeat} times") - - # warm_up - stream.synchronize() - for _ in range(10): - std = fn(*args) - stream.synchronize() - - start = time.perf_counter() - for _ in range(times * repeat): - std = fn(*args) - stream.synchronize() - end = time.perf_counter() - time_eager = (end - start) / (times * repeat) - time_eager *= 1000000 - print(f"time_eager:{time_eager:.6f}") - - accelerated = (time_eager - time_compiled) / time_compiled * 100 - print(f"Accelerated: {accelerated:.4f}% eager takes {time_eager:.3f} us, triton takes {time_compiled:.3f} us") - - return accelerated, time_eager, time_compiled - - -test_layer_norm(1151, 8192, torch.float16) - -M = 2048 -N = 8192 # 12288 12800 13312 13000 -x_shape = (M, N) -w_shape = (x_shape[-1], ) -weight = torch.rand(w_shape, dtype=torch.float16, device='npu', requires_grad=True) -bias = torch.rand(w_shape, dtype=torch.float16, device='npu', requires_grad=True) -x = -2.3 + 0.5 * torch.randn(x_shape, dtype=torch.float16, device='npu') -eps = 1e-5 -benchmark_test(torch.nn.functional.layer_norm, layer_norm, args=(x, w_shape, weight, bias, eps)) - -# %% -# References -# ---------- -# -# .. [BA2016] Jimmy Lei Ba and Jamie Ryan Kiros and Geoffrey E. Hinton, "Layer Normalization", Arxiv 2016 diff --git a/third_party/ascend/examples/benchmark_cases/softmax_perf.py b/third_party/ascend/examples/benchmark_cases/softmax_perf.py deleted file mode 100644 index 2d5daffe3..000000000 --- a/third_party/ascend/examples/benchmark_cases/softmax_perf.py +++ /dev/null @@ -1,283 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. -""" -Fused Softmax -============= - -In this tutorial, you will write a fused softmax operation that is significantly faster -than PyTorch's native op for a particular class of matrices: those whose rows can fit in -the NPU's SRAM. - -In doing so, you will learn about: - -* The benefits of kernel fusion for bandwidth-bound operations. - -* Reduction operators in Triton. - -""" - -# %% -# Motivations -# ----------- -# -# Custom NPU kernels for elementwise additions are educationally valuable but won't get you very far in practice. -# Let us consider instead the case of a simple (numerically stabilized) softmax operation: - -import torch -import torch_npu -import triton -import triton.language as tl -from triton.runtime import driver -import time - - -def naive_softmax(x): - """Compute row-wise softmax of X using native pytorch - - We subtract the maximum element in order to avoid overflows. Softmax is invariant to - this shift. - """ - # read MN elements ; write M elements - x_max = x.max(dim=1)[0] - # read MN + M elements ; write MN elements - z = x - x_max[:, None] - # read MN elements ; write MN elements - numerator = torch.exp(z) - # read MN elements ; write M elements - denominator = numerator.sum(dim=1) - # read MN + M elements ; write MN elements - ret = numerator / denominator[:, None] - # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements - return ret - - -# %% -# When implemented naively in PyTorch, computing :code:`y = naive_softmax(x)` for :math:`x \in R^{M \times N}` -# requires reading :math:`5MN + 2M` elements from DRAM and writing back :math:`3MN + 2M` elements. -# This is obviously wasteful; we'd prefer to have a custom "fused" kernel that only reads -# X once and does all the necessary computations on-chip. -# Doing so would require reading and writing back only :math:`MN` bytes, so we could -# expect a theoretical speed-up of ~4x (i.e., :math:`(8MN + 4M) / 2MN`). -# The `torch.jit.script` flags aims to perform this kind of "kernel fusion" automatically -# but, as we will see later, it is still far from ideal. - -# %% -# Compute Kernel -# -------------- -# -# Our softmax kernel works as follows: each program loads a set of rows of the input matrix X strided by number of programs, -# normalizes it and writes back the result to the output Y. -# -# Note that one important limitation of Triton is that each block must have a -# power-of-two number of elements, so we need to internally "pad" each row and guard the -# memory operations properly if we want to handle any possible input shapes: - - -@triton.jit -def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr, - XBLOCK: tl.constexpr, num_stages: tl.constexpr): - # starting row of the program - row_start = tl.program_id(0) * XBLOCK - XBLOCK_SUB: tl.constexpr = 8 - #for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages): - for row_idx in tl.range(0, XBLOCK, XBLOCK_SUB): - # The stride represents how much we need to increase the pointer to advance 1 row - row_offsets = row_start + row_idx + tl.arange(0, XBLOCK_SUB)[:, None] - col_offsets = tl.arange(0, BLOCK_SIZE)[None, :] - xmask = (row_offsets < n_rows) - ymask = (col_offsets < n_cols) - mask = xmask & ymask - input_ptrs = input_ptr + (row_offsets * input_row_stride + col_offsets) - # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols - row = tl.load(input_ptrs, mask=mask, other=-float('inf')) - # Subtract maximum for numerical stability - row_minus_max = row - tl.max(row, axis=1).reshape(XBLOCK_SUB, 1).broadcast_to(XBLOCK_SUB, BLOCK_SIZE) - numerator = tl.exp(row_minus_max) - denominator = tl.sum(numerator, axis=1).reshape(XBLOCK_SUB, 1).broadcast_to(XBLOCK_SUB, BLOCK_SIZE) - softmax_output = numerator / denominator - # Write back output to DRAM - output_ptrs = output_ptr + (row_offsets * output_row_stride + col_offsets) - tl.store(output_ptrs, softmax_output, mask=mask) - - -# %% -# We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor. -# NUM_SM = properties["multiprocessor_count"] -# NUM_REGS = properties["max_num_regs"] -# SIZE_SMEM = properties["max_shared_mem"] -# WARP_SIZE = properties["warpSize"] -target = triton.runtime.driver.active.get_current_target() -kernels = {} - -device = torch.npu.current_device() -stream = torch.npu.current_stream(device).npu_stream - - -def softmax(x): - n_rows, n_cols = x.shape - - # The block size of each loop iteration is the smallest power of two greater than the number of columns in `x` - num_programs = 32 - - XBLOCK = (n_rows + num_programs - 1) // num_programs - BLOCK_SIZE = n_cols - # Another trick we can use is to ask the compiler to use more threads per row by - # increasing the number of warps (`num_warps`) over which each row is distributed. - # You will see in the next tutorial how to auto-tune this value in a more natural - # way so you don't have to come up with manual heuristics yourself. - num_warps = 8 - - # Number of software piepling stages. - #num_stages = 4 if SIZE_SMEM > 200000 else 2 - num_stages = 4 - - # Allocate output - y = torch.empty_like(x) - - # pre-compile kernel to get register usage and compute thread occupancy. - kernel, num_programs = kernels.get(BLOCK_SIZE, (None, num_programs)) - if kernel is None: - kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE, - XBLOCK=XBLOCK, num_stages=num_stages, num_warps=num_warps, grid=(32, )) - kernel._init_handles() - # n_regs = kernel.n_regs - # size_smem = kernel.metadata.shared - # occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps) - # occupancy = min(occupancy, SIZE_SMEM // size_smem) - # num_programs = NUM_SM * occupancy - kernels[BLOCK_SIZE] = (kernel, num_programs) - - num_programs = min(num_programs, n_rows) - - # Create a number of persistent programs. - kernel[(32, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, stream=stream) - return y - - -# %% -# Unit Test -# --------- - -# %% -# We make sure that we test our kernel on a matrix with an irregular number of rows and columns. -# This will allow us to verify that our padding mechanism works. - - -def torch_softmax(x): - return torch.softmax(x, axis=-1) - - -torch.manual_seed(0) -# x = torch.randn(1823, 781, device='npu') -x = torch.randn(4096, 1024, device='npu') -y_triton = softmax(x) -y_torch = torch_softmax(x) -assert torch.allclose(y_triton, y_torch), (y_triton, y_torch) - - -#torch.testing.assert_close(y_triton, y_torch, rtol=1e-3, atol=1e-3) -# %% -# As expected, the results are identical. -def benchmark_test(fn, fn_triton, args=(), name="gen_fn", times=100, repeat=10): - print(f"--------------------benchmark_{name} for {times * repeat} times--------------------") - stream = torch.npu.current_stream() - # warm_up - stream.synchronize() - for _ in range(10): - fn_triton(args) - stream.synchronize() - - start = time.perf_counter() - for _ in range(times * repeat): - fn_triton(args) - stream.synchronize() - end = time.perf_counter() - - time_compiled = (end - start) / (times * repeat) - time_compiled *= 1000000 - print(f"time_triton:{time_compiled:.6f}") - - print(f"Runing eager {name} for {times * repeat} times") - - # warm_up - stream.synchronize() - for _ in range(10): - std = fn(args) - stream.synchronize() - - start = time.perf_counter() - for _ in range(times * repeat): - std = fn(args) - stream.synchronize() - end = time.perf_counter() - time_eager = (end - start) / (times * repeat) - time_eager *= 1000000 - print(f"time_eager:{time_eager:.6f}") - - accelerated = (time_eager - time_compiled) / time_compiled * 100 - print(f"Accelerated: {accelerated:.4f}% eager takes {time_eager:.3f} us, triton takes {time_compiled:.3f} us") - - return accelerated, time_eager, time_compiled - - -# x = torch.randn(4096, 1024, device='npu') -benchmark_test(torch_softmax, softmax, args=x) -# %% -# Benchmark -# --------- -# -# Here we will benchmark our operation as a function of the number of columns in the input matrix -- assuming 4096 rows. -# We will then compare its performance against (1) :code:`torch.softmax` and (2) the :code:`naive_softmax` defined above. - -# @triton.testing.perf_report( -# triton.testing.Benchmark( -# x_names=['N'], # argument names to use as an x-axis for the plot -# x_vals=[128 * i for i in range(2, 8)], # different possible values for `x_name` -# line_arg='provider', # argument name whose value corresponds to a different line in the plot -# line_vals=['triton', 'torch'], # possible values for `line_arg`` -# line_names=[ -# "Triton", -# "Torch", -# ], # label name for the lines -# styles=[('blue', '-'), ('green', '-')], # line styles -# ylabel="GB/s", # label name for the y-axis -# plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot. -# args={'M': 4096}, # values for function arguments not in `x_names` and `y_name` -# )) -# def benchmark(M, N, provider): -# x = torch.randn(M, N, device='npu', dtype=torch.float32) -# #stream = torch.npu.Stream() -# #torch.npu.set_stream(stream) - -# if provider == 'torch': -# ms = triton.testing.do_bench(lambda: torch_softmax(x)) -# if provider == 'triton': -# ms = triton.testing.do_bench(lambda: softmax(x)) -# # gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3) -# gbps = lambda ms: ms*1000 -# return gbps(ms) - -# benchmark.run(show_plots=True, print_data=True) - -# %% -# In the above plot, we can see that: -# - Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here. -# - Triton is noticeably faster than :code:`torch.softmax` -- in addition to being **easier to read, understand and maintain**. -# Note however that the PyTorch `softmax` operation is more general and will work on tensors of any shape. diff --git a/third_party/ascend/examples/custom_op/builtin_ops_demo.py b/third_party/ascend/examples/custom_op/builtin_ops_demo.py new file mode 100755 index 000000000..e3e9fc675 --- /dev/null +++ b/third_party/ascend/examples/custom_op/builtin_ops_demo.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +import subprocess +import triton +import triton.language as tl +import triton.language.extra.cann.extension as al +from triton.compiler.compiler import ASTSource +from triton.compiler.code_generator import ast_to_ttir +from triton._C.libtriton import ir +from triton._C.libtriton.ascend import ir as ascend_ir +from triton.backends.ascend.compiler import NPUOptions, ttir_to_linalg + + +@triton.jit +def my_kernel(x_ptr, y_ptr, out_ptr, n, BLOCK: tl.constexpr): + i = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) + x = tl.load(x_ptr + i, mask=i < n) + y = tl.load(y_ptr + i, mask=i < n) + index = tl.full([8], 0, tl.int32) + value = tl.full([8, 64], 0, tl.float32) + tmp = tl.full([8], 0, tl.float32) + x = al.custom("__builtin_index_select", x_ptr, index, dim=0, bound=100, end_offset=(2, 2), start_offset=(0, 0), + src_stride=(4, 1), out=x) + al.custom("__builtin_index_put", x_ptr, index, value, dim=0, bound=12, dst_shape=(1, 2, 3), dst_offset=(4, 5, 6), + dst_stride=(8, 4, 1)) + tmp = al.custom("__builtin_gather_load", y_ptr, index, bound=100, dim=0, src_stride=(1, ), index_shape=(3, ), + offsets=(0, ), out=tmp) + al.custom("__builtin_scatter_store", out_ptr, value, index, 1, 0, (1, ), (2, ), (1, )) + y = al.custom("__builtin_indirect_load", x_ptr, index, mask=i < n, other=y, out=y) + al.custom("__builtin_indirect_store", out_ptr, index, value) + tl.store(out_ptr + i, y, mask=i < n) + + +if __name__ == "__main__": + src = ASTSource(my_kernel, {"x_ptr": "*fp32", "y_ptr": "*fp32", "out_ptr": "*fp32", "n": "i32"}, {"BLOCK": 256}) + context = ir.context() + ir.load_dialects(context) + ascend_ir.load_dialects(context) + options = NPUOptions() + try: + ttir = ast_to_ttir(my_kernel, src, context, options, {}, {}) + print("=== TTIR ===") + print(ttir) + metadata = { + **options.__dict__, + } + linalg = ttir_to_linalg(ttir, metadata, options, named_ops=True) + print("=== MLIR (linalg) ===") + print(linalg) + except subprocess.CalledProcessError as ex: + print(ex.stdout.decode()) + print(ex.stderr.decode()) + print("failed") diff --git a/third_party/ascend/examples/custom_op/custom_op_demo.py b/third_party/ascend/examples/custom_op/custom_op_demo.py new file mode 100755 index 000000000..b14ce4f59 --- /dev/null +++ b/third_party/ascend/examples/custom_op/custom_op_demo.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python3 +import subprocess +import triton +import triton.language as tl +import triton.language.extra.cann.extension as al +from triton.compiler.compiler import ASTSource +from triton.compiler.code_generator import ast_to_ttir +from triton._C.libtriton import ir +from triton._C.libtriton.ascend import ir as ascend_ir +from triton.backends.ascend.compiler import NPUOptions, ttir_to_linalg + + +@al.register_custom_op +class min_custom_op: + core = al.CORE.VECTOR + pipe = al.PIPE.PIPE_MTE2 + mode = al.MODE.SIMD + + +@al.register_custom_op +class simple_custom_op: + # name is optional, use class name by default. + name = 'simple_custom_op' + + # required attributes. + core = al.CORE.VECTOR + pipe = al.PIPE.PIPE_V + mode = al.MODE.SIMT + + # __init__ method is optional, but it can be used for better user experience + # when provided. for example, you can validate arguments here. + def __init__(self, x, y, dim=0, out=None): + assert x.shape == y.shape, "x and y should have same shape" + assert isinstance(dim, int), "dim should be const integer" + assert out, "out is required" + + +@al.register_custom_op +class _example_custom_op: + name = 'example_custom_op' + core = al.CORE.VECTOR + pipe = al.PIPE.PIPE_V + mode = al.MODE.SIMT + + def __init__(self, src, index, offset: tl.int64, axis, out=None): + # support validate arguments in __init__ method. + assert isinstance(src, tl.tensor), "src should be tensor" + assert index.dtype.is_int(), "index should be integer tensor" + assert isinstance(offset, int), "offset should be integer" + assert isinstance(axis, int), "axis should be integer" + + # support multi-output by using tuple or list. + assert isinstance(out, tuple) and len(out) == 2, "out should be tuple of 2 items" + + # setup the symbol name of the function that will be called at runtime. + rank = len(index.shape) + self.symbol = f"{self.name}_{rank}d_{src.dtype.cname}_{index.dtype.cname}" + + # setup source and compile command if it is implemented by user source code. + self.source = f"workspace/example_custom_op_impl.cce" + self.compile = "bisheng -O2 -std=c++17 -o $@ -c $<" + + # dynamic set argument type. + self.arg_type['axis'] = index.dtype + + +@al.builtin +def example_op(src, index, offset, axis, _builder=None): + # you can wrap a custom op as a builtin operation, + # output can be provided here to make it easy to use. + x = tl.semantic.full(src.shape, 0, tl.float32, _builder) + y = tl.semantic.full(index.shape, 0, tl.float32, _builder) + return al.custom_semantic(_example_custom_op.name, src, index, offset, axis, out=(x, y), _builder=_builder) + + +@triton.jit +def my_kernel(x_ptr, y_ptr, out_ptr, n, BLOCK: tl.constexpr): + i = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) + x = tl.load(x_ptr + i, mask=i < n) + y = tl.load(y_ptr + i, mask=i < n) + y = al.custom("min_custom_op", x, x_ptr, y_ptr + i, al.int64(0), (1, 2, 3), [4.1, 5.2], out=y) + y = al.custom("simple_custom_op", x, y, dim=1, out=y) + index = tl.full((2, 3), 0, tl.int64) + x, y = al.custom("example_custom_op", x, index, offset=1, axis=0, out=(x, y)) + result, _ = example_op(x, index, offset=2, axis=1) + tl.store(out_ptr + i, result, mask=i < n) + + +if __name__ == "__main__": + src = ASTSource(my_kernel, {"x_ptr": "*fp32", "y_ptr": "*fp32", "out_ptr": "*fp32", "n": "i32"}, {"BLOCK": 256}) + context = ir.context() + ir.load_dialects(context) + ascend_ir.load_dialects(context) + options = NPUOptions() + try: + ttir = ast_to_ttir(my_kernel, src, context, options, {}, {}) + print("=== TTIR ===") + print(ttir) + metadata = { + **options.__dict__, + } + linalg = ttir_to_linalg(ttir, metadata, options, named_ops=True) + print("=== MLIR (linalg) ===") + print(linalg) + except subprocess.CalledProcessError as ex: + print(ex.stdout.decode()) + print(ex.stderr.decode()) + print("failed") diff --git a/third_party/ascend/examples/custom_op/test_gather_load.py b/third_party/ascend/examples/custom_op/test_gather_load.py new file mode 100755 index 000000000..03e83b917 --- /dev/null +++ b/third_party/ascend/examples/custom_op/test_gather_load.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python3 +import torch +import triton +import triton.language as tl +import triton.language.extra.cann.extension as al + + +@triton.jit +def test_gather_load_kernel(src_ptr, index_ptr, out_ptr): + # index tile shape: (2, 2) + cols = tl.arange(0, 2)[None, :] # [[0, 1]] + rows = tl.arange(0, 2)[:, None] # [[0],[1]] + mask = (rows < 2) & (cols < 2) + + # load index tile to UB + index = tl.load(index_ptr + rows * 2 + cols, mask) + + # gather load from GM to UB + dst = tl.full(index.shape, 0, tl.float32) + gathered = al.custom("__builtin_gather_load", src_ptr, index, bound=4, dim=0, src_stride=(2, 1), index_shape=(2, 2), + offsets=(0, 0), out=dst) + + # store result to GM + tl.store(out_ptr + rows * 2 + cols, gathered, mask) + + +if __name__ == "__main__": + src = torch.tensor([[1., 2.], [3., 4.], [5., 6.], [7., 8.]], device='npu') + index = torch.tensor([[0, 1], [2, 3]], device='npu') + out = torch.empty((2, 2), device='npu', dtype=torch.float32) + test_gather_load_kernel[(1, )](src, index, out) + print("result: ", out) # [[1., 4.], [5., 8.]] diff --git a/third_party/ascend/examples/custom_op/test_index_select.py b/third_party/ascend/examples/custom_op/test_index_select.py new file mode 100644 index 000000000..d06174fde --- /dev/null +++ b/third_party/ascend/examples/custom_op/test_index_select.py @@ -0,0 +1,44 @@ +import pytest +import torch +import triton +import triton.language as tl +import triton.language.extra.cann.extension as al + + +@triton.jit +def builtin_index_select_kernel(src_ptr, index_ptr, out_ptr): + # Define 2x2 tile indices for output tensor + r = tl.arange(0, 2)[:, None] # Row indices: shape [2, 1] + c = tl.arange(0, 2)[None, :] # Column indices: shape [1, 2] + + # Load index tensor (shape [2]) from GM to UB + idx = tl.load(index_ptr + tl.arange(0, 2)) + # Initialize empty 2x2 output tile in UB (default value: 0) + dst = tl.full((2, 2), 0, dtype=tl.float32) + + # Invoke __builtin_index_select custom op to gather elements + out_tile = al.custom("__builtin_index_select", src_ptr, # Pointer to source tensor in GM + idx, # Index tensor (in UB) for gathering + dim=0, # Dimension to gather along + bound=4, # Upper bound for valid index values (out-of-bound check) + end_offset=(2, 2), # End offsets of each dimension for the index tensor + start_offset=(0, 0), # Start offsets of each dimension for the source tensor + src_stride=(4, 1), # Stride of each dimension for the source tensor in GM + out=dst # Output tensor (in UB) to store gathered elements + ) + + # Store the gathered tile from UB to output tensor in GM + tl.store(out_ptr + r * 2 + c, out_tile) + + +if __name__ == "__main__": + src = torch.tensor( + [[10., 11., 12., 13.], [20., 21., 22., 23.], [30., 31., 32., 33.], [40., 41., 42., 43.]], + device="npu", + dtype=torch.float32, + ) + index = torch.tensor([2, 0], device="npu", dtype=torch.int32) + out = torch.empty((2, 2), device="npu", dtype=torch.float32) + ref = torch.index_select(src, 0, index.to(torch.int64))[:, :2] + builtin_index_select_kernel[(1, )](src, index, out) + torch.testing.assert_close(out, ref) # ref: [[30., 31.], [10., 11.]] diff --git a/third_party/ascend/examples/flaggems_cases/op_test_run.sh b/third_party/ascend/examples/flaggems_cases/op_test_run.sh deleted file mode 100755 index 07d91827b..000000000 --- a/third_party/ascend/examples/flaggems_cases/op_test_run.sh +++ /dev/null @@ -1,322 +0,0 @@ -#!/bin/bash - -# 获取传入参数 -device_count="${1:-1}" # 默认使用1个设备 -threads_per_device="${2:-64}" # 每个设备线程数,默认64 - -# 定义路径 -DIR_TESTS="tests" -DIR_BENCHMARK="benchmark" -PR_LOG_DIR="/home/pr_test_log" -TIMESTAMP=$(date +"%Y%m%d") -LOG_ARCHIVE="test_flaggems_logs_${TIMESTAMP}.tar.gz" -SUMMARY_FILE="${WORKSPACE}/triton-ascend/ascend/examples/summary.txt" # 新增:统计信息文件 - -# 检查日志目录 -mkdir -p "$PR_LOG_DIR" || { echo "无法创建日志目录 $PR_LOG_DIR"; exit 1; } - -# 中央计数器文件定义 -COUNTER_FILE=$(mktemp) -LOCK_FILE="/tmp/op_test_run.lock" -touch $LOCK_FILE - -# ===== 修改:改进的统计结果收集机制 ===== -# 使用文件存储统计结果 -STATS_DIR=$(mktemp -d) -# 初始化设备统计文件 -for ((device_id=0; device_id < device_count; device_id++)); do - stats_file="${STATS_DIR}/device_${device_id}.stats" - echo "success=0" > "$stats_file" - echo "failure=0" >> "$stats_file" - echo "skipped=0" >> "$stats_file" - echo "error=0" >> "$stats_file" -done - -# 原子更新统计 -record_stats() { - local device_id=$1 - local status=$2 # success/failure/skipped/error - local stats_file="${STATS_DIR}/device_${device_id}.stats" - - ( - flock -x 20 - # 读取当前值 - current=$(grep "^${status}=" "$stats_file" | cut -d= -f2) - # 更新值 - new_value=$((current + 1)) - # 替换文件中的值 - sed -i "s/^${status}=.*/${status}=${new_value}/" "$stats_file" - ) 20>"${stats_file}.lock" -} - -# 任务队列管理函数 -init_task_queue() { - local -n arr_ref=$1 - TASK_FILE=$(mktemp) - printf "%s\n" "${arr_ref[@]}" > "$TASK_FILE" - echo 0 > "$TASK_FILE.counter" - echo "${#arr_ref[@]}" > "$COUNTER_FILE.total" - echo 0 > "$COUNTER_FILE.completed" -} - -get_next_task() { - ( - # 文件锁保证原子操作 - flock -x 9 - counter=$(< $TASK_FILE.counter) - total_tasks=$(wc -l < $TASK_FILE) - - if (( counter >= total_tasks )); then - echo "" - return - fi - - task_name=$(sed -n "$((counter+1))p" $TASK_FILE) - echo $((counter+1)) > "$TASK_FILE.counter" - echo "$task_name" - ) 9> "$TASK_FILE.lock" -} - -# 原子更新完成计数器 -update_progress() { - ( - flock -x 11 - local current=$(< $COUNTER_FILE.completed) - echo $((current + 1)) > $COUNTER_FILE.completed - echo $((current + 1)) - ) 11> $LOCK_FILE -} - -# 获取进度信息 -get_progress() { - ( - flock -s 11 # 共享锁(只读) - completed=$(< $COUNTER_FILE.completed) - total=$(< $COUNTER_FILE.total) - echo "$completed $total" - ) 11> $LOCK_FILE -} - -cleanup_tasks() { - rm -f "$TASK_FILE" "$TASK_FILE.counter" "$TASK_FILE.lock" $LOCK_FILE $COUNTER_FILE* -} - -# 算子列表定义 -OPS=("abs" "add" "addmm" "all" "amax" "argmax" "bitwise_and" "bitwise_not" "bitwise_or" "bmm" \ -"cos" "CrossEntryLoss" "div" "dropout" "eq" "exp" "fill" "ge" "gelu" "group_norm" "gt" "isinf" \ -"isnan" "rsub" "le" "linear" "log_softmax" "lt" "max" "mean" "min" "mm" "mul" "mv" \ -"native_dropout" "ne" "neg" "pow" "prod" "reciprocal" "relu" "rsqrt" "sigmoid" "silu" \ -"sin" "softmax" "sub" "sum" "tanh" "triu") - -total_ops=${#OPS[@]} -echo "======================================" -echo "测试算子列表: ${OPS[@]}" -echo "算子总数: $total_ops" -echo "使用设备数量: $device_count" -echo "每设备线程数: $threads_per_device" -echo "======================================" - -# 初始化性能计数器 - 修复开始时间显示问题 -start_time=$(date +%s) # 使用Unix时间戳 - -# 线程执行函数 - 正确性测试 -run_tests_thread() { - local device_id=$1 - local thread_id=$2 - local device_log_dir=$3 - local thread_log_dir="$device_log_dir/thread_${thread_id}" - mkdir -p "$thread_log_dir" - - while true; do - task_name=$(get_next_task) - [[ -z "$task_name" ]] && break - - echo "[设备 $device_id-线程 $thread_id] 正在执行: pytest -m $task_name --ref cpu -sv" - log_file="${thread_log_dir}/result_${task_name}.log" - - # 执行正确性测试并记录时间 - start_op=$(date +%s) - python -m pytest -m $task_name --dist=loadfile --ref cpu -sv &> "$log_file" - exit_code=$? - duration=$(( $(date +%s) - start_op )) - - # 根据退出码记录不同状态 - case $exit_code in - 0) - status="success" - ;; - 1) - status="failure" - ;; - 2) # pytest跳过用例的退出码 - status="skipped" - ;; - *) - status="error" - ;; - esac - - # 记录统计结果 - record_stats $device_id $status - - # 原子更新完成计数 - new_completed=$(update_progress) - - # 获取最新进度状态 - read completed total < <(get_progress) - progress=$(( completed * 100 / total )) - - # 输出结果 - if [ $exit_code -ne 0 ]; then - echo "[错误] [$device_id-$thread_id] $task_name 失败! (用时 ${duration}s, 进度: $completed/$total)" - else - echo "[成功] [$device_id-$thread_id] $task_name 完成! (用时 ${duration}s, 进度: $completed/$total)" - fi - done -} - -# 设备主函数 -run_device() { - local device_id=$1 - local device_log_dir="device_${device_id}_logs" - mkdir -p "$device_log_dir" - - # 创建设备内的线程池 - for ((thread_id=0; thread_id < threads_per_device; thread_id++)); do - run_tests_thread $device_id $thread_id "$device_log_dir" & - done - - # 等待设备内所有线程完成 - wait - echo "======== 设备 $device_id 上所有任务完成 ========" -} - -# 根据参数执行测试 -cd "$DIR_TESTS" || { echo "无法进入目录 $DIR_TESTS"; exit 1; } - -# 创建全局任务队列 -init_task_queue OPS - -# 启动设备主进程 -for ((device_id=0; device_id < device_count; device_id++)); do - ( - export ASCEND_RT_VISIBLE_DEVICES=$device_id - run_device $device_id - ) & -done - -# 等待所有设备完成 -wait -cleanup_tasks - -# ===== 修改:改进的统计信息汇总 ===== -total_success=0 -total_failure=0 -total_skipped=0 -total_error=0 - -# 按设备汇总结果 -for ((device_id=0; device_id < device_count; device_id++)); do - stats_file="${STATS_DIR}/device_${device_id}.stats" - - if [ -f "$stats_file" ]; then - # 从文件加载统计 - d_success=$(grep '^success=' "$stats_file" | cut -d= -f2) - d_failure=$(grep '^failure=' "$stats_file" | cut -d= -f2) - d_skipped=$(grep '^skipped=' "$stats_file" | cut -d= -f2) - d_error=$(grep '^error=' "$stats_file" | cut -d= -f2) - - total_success=$((total_success + d_success)) - total_failure=$((total_failure + d_failure)) - total_skipped=$((total_skipped + d_skipped)) - total_error=$((total_error + d_error)) - - # 记录设备统计 - echo "设备 $device_id 完成情况: $d_success 成功, $d_failure 失败, $d_skipped 跳过, $d_error 错误" - else - echo "警告: 设备 $device_id 的统计文件未找到" - fi -done - -# 清理统计目录 -rm -rf "$STATS_DIR" - -# 计算总耗时 -total_time=$(( $(date +%s) - start_time )) # 使用绝对时间计算总耗时 -hours=$(( total_time / 3600 )) -minutes=$(( (total_time % 3600) / 60 )) -seconds=$(( total_time % 60 )) -time_str=$(printf "%02dh %02dm %02ds" $hours $minutes $seconds) - -# 计算平均耗时 -if [[ $total_ops -gt 0 ]]; then - completed_ops=$((total_success + total_failure + total_error)) - if [[ $completed_ops -gt 0 ]]; then - avg_time=$((total_time / completed_ops)) - avg_min=$((avg_time / 60)) - avg_sec=$((avg_time % 60)) - avg_str=$(printf "%02dm %02ds" $avg_min $avg_sec) - else - avg_str="N/A" - fi -else - avg_str="N/A" -fi - -# 生成统计信息摘要 -{ - echo "===================== flaggems测试统计摘要 =====================" - echo "开始时间: $(date -d @$start_time '+%Y-%m-%d %H:%M:%S')" - echo "结束时间: $(date '+%Y-%m-%d %H:%M:%S')" - echo "测试日期: $(date '+%Y-%m-%d')" - echo "总耗时: $time_str" - echo "--------------------------------------------------------" - echo "总算子数: $total_ops" - echo "成功用例数: $total_success" - echo "失败用例数: $total_failure" - echo "跳过用例数: $total_skipped" - echo "错误用例数: $total_error" - echo "完成用例数: $((total_success + total_failure + total_error))" - - if [[ $total_ops -gt 0 ]]; then - echo "完成率: $(( (total_success + total_failure + total_error) * 100 / total_ops ))%" - else - echo "完成率: N/A" - fi - - if [[ $total_success -gt 0 ]] || [[ $total_failure -gt 0 ]] || [[ $total_error -gt 0 ]]; then - success_rate=$(( total_success * 100 / (total_success + total_failure + total_error) )) - echo "成功率: ${success_rate}%" - else - echo "成功率: N/A" - fi - - echo "平均耗时/算子: $avg_str" - echo "--------------------------------------------------------" - echo "设备数量: $device_count" - echo "每设备线程数: $threads_per_device" - echo "========================================================" - echo "" -} | tee -a $SUMMARY_FILE # 追加到统计文件并同时输出到控制台 - -# 归档所有日志文件 -log_dirs=($(find . -maxdepth 1 -type d -name "device_*_logs" 2>/dev/null)) -if [ ${#log_dirs[@]} -gt 0 ]; then - echo "归档日志文件到 $LOG_ARCHIVE" - tar -czf "$LOG_ARCHIVE" "${log_dirs[@]}" - - if mv "$LOG_ARCHIVE" "$PR_LOG_DIR"; then - echo "日志已保存到: $PR_LOG_DIR/$LOG_ARCHIVE" - else - echo "警告:日志移动到 $PR_LOG_DIR 失败" - fi - - # 清理临时日志 - rm -rf "${log_dirs[@]}" -else - echo "警告:未找到任何日志目录,跳过归档" -fi - -echo "所有算子测试执行完成!" -echo "详细统计信息已追加到: $SUMMARY_FILE" -exit 0 diff --git a/third_party/ascend/examples/flaggems_cases/run_flaggems_test.sh b/third_party/ascend/examples/flaggems_cases/run_flaggems_test.sh deleted file mode 100644 index 7a54ccec6..000000000 --- a/third_party/ascend/examples/flaggems_cases/run_flaggems_test.sh +++ /dev/null @@ -1,10 +0,0 @@ - -TEST_flaggems="${WORKSPACE}/triton-ascend/ascend/examples/flaggems_cases" -cd ${TEST_flaggems} -git init -git clone https://gitee.com/leopold0801/flaggems.git -cd flaggems -git checkout 4f3f548 -mv ../op_test_run.sh ./ -ls -al -bash op_test_run.sh 16 32 diff --git a/third_party/ascend/examples/generalization_cases/full_run.sh b/third_party/ascend/examples/generalization_cases/full_run.sh deleted file mode 100755 index ac724c97f..000000000 --- a/third_party/ascend/examples/generalization_cases/full_run.sh +++ /dev/null @@ -1,41 +0,0 @@ -#!/bin/bash -current_date=$(date +%Y%m%d) -pid_log="process_${current_date}.log" -max_parallel=12 - -fifo="/tmp/$$.fifo" -mkfifo $fifo -exec 9<>$fifo -rm -f $fifo - -for ((i=0; i<$max_parallel; i++)); do - echo >&9 -done - -> "$pid_log" - - -if [ -d logs ]; then - rm -rf logs -fi - -mkdir logs - -while IFS= read -r -d $'\0' file; do - read -u 9 - - test_log="./logs/${file%.py}_${current_date}.log" - - { - pytest -sv "$file" -n 16 > "$test_log" 2>&1 - echo >&9 - } & - - echo "[INFO] Activated $(basename "$file"), PID=$!, logging into $test_log." - -done < <(find . -maxdepth 1 -type f -name "test_*.py" ! -name "test_common.py" -print0) - -wait -exec 9>&- - -echo "[INFO] All test processes completed, pids logged into ${pid_log}" diff --git a/third_party/ascend/examples/generalization_cases/test_device_print_op.py b/third_party/ascend/examples/generalization_cases/test_device_print_op.py deleted file mode 100644 index fdfa705a5..000000000 --- a/third_party/ascend/examples/generalization_cases/test_device_print_op.py +++ /dev/null @@ -1,147 +0,0 @@ -import torch -import torch_npu -import triton -import triton.language as tl -import pytest -import test_common - -import os - -os.environ["TRITON_DEVICE_PRINT"] = "1" -os.environ["TRITON_ENABLE_TASKQUEUE"] = "0" - -shape = (8, ) -XS = 8 -XVALS_INT = [ - 0, - torch.iinfo(torch.int8).min, - torch.iinfo(torch.int8).max, - torch.iinfo(torch.int16).min, - torch.iinfo(torch.int16).max, - torch.iinfo(torch.int32).min, - torch.iinfo(torch.int32).max, - torch.iinfo(torch.int32).max + 1 -] -XVALS_FP = [ - 0, - torch.finfo(torch.float32).eps, - torch.finfo(torch.float16).eps, - torch.finfo(torch.bfloat16).eps, - torch.finfo(torch.float32).max, - torch.finfo(torch.float16).max, - torch.finfo(torch.bfloat16).max, 1 -] - - -def torch_func(x0, x1): - res = x0 + x1 - return res - - -@triton.jit -def triton_kernel(out_ptr0, in_ptr0, in_ptr1, XBLOCK: tl.constexpr): - idx = tl.arange(0, XBLOCK) - tmp0 = tl.load(in_ptr0 + idx) - tmp1 = tl.load(in_ptr1 + idx) - tmp2 = tmp0 + tmp1 - tl.device_print("OUTPUT = ", tmp2) - tl.store(out_ptr0 + idx, tmp2) - - -def triton_func(x0, x1, XS): - out = torch.empty_like(x0) - triton_kernel[1, 1, 1](out, x0, x1, XS) - return out - - -@pytest.mark.skip(reason="waiting for bishengir-compile to support") -@pytest.mark.parametrize('sigtype', ['int64']) -@test_common.capture_output("???") -def test_device_print_int64(capsys, sigtype): - dtype = eval(f"torch.{sigtype}") - x0 = torch.zeros(shape, dtype=dtype).npu() - x1 = torch.ones(shape, dtype=dtype).npu() - for i in range(x1.numel()): - x1[i] = XVALS_INT[i] - torch_ref = torch_func(x0, x1) - triton_cal = triton_func(x0, x1, XS) - test_common.validate_cmp(sigtype, triton_cal, torch_ref) - - -@pytest.mark.parametrize('sigtype', ['int32']) -@test_common.capture_output("0,-128,127,-32768,32767,-2147483648,2147483647,-2147483648") -def test_device_print_int32(capsys, sigtype): - dtype = eval(f"torch.{sigtype}") - x0 = torch.zeros(shape, dtype=dtype).npu() - x1 = torch.ones(shape, dtype=dtype).npu() - for i in range(x1.numel()): - x1[i] = XVALS_INT[i] - torch_ref = torch_func(x0, x1) - triton_cal = triton_func(x0, x1, XS) - test_common.validate_cmp(sigtype, triton_cal, torch_ref) - - -@pytest.mark.parametrize('sigtype', ['int16']) -@test_common.capture_output("0,-128,127,-32768,32767,0,-1,0") -def test_device_print_int16(capsys, sigtype): - dtype = eval(f"torch.{sigtype}") - x0 = torch.zeros(shape, dtype=dtype).npu() - x1 = torch.ones(shape, dtype=dtype).npu() - for i in range(x1.numel()): - x1[i] = XVALS_INT[i] - torch_ref = torch_func(x0, x1) - triton_cal = triton_func(x0, x1, XS) - test_common.validate_cmp(sigtype, triton_cal, torch_ref) - - -@pytest.mark.parametrize('sigtype', ['int8']) -@test_common.capture_output("0,-128,127,0,-1,0,-1,0") -def test_device_print_int8(capsys, sigtype): - dtype = eval(f"torch.{sigtype}") - x0 = torch.zeros(shape, dtype=dtype).npu() - x1 = torch.ones(shape, dtype=dtype).npu() - for i in range(x1.numel()): - x1[i] = XVALS_INT[i] - torch_ref = torch_func(x0, x1) - triton_cal = triton_func(x0, x1, XS) - test_common.validate_cmp(sigtype, triton_cal, torch_ref) - - -@pytest.mark.parametrize('sigtype', ['float32']) -@test_common.capture_output("0,1.19209e-07,0.000976562,0.0078125,3.40282e+38,65504,3.38953e+38,1") -def test_device_print_fp32(capsys, sigtype): - dtype = eval(f"torch.{sigtype}") - x0 = torch.zeros(shape, dtype=dtype).npu() - x1 = torch.ones(shape, dtype=dtype).npu() - for i in range(x1.numel()): - x1[i] = XVALS_FP[i] - torch_ref = torch_func(x0, x1) - triton_cal = triton_func(x0, x1, XS) - test_common.validate_cmp(sigtype, triton_cal, torch_ref) - - -@pytest.mark.parametrize('sigtype', ['float16']) -@test_common.capture_output("0,1.19209e-07,0.000976562,0.0078125,inf,65504,inf,1") -def test_device_print_fp16(capsys, sigtype): - dtype = eval(f"torch.{sigtype}") - x0 = torch.zeros(shape, dtype=dtype).npu() - x1 = torch.ones(shape, dtype=dtype).npu() - for i in range(x1.numel()): - x1[i] = XVALS_FP[i] - torch_ref = torch_func(x0, x1) - triton_cal = triton_func(x0, x1, XS) - test_common.validate_cmp(sigtype, triton_cal, torch_ref) - - -@pytest.mark.skip(reason="waiting for bishengir-compile to support") -@pytest.mark.parametrize('sigtype', ['bfloat16']) -@test_common.capture_output("???") -def test_device_print_bf16(capsys, sigtype): - dtype = eval(f"torch.{sigtype}") - x0 = torch.zeros(shape, dtype=dtype).npu() - x1 = torch.ones(shape, dtype=dtype).npu() - for i in range(x1.numel()): - x1[i] = XVALS_FP[i] - torch_ref = torch_func(x0, x1) - triton_cal = triton_func(x0, x1, XS) - test_common.validate_cmp(sigtype, triton_cal, torch_ref) diff --git a/third_party/ascend/examples/generalization_cases/test_static_print_and_assert.py b/third_party/ascend/examples/generalization_cases/test_static_print_and_assert.py deleted file mode 100644 index 62105ae3f..000000000 --- a/third_party/ascend/examples/generalization_cases/test_static_print_and_assert.py +++ /dev/null @@ -1,144 +0,0 @@ -import torch -import torch_npu -import triton -import triton.language as tl -import pytest -import test_common -import functools -import os -import re - -shape = (8, ) -XS = 8 - -XVALS_INT = [ - 0, -128, # torch.iinfo(torch.int8).min - 127, # torch.iinfo(torch.int8).max - -32768, # torch.iinfo(torch.int16).min - 32767, # torch.iinfo(torch.int16).max - -2147483648, # torch.iinfo(torch.int32).min - 2147483647, # torch.iinfo(torch.int32).max - 9223372036854775807 -] # torch.iinfo(torch.int64).max - -XVALS_FP = [ - 0.0000000000e+00, # 0 - 1.1921000009e-07, # torch.finfo(torch.float32).eps - 9.7655999707e-04, # torch.finfo(torch.float16).eps - 7.8125000000e-03, # torch.finfo(torch.bfloat16).eps - 3.4027999388e+38, # torch.finfo(torch.float32).max - 6.5504000000e+04, # torch.finfo(torch.float16).max - 3.3894999515e+38, # torch.finfo(torch.bfloat16).max - 1.0000000000e+00 -] # 1 - - -def torch_func(x0, x1): - res = x0 + x1 - return res - - -@triton.jit -def triton_kernel(out_ptr0, in_ptr0, in_ptr1, XBLOCK: tl.constexpr, print_data_ptr: tl.constexpr, - assert_data_ptr: tl.constexpr): - idx = tl.arange(0, XBLOCK) - tmp0 = tl.load(in_ptr0 + idx) - tmp1 = tl.load(in_ptr1 + idx) - tmp2 = tmp0 + tmp1 - tl.static_print(print_data_ptr) - tl.static_assert(assert_data_ptr == assert_data_ptr, "assert_data should equal assert_data") - tl.store(out_ptr0 + idx, tmp2) - - -def triton_func(x0, x1, XS, print_data_ptr, assert_data_ptr): - out = torch.empty_like(x0) - triton_kernel[1, 1, 1](out, x0, x1, XS, print_data_ptr, assert_data_ptr) - return out - - -@pytest.mark.parametrize('sigtype', ['int8']) -@test_common.capture_output("-128") -def test_static_print_int8(capsys, sigtype): - dtype = eval(f"torch.{sigtype}") - x0 = torch.zeros(shape, dtype=dtype).npu() - x1 = torch.ones(shape, dtype=dtype).npu() - torch_ref = torch_func(x0, x1) - triton_cal = triton_func(x0, x1, XS, -128, XVALS_INT[0]) - test_common.validate_cmp(sigtype, triton_cal, torch_ref) - - -@pytest.mark.parametrize('sigtype', ['int16']) -@test_common.capture_output("-32768") -def test_static_print_int16(capsys, sigtype): - dtype = eval(f"torch.{sigtype}") - x0 = torch.zeros(shape, dtype=dtype).npu() - x1 = torch.ones(shape, dtype=dtype).npu() - torch_ref = torch_func(x0, x1) - triton_cal = triton_func(x0, x1, XS, -32768, XVALS_INT[2]) - test_common.validate_cmp(sigtype, triton_cal, torch_ref) - - -@pytest.mark.parametrize('sigtype', ['int32']) -@test_common.capture_output("-2147483648") -def test_static_print_int32(capsys, sigtype): - dtype = eval(f"torch.{sigtype}") - x0 = torch.zeros(shape, dtype=dtype).npu() - x1 = torch.ones(shape, dtype=dtype).npu() - torch_ref = torch_func(x0, x1) - triton_cal = triton_func(x0, x1, XS, -2147483648, XVALS_INT[4]) - test_common.validate_cmp(sigtype, triton_cal, torch_ref) - - -@pytest.mark.parametrize('sigtype', ['int64']) -@test_common.capture_output("9223372036854775807") -def test_static_print_int64(capsys, sigtype): - dtype = eval(f"torch.{sigtype}") - x0 = torch.zeros(shape, dtype=dtype).npu() - x1 = torch.ones(shape, dtype=dtype).npu() - torch_ref = torch_func(x0, x1) - triton_cal = triton_func(x0, x1, XS, 9223372036854775807, XVALS_INT[-1]) - test_common.validate_cmp(sigtype, triton_cal, torch_ref) - - -@pytest.mark.parametrize('sigtype', ['float16']) -@test_common.capture_output("1.1921000009e-07") -def test_static_print_float16(capsys, sigtype): - dtype = eval(f"torch.{sigtype}") - x0 = torch.zeros(shape, dtype=dtype).npu() - x1 = torch.ones(shape, dtype=dtype).npu() - torch_ref = torch_func(x0, x1) - triton_cal = triton_func(x0, x1, XS, 1.1921000009e-07, XVALS_FP[1]) - test_common.validate_cmp(sigtype, triton_cal, torch_ref) - - -@pytest.mark.parametrize('sigtype', ['float32']) -@test_common.capture_output("0.0078125") -def test_static_print_float32(capsys, sigtype): - dtype = eval(f"torch.{sigtype}") - x0 = torch.zeros(shape, dtype=dtype).npu() - x1 = torch.ones(shape, dtype=dtype).npu() - torch_ref = torch_func(x0, x1) - triton_cal = triton_func(x0, x1, XS, 7.8125000000e-03, XVALS_FP[0]) - test_common.validate_cmp(sigtype, triton_cal, torch_ref) - - -@pytest.mark.parametrize('sigtype', ['bfloat16']) -@test_common.capture_output("0.00097655999707") -def test_static_print_bfloat16(capsys, sigtype): - dtype = eval(f"torch.{sigtype}") - x0 = torch.zeros(shape, dtype=dtype).npu() - x1 = torch.ones(shape, dtype=dtype).npu() - torch_ref = torch_func(x0, x1) - triton_cal = triton_func(x0, x1, XS, 9.7655999707e-04, XVALS_FP[2]) - test_common.validate_cmp(sigtype, triton_cal, torch_ref) - - -@pytest.mark.parametrize('sigtype', ['int8']) -@test_common.capture_output("True") -def test_static_print_bool(capsys, sigtype): - dtype = eval(f"torch.{sigtype}") - x0 = torch.zeros(shape, dtype=dtype).npu() - x1 = torch.ones(shape, dtype=dtype).npu() - torch_ref = torch_func(x0, x1) - triton_cal = triton_func(x0, x1, XS, True, True) - test_common.validate_cmp(sigtype, triton_cal, torch_ref) diff --git a/third_party/ascend/examples/inductor_cases/run_inductor_test.sh b/third_party/ascend/examples/inductor_cases/run_inductor_test.sh deleted file mode 100644 index 1f4fcc5cf..000000000 --- a/third_party/ascend/examples/inductor_cases/run_inductor_test.sh +++ /dev/null @@ -1,138 +0,0 @@ -inductor_skip_list=( - "test_check_accuracy.py" - "test_debug_msg.py" - "test_embedding.py" - "test_force_fallback.py" - "test_foreach_add.py" - "test_geometric.py" - "test_lazy_register.py" -) - -TEST_inductor="${WORKSPACE}/triton-ascend/ascend/examples/inductor_cases" -# 定义统计文件路径 -SUMMARY_FILE="${WORKSPACE}/triton-ascend/ascend/examples/summary.txt" - -# install daily torch_npu -current_date=$(date +%Y%m%d) -wget https://pytorch-package.obs.cn-north-4.myhuaweicloud.com/pta/Daily/v2.6.0/${current_date}.2/pytorch_v2.6.0_py311.tar.gz -tar -zxvf pytorch_v2.6.0_py311.tar.gz -pip install *.dev${current_date}-cp311-cp311-manylinux_2_28_aarch64.whl - -cd ${TEST_inductor} -git init -git remote add origin https://gitcode.com/Ascend/pytorch.git -git config core.sparsecheckout true -echo "test/_inductor" >> .git/info/sparse-checkout -git pull origin v2.6.0:master -TEST_inductor_cases_path="${TEST_inductor}/test/_inductor" -cp ../conftest.py ${TEST_inductor_cases_path} -cd ${TEST_inductor_cases_path} -export PYTHONPATH="${PYTHONPATH}:${TEST_inductor_cases_path}" - -# 记录跳过的测试用例 -echo -e "\n======= Inductor 测试跳过的用例 =======" >> $SUMMARY_FILE -for skip_case in ${inductor_skip_list[@]}; -do - if [ -e "${TEST_inductor_cases_path}/${skip_case}" ];then - echo "跳过测试用例: ${skip_case}" | tee -a $SUMMARY_FILE - mv ${skip_case} "${skip_case}_skip" - fi -done - -# 创建临时日志目录 -LOG_DIR=$(mktemp -d) -INDUCTOR_CASE_LOG_FILE="$LOG_DIR/test_inductor_case_$(date +%Y%m%d).log" - -# 记录测试开始时间 -start_time=$(date +"%Y-%m-%d %H:%M:%S") - -# 执行测试并生成JUnit报告 -python -m pytest -n 16 --dist=loadfile . \ - --junitxml="$LOG_DIR/results.xml" \ - 2>&1 | tee "$INDUCTOR_CASE_LOG_FILE" - -# 解析统计信息 -# 使用Python解析JUnit XML报告 -python3 -c " -import xml.etree.ElementTree as ET -import os - -xml_file = '$LOG_DIR/results.xml' -if not os.path.exists(xml_file): - print('JUnitXML report not found:', xml_file) - exit(1) - -tree = ET.parse(xml_file) -root = tree.getroot() - -total_tests = 0 -passed_tests = 0 -failed_tests = 0 -skipped_tests = 0 -error_tests = 0 - -# 遍历所有testsuite -for testsuite in root.findall('testsuite'): - total_tests += int(testsuite.get('tests', 0)) - skipped_tests += int(testsuite.get('skipped', 0)) - error_tests += int(testsuite.get('errors', 0)) - failed_tests += int(testsuite.get('failures', 0)) - -# 计算通过用例数 -passed_tests = total_tests - skipped_tests - error_tests - failed_tests - -# 输出统计信息 -print(f'total_tests={total_tests}') -print(f'passed_tests={passed_tests}') -print(f'failed_tests={failed_tests}') -print(f'skipped_tests={skipped_tests}') -print(f'error_tests={error_tests}') -" > $LOG_DIR/stats.tmp - -# 加载统计结果 -source $LOG_DIR/stats.tmp -rm $LOG_DIR/stats.tmp - -# 计算测试持续时间 -end_time=$(date +"%Y-%m-%d %H:%M:%S") -duration=$(( $(date -d "$end_time" +%s) - $(date -d "$start_time" +%s) )) -duration_str=$(printf "%02dh %02dm %02ds" $((duration/3600)) $(((duration%3600)/60)) $((duration%60))) - -# 计算通过率 -if [ $total_tests -gt 0 ]; then - pass_rate=$(( 100 * passed_tests / total_tests )) -else - pass_rate=0 -fi - -# 生成统计信息摘要 -stats_summary=" -inductor 测试用例结果摘要: ------------------------- -开始时间: $start_time -结束时间: $end_time -总耗时: $duration_str ------------------------- -总用例数: $total_tests -成功用例: $passed_tests -失败用例: $failed_tests -跳过用例: $skipped_tests -错误用例: $error_tests ------------------------- -通过率: ${pass_rate}% (成功/总数) ------------------------- -" - -# 输出统计信息到控制台 -echo "$stats_summary" - -# 追加统计信息到summary.txt -echo "$stats_summary" >> $SUMMARY_FILE - -# 保存原始日志文件 -cp "$INDUCTOR_CASE_LOG_FILE" "/home/pr_test_log" - -# 清理临时文件 -rm -rf "$LOG_DIR" - -echo "测试统计信息已追加到: $SUMMARY_FILE" diff --git a/third_party/ascend/examples/model_cases/deberta.py b/third_party/ascend/examples/model_cases/deberta.py deleted file mode 100644 index b1ace76f1..000000000 --- a/third_party/ascend/examples/model_cases/deberta.py +++ /dev/null @@ -1,62 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import logging -import os - -import torch -import torch_npu -import torch_npu._inductor - -from transformers import AutoTokenizer, AutoModelForTokenClassification, AutoModelForSequenceClassification - -os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" - -logging.basicConfig(level=logging.DEBUG) - -torch.npu.config.allow_internal_format = False -torch.manual_seed(0) -torch.npu.manual_seed(0) -tokenizer = AutoTokenizer.from_pretrained("./microsoft/deberta-v3-large") - -sample_texts = ["This is a positive example.", "This might be negative."] * 128 - -model_ = AutoModelForTokenClassification.from_pretrained("./microsoft/deberta-v3-large", device_map="npu:0") -model_.eval() - -inputs = tokenizer(sample_texts, max_length=512, padding="longest", truncation=True, return_tensors="pt", - add_special_tokens=True).to("npu:0") - - -def model(**model_inputs): - with torch.no_grad(): - return model_(**model_inputs).logits - - -y = model(**inputs) -logging.info("result eager: " + str(torch.flatten(y)[:100])) - -model_compiled = torch.compile(model_) - -z = model_compiled(**inputs) -logging.info("result compiled: " + str(torch.flatten(z)[:100])) - -torch.testing.assert_close(y, z, atol=1e-4, rtol=1e-4) -logging.info("deberta accuracy check pass!") diff --git a/third_party/ascend/examples/pytest_ut/test_count_dim0.py b/third_party/ascend/examples/pytest_ut/test_count_dim0.py deleted file mode 100644 index 8b7039225..000000000 --- a/third_party/ascend/examples/pytest_ut/test_count_dim0.py +++ /dev/null @@ -1,191 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. -import pytest -import triton -import triton.language as tl -import time -import torch -import torch_npu -import test_common - - -def standard_count(x0, cmp_val, dim, dtype): - res = (x0 == cmp_val).sum(dim=dim) - return res - - -def standard_count_gt(x0, cmp_val, dim, dtype): - res = (x0 > cmp_val).sum(dim=dim) - return res - - -def standard_count_lt(x0, cmp_val, dim, dtype): - res = (x0 < cmp_val).sum(dim=dim) - return res - - -@triton.jit -def count(in_ptr0, out_ptr0, cmp_val, dim: tl.constexpr, M: tl.constexpr, N: tl.constexpr, MNUMEL: tl.constexpr, - NNUMEL: tl.constexpr): - mblk_idx = tl.arange(0, MNUMEL) - nblk_idx = tl.arange(0, NNUMEL) - mmask = mblk_idx < M - nmask = nblk_idx < N - mask = (mmask[:, None]) & (nmask[None, :]) - idx = mblk_idx[:, None] * N + nblk_idx[None, :] - x = tl.load(in_ptr0 + idx, mask=mask, other=0) - tmp1 = (x == cmp_val) - tmp2 = tmp1.to(tl.float32) - ret = tl.sum(tmp2, dim) - tl.store(out_ptr0 + nblk_idx, ret, mask=nmask) - - -@triton.jit -def count_gt(in_ptr0, out_ptr0, cmp_val, dim: tl.constexpr, M: tl.constexpr, N: tl.constexpr, MNUMEL: tl.constexpr, - NNUMEL: tl.constexpr): - mblk_idx = tl.arange(0, MNUMEL) - nblk_idx = tl.arange(0, NNUMEL) - mmask = mblk_idx < M - nmask = nblk_idx < N - mask = (mmask[:, None]) & (nmask[None, :]) - idx = mblk_idx[:, None] * N + nblk_idx[None, :] - x = tl.load(in_ptr0 + idx, mask=mask, other=0) - tmp1 = (x > cmp_val) - tmp2 = tmp1.to(tl.float32) - ret = tl.sum(tmp2, dim) - tl.store(out_ptr0 + nblk_idx, ret, mask=nmask) - - -@triton.jit -def count_lt(in_ptr0, out_ptr0, cmp_val, dim: tl.constexpr, M: tl.constexpr, N: tl.constexpr, MNUMEL: tl.constexpr, - NNUMEL: tl.constexpr): - mblk_idx = tl.arange(0, MNUMEL) - nblk_idx = tl.arange(0, NNUMEL) - mmask = mblk_idx < M - nmask = nblk_idx < N - mask = (mmask[:, None]) & (nmask[None, :]) - idx = mblk_idx[:, None] * N + nblk_idx[None, :] - x = tl.load(in_ptr0 + idx, mask=mask, other=0) - tmp1 = (x < cmp_val) - tmp2 = tmp1.to(tl.float32) - ret = tl.sum(tmp2, dim) - tl.store(out_ptr0 + nblk_idx, ret, mask=nmask) - - -shapes = [(57, 3, 64, 16), (57, -32, 64, 32), (57, 37, 64, 64), (64, 3, 64, 16), (64, -32, 64, 32), (64, 37, 64, 64), - (3, 3, 8, 8), (-32, 3, 32, 8), (37, 3, 64, 8), (3, 1, 8, 8), (-32, 1, 32, 8), (37, 1, 64, 8)] - -map_for_64_t = {37: (31, 32), 263: (107, 128)} -map_for_32_t = {263: (137, 256)} - -types0 = [ - (torch.int8, 'int8'), -] - - -@pytest.mark.parametrize('dtype, sigtype', types0) -@pytest.mark.parametrize('M, N, MNUMEL, NNUMEL', shapes) -def test_count_eq_dim0_common(dtype, sigtype, M, N, MNUMEL, NNUMEL): - M = (-M) // torch.tensor(0, dtype=dtype).element_size() if M < 0 else M - N = (-N) // torch.tensor(0, dtype=dtype).element_size() if N < 0 else N - - if sigtype == 'int64': - M = map_for_64_t[M][0] if M in map_for_64_t else M - MNUMEL = map_for_64_t[M][1] if M in map_for_64_t else MNUMEL - N = map_for_64_t[N][0] if N in map_for_64_t else N - NNUMEL = map_for_64_t[N][1] if N in map_for_64_t else NNUMEL - - elif sigtype == 'float32' or sigtype == 'bfloat16' or sigtype == 'int32': - M = map_for_32_t[M][0] if M in map_for_32_t else M - MNUMEL = map_for_32_t[M][1] if M in map_for_32_t else MNUMEL - N = map_for_32_t[N][0] if N in map_for_32_t else N - NNUMEL = map_for_32_t[N][1] if N in map_for_32_t else NNUMEL - - print(f"sum : ({M}, {N}) {dtype} {sigtype}") - cmp_val = 8 - x0 = test_common.generate_tensor(shape=(M, N), dtype=sigtype) - ans = standard_count(x0, cmp_val, 0, dtype) - x0 = x0.npu() - print(ans) - output = torch.zeros((N, ), dtype=torch.float32).npu() - count[1, 1, 1](x0, output, cmp_val, 0, M=M, N=N, MNUMEL=MNUMEL, NNUMEL=NNUMEL, debug=True) - print(output) - test_common.validate_cmp('float32', output, ans.to(torch.float32)) - - -#------------------------------------------------------------------------------------- - -types1 = [ - (torch.float32, 'float32'), - (torch.float32, 'float16'), - (torch.int8, 'int8'), -] - - -@pytest.mark.parametrize('dtype, sigtype', types1) -@pytest.mark.parametrize('M, N, MNUMEL, NNUMEL', shapes) -def test_count_gt_dim0_common(dtype, sigtype, M, N, MNUMEL, NNUMEL): - M = (-M) // torch.tensor(0, dtype=dtype).element_size() if M < 0 else M - N = (-N) // torch.tensor(0, dtype=dtype).element_size() if N < 0 else N - - if sigtype == 'int64': - M = map_for_64_t[M][0] if M in map_for_64_t else M - MNUMEL = map_for_64_t[M][1] if M in map_for_64_t else MNUMEL - N = map_for_64_t[N][0] if N in map_for_64_t else N - NNUMEL = map_for_64_t[N][1] if N in map_for_64_t else NNUMEL - - elif sigtype == 'float32' or sigtype == 'bfloat16' or sigtype == 'int32': - M = map_for_32_t[M][0] if M in map_for_32_t else M - MNUMEL = map_for_32_t[M][1] if M in map_for_32_t else MNUMEL - N = map_for_32_t[N][0] if N in map_for_32_t else N - NNUMEL = map_for_32_t[N][1] if N in map_for_32_t else NNUMEL - - print(f"sum : ({M}, {N}) {dtype} {sigtype}") - if dtype == torch.int8: - cmp_val = 8 - else: - cmp_val = 0.5 - x0 = test_common.generate_tensor(shape=(M, N), dtype=sigtype) - ans = standard_count_gt(x0, cmp_val, 0, dtype) - x0 = x0.npu() - print(ans) - output = torch.zeros((N, ), dtype=torch.float32).npu() - count_gt[1, 1, 1](x0, output, cmp_val, 0, M=M, N=N, MNUMEL=MNUMEL, NNUMEL=NNUMEL, debug=True) - print(output) - test_common.validate_cmp("float32", output, ans.to(torch.float32)) - - -shapes1 = [(64, 3, 64, 16), (64, -32, 64, 32), (64, 37, 64, 64)] - - -@pytest.mark.parametrize('dtype, sigtype', types1) -@pytest.mark.parametrize('M, N, MNUMEL, NNUMEL', shapes1) -def test_count_lt_dim0_common(dtype, sigtype, M, N, MNUMEL, NNUMEL): - M = (-M) // torch.tensor(0, dtype=dtype).element_size() if M < 0 else M - N = (-N) // torch.tensor(0, dtype=dtype).element_size() if N < 0 else N - - if sigtype == 'int64': - M = map_for_64_t[M][0] if M in map_for_64_t else M - MNUMEL = map_for_64_t[M][1] if M in map_for_64_t else MNUMEL - N = map_for_64_t[N][0] if N in map_for_64_t else N - NNUMEL = map_for_64_t[N][1] if N in map_for_64_t else NNUMEL - - elif sigtype == 'float32' or sigtype == 'bfloat16' or sigtype == 'int32': - M = map_for_32_t[M][0] if M in map_for_32_t else M - MNUMEL = map_for_32_t[M][1] if M in map_for_32_t else MNUMEL - N = map_for_32_t[N][0] if N in map_for_32_t else N - NNUMEL = map_for_32_t[N][1] if N in map_for_32_t else NNUMEL - - print(f"sum : ({M}, {N}) {dtype} {sigtype}") - if dtype == torch.int8: - cmp_val = 8 - else: - cmp_val = 0.5 - x0 = test_common.generate_tensor(shape=(M, N), dtype=sigtype) - ans = standard_count_lt(x0, cmp_val, 0, dtype) - x0 = x0.npu() - print(ans) - output = torch.zeros((N, ), dtype=torch.float32).npu() - count_lt[1, 1, 1](x0, output, cmp_val, 0, M=M, N=N, MNUMEL=MNUMEL, NNUMEL=NNUMEL, debug=True) - print(output) - test_common.validate_cmp("float32", output, ans.to(torch.float32)) diff --git a/third_party/ascend/examples/pytest_ut/test_count_dim1.py b/third_party/ascend/examples/pytest_ut/test_count_dim1.py deleted file mode 100644 index 06afb93ce..000000000 --- a/third_party/ascend/examples/pytest_ut/test_count_dim1.py +++ /dev/null @@ -1,195 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. -import pytest -import triton -import triton.language as tl -import time - -import torch -import torch_npu -import test_common - - -def standard_count(x0, cmp_val, dim, dtype): - res = (x0 == cmp_val).sum(dim=dim) - return res - - -def standard_count_gt(x0, cmp_val, dim, dtype): - res = (x0 > cmp_val).sum(dim=dim) - return res - - -def standard_count_lt(x0, cmp_val, dim, dtype): - res = (x0 < cmp_val).sum(dim=dim) - return res - - -@triton.jit -def count(in_ptr0, out_ptr0, cmp_val, dim: tl.constexpr, M: tl.constexpr, N: tl.constexpr, MNUMEL: tl.constexpr, - NNUMEL: tl.constexpr): - mblk_idx = tl.arange(0, MNUMEL) - nblk_idx = tl.arange(0, NNUMEL) - mmask = mblk_idx < M - nmask = nblk_idx < N - mask = (mmask[:, None]) & (nmask[None, :]) - idx = mblk_idx[:, None] * N + nblk_idx[None, :] - x = tl.load(in_ptr0 + idx, mask=mask, other=0) - tmp1 = (x == cmp_val) - tmp2 = tmp1.to(tl.float32) - ret = tl.sum(tmp2, dim) - tl.store(out_ptr0 + mblk_idx, ret, mask=mmask) - - -@triton.jit -def count_gt(in_ptr0, out_ptr0, cmp_val, dim: tl.constexpr, M: tl.constexpr, N: tl.constexpr, MNUMEL: tl.constexpr, - NNUMEL: tl.constexpr): - mblk_idx = tl.arange(0, MNUMEL) - nblk_idx = tl.arange(0, NNUMEL) - mmask = mblk_idx < M - nmask = nblk_idx < N - mask = (mmask[:, None]) & (nmask[None, :]) - idx = mblk_idx[:, None] * N + nblk_idx[None, :] - x = tl.load(in_ptr0 + idx, mask=mask, other=0) - tmp1 = (x > cmp_val) - tmp2 = tmp1.to(tl.float32) - ret = tl.sum(tmp2, dim) - tl.store(out_ptr0 + mblk_idx, ret, mask=mmask) - - -@triton.jit -def count_lt(in_ptr0, out_ptr0, cmp_val, dim: tl.constexpr, M: tl.constexpr, N: tl.constexpr, MNUMEL: tl.constexpr, - NNUMEL: tl.constexpr): - mblk_idx = tl.arange(0, MNUMEL) - nblk_idx = tl.arange(0, NNUMEL) - mmask = mblk_idx < M - nmask = nblk_idx < N - mask = (mmask[:, None]) & (nmask[None, :]) - idx = mblk_idx[:, None] * N + nblk_idx[None, :] - x = tl.load(in_ptr0 + idx, mask=mask, other=0) - tmp1 = (x < cmp_val) - tmp2 = tmp1.to(tl.float32) - ret = tl.sum(tmp2, dim) - tl.store(out_ptr0 + mblk_idx, ret, mask=mmask) - - -# if shape axis = 32/256 , then actual shape = axis/element_size() - -shapes = [(57, 3, 64, 16), (57, -32, 64, 32), (64, 3, 64, 16), (64, -32, 64, 32), (3, 3, 8, 8), (-32, 3, 32, 8), - (37, 3, 64, 8), (3, 1, 8, 8), (-32, 1, 32, 8), (37, 1, 64, 8)] - -map_for_64_t = {37: (31, 32), 263: (107, 128)} -map_for_32_t = {263: (137, 256)} - -types0 = [ - (torch.int8, 'int8'), -] - - -@pytest.mark.parametrize('dtype, sigtype', types0) -@pytest.mark.parametrize('M, N, MNUMEL, NNUMEL', shapes) -def test_count_eq_dim0_common(dtype, sigtype, M, N, MNUMEL, NNUMEL): - M = (-M) // torch.tensor(0, dtype=dtype).element_size() if M < 0 else M - N = (-N) // torch.tensor(0, dtype=dtype).element_size() if N < 0 else N - - if sigtype == 'int64': - M = map_for_64_t[M][0] if M in map_for_64_t else M - MNUMEL = map_for_64_t[M][1] if M in map_for_64_t else MNUMEL - N = map_for_64_t[N][0] if N in map_for_64_t else N - NNUMEL = map_for_64_t[N][1] if N in map_for_64_t else NNUMEL - - elif sigtype == 'float32' or sigtype == 'bfloat16' or sigtype == 'int32': - M = map_for_32_t[M][0] if M in map_for_32_t else M - MNUMEL = map_for_32_t[M][1] if M in map_for_32_t else MNUMEL - N = map_for_32_t[N][0] if N in map_for_32_t else N - NNUMEL = map_for_32_t[N][1] if N in map_for_32_t else NNUMEL - - print(f"sum : ({M}, {N}) {dtype} {sigtype}") - cmp_val = 8 - x0 = test_common.generate_tensor(shape=(M, N), dtype=sigtype) - ans = standard_count(x0, cmp_val, 1, dtype) - x0 = x0.npu() - print(ans) - output = torch.zeros((M, ), dtype=torch.float32).npu() - count[1, 1, 1](x0, output, cmp_val, 1, M=M, N=N, MNUMEL=MNUMEL, NNUMEL=NNUMEL, debug=True) - print(output) - test_common.validate_cmp('float32', output, ans.to(torch.float32)) - - -#------------------------------------------------------------------------------------- - -types1 = [ - (torch.float32, 'float32'), - (torch.float32, 'float16'), - (torch.int8, 'int8'), -] - - -@pytest.mark.parametrize('dtype, sigtype', types1) -@pytest.mark.parametrize('M, N, MNUMEL, NNUMEL', shapes) -def test_count_gt_dim0_common(dtype, sigtype, M, N, MNUMEL, NNUMEL): - M = (-M) // torch.tensor(0, dtype=dtype).element_size() if M < 0 else M - N = (-N) // torch.tensor(0, dtype=dtype).element_size() if N < 0 else N - - if sigtype == 'int64': - M = map_for_64_t[M][0] if M in map_for_64_t else M - MNUMEL = map_for_64_t[M][1] if M in map_for_64_t else MNUMEL - N = map_for_64_t[N][0] if N in map_for_64_t else N - NNUMEL = map_for_64_t[N][1] if N in map_for_64_t else NNUMEL - - elif sigtype == 'float32' or sigtype == 'bfloat16' or sigtype == 'int32': - M = map_for_32_t[M][0] if M in map_for_32_t else M - MNUMEL = map_for_32_t[M][1] if M in map_for_32_t else MNUMEL - N = map_for_32_t[N][0] if N in map_for_32_t else N - NNUMEL = map_for_32_t[N][1] if N in map_for_32_t else NNUMEL - - print(f"sum : ({M}, {N}) {dtype} {sigtype}") - if dtype == torch.int8: - cmp_val = 8 - else: - cmp_val = 0.5 - x0 = test_common.generate_tensor(shape=(M, N), dtype=sigtype) - ans = standard_count_gt(x0, cmp_val, 1, dtype) - x0 = x0.npu() - print(ans) - output = torch.zeros((M, ), dtype=torch.float32).npu() - count_gt[1, 1, 1](x0, output, cmp_val, 1, M=M, N=N, MNUMEL=MNUMEL, NNUMEL=NNUMEL, debug=True) - print(output) - test_common.validate_cmp('float32', output, ans.to(torch.float32)) - - -types2 = [(torch.int8, 'int8')] -shapes2 = [(57, -32, 64, 32), (64, -32, 64, 32)] - - -@pytest.mark.parametrize('dtype, sigtype', types2) -@pytest.mark.parametrize('M, N, MNUMEL, NNUMEL', shapes2) -def test_count_lt_dim0_common(dtype, sigtype, M, N, MNUMEL, NNUMEL): - M = (-M) // torch.tensor(0, dtype=dtype).element_size() if M < 0 else M - N = (-N) // torch.tensor(0, dtype=dtype).element_size() if N < 0 else N - - if sigtype == 'int64': - M = map_for_64_t[M][0] if M in map_for_64_t else M - MNUMEL = map_for_64_t[M][1] if M in map_for_64_t else MNUMEL - N = map_for_64_t[N][0] if N in map_for_64_t else N - NNUMEL = map_for_64_t[N][1] if N in map_for_64_t else NNUMEL - - elif sigtype == 'float32' or sigtype == 'bfloat16' or sigtype == 'int32': - M = map_for_32_t[M][0] if M in map_for_32_t else M - MNUMEL = map_for_32_t[M][1] if M in map_for_32_t else MNUMEL - N = map_for_32_t[N][0] if N in map_for_32_t else N - NNUMEL = map_for_32_t[N][1] if N in map_for_32_t else NNUMEL - - print(f"sum : ({M}, {N}) {dtype} {sigtype}") - if dtype == torch.int8: - cmp_val = 8 - else: - cmp_val = 0.5 - x0 = test_common.generate_tensor(shape=(M, N), dtype=sigtype) - ans = standard_count_lt(x0, cmp_val, 1, dtype) - x0 = x0.npu() - print(ans) - output = torch.zeros((M, ), dtype=torch.float32).npu() - count_lt[1, 1, 1](x0, output, cmp_val, 1, M=M, N=N, MNUMEL=MNUMEL, NNUMEL=NNUMEL, debug=True) - print(output) - test_common.validate_cmp('float32', output, ans.to(torch.float32)) diff --git a/third_party/ascend/examples/pytest_ut_regbase/README.md b/third_party/ascend/examples/pytest_ut_regbase/README.md deleted file mode 100644 index 282ed3fd6..000000000 --- a/third_party/ascend/examples/pytest_ut_regbase/README.md +++ /dev/null @@ -1,7 +0,0 @@ -# Test Guide - -When you already set up bisheng-triton in this branch `regbase`, you can run the script `run_test.sh` to run the unit tests in this directory. - -Notes: - -- It seems that 310B4 can run max 7 threads at the same time. So the `-n` option of pytest should be no more than 7. Otherwise, the output of your kernel maybe empty. diff --git a/third_party/ascend/examples/pytest_ut_regbase/run_test.sh b/third_party/ascend/examples/pytest_ut_regbase/run_test.sh deleted file mode 100755 index 8acb49538..000000000 --- a/third_party/ascend/examples/pytest_ut_regbase/run_test.sh +++ /dev/null @@ -1,6 +0,0 @@ -#!/bin/bash - -TRITON_ASCEND_ARCH=Ascend310B4 \ -TRITON_ENABLE_TASKQUEUE=0 \ -pytest -n 7 --dist=load \ -. diff --git a/third_party/ascend/examples/pytest_ut_regbase/test_binary.py b/third_party/ascend/examples/pytest_ut_regbase/test_binary.py deleted file mode 100644 index 53cc87162..000000000 --- a/third_party/ascend/examples/pytest_ut_regbase/test_binary.py +++ /dev/null @@ -1,168 +0,0 @@ -import triton -import triton.language as tl -import torch -import pytest -import atexit -from test_common import ( - generate_tensor, - validate_cmp, - fill_zero_with_one, - _float_dtypes_without_bf16 as _float_dtypes, - _int_dtypes, - _shape_1d, -) - -################################################# -FUNCTIONS_TO_TEST = { - 'add': ("x0 + x1", "x0 + x1"), - 'sub': ("x0 - x1", "x0 - x1"), - 'mul': ("x0 * x1", "x0 * x1"), - 'div': - ("x0 / x1", - "x0 / x1 if x0.dtype in [torch.float32, torch.float16] else (x0.to(torch.float32) / x1.to(torch.float32)).to(x0.dtype)" - ), - 'floordiv': ("x0 // x1", "x0 // x1"), - 'mod': ("x0 % x1", "(x0.cpu() % x1.cpu()).npu() if x0.dtype in [torch.int8, torch.int16] else x0 % x1"), - 'and': ("x0 & x1", "x0 & x1"), - 'or': ("x0 | x1", "x0 | x1"), - 'xor': ("x0 ^ x1", "x0 ^ x1"), - 'gt': ("x0 > x1", "x0 > x1"), - 'ge': ("x0 >= x1", "x0 >= x1"), - 'lt': ("x0 < x1", "x0 < x1"), - 'le': ("x0 <= x1", "x0 <= x1"), - 'eq': ("x0 == x1", "x0 == x1"), - 'ne': ("x0 != x1", "x0 != x1"), - 'cdiv': ("tl.cdiv(x0, x1)", "( (x0.cpu() + x1.cpu() - 1) // x1.cpu() ).npu()"), - 'fdiv': ("tl.fdiv(x0, x1)", "x0 / x1"), - 'div_rn': ("tl.div_rn(x0, x1)", "x0 / x1"), - 'logical_and': ("tl.logical_and(x0, x1)", "torch.logical_and(x0, x1)"), - 'logical_or': ("tl.logical_or(x0, x1)", "torch.logical_or(x0, x1)"), - 'maximum': ("tl.maximum(x0, x1)", "torch.maximum(x0, x1)"), - 'minimum': ("tl.minimum(x0, x1)", "torch.minimum(x0, x1)"), -} -################################################# - -# Global dictionary to keep track of temporary files -_temp_kernel_files = {} - - -def _cleanup_temp_files(): - import os - for file_path in _temp_kernel_files.values(): - try: - if os.path.exists(file_path): - os.unlink(file_path) - except: - pass - - -atexit.register(_cleanup_temp_files) - - -def create_triton_kernel(func_name, func_pattern): - import tempfile - import os - kernel_source = f""" -import triton -import triton.language as tl - -@triton.jit -def triton_kernel(in_ptr0, in_ptr1, out_ptr0, numel, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): - offset = tl.program_id(0) * XBLOCK - base = tl.arange(0, XBLOCK_SUB) - num_loop: tl.constexpr = (XBLOCK + XBLOCK_SUB - 1) // XBLOCK_SUB - for loop in range(num_loop): - idx = offset + loop * XBLOCK_SUB + base - msk = idx < numel - x0 = tl.load(in_ptr0 + idx, mask=msk) - x1 = tl.load(in_ptr1 + idx, mask=msk) - y0 = {func_pattern} - tl.store(out_ptr0 + idx, y0, mask=msk) -""" - - # Create a temporary file with a unique name based on the function name - if func_name in _temp_kernel_files: - temp_file_path = _temp_kernel_files[func_name] - else: - fd, temp_file_path = tempfile.mkstemp(suffix='.py', prefix=f'triton_kernel_{func_name}_') - os.close(fd) # We don't need the file descriptor - _temp_kernel_files[func_name] = temp_file_path - - # Write the kernel source to the file - with open(temp_file_path, 'w') as f: - f.write(kernel_source) - - # Import the kernel from the temporary file - import importlib.util - module_name = f"triton_kernel_{func_name.replace('.', '_')}" - spec = importlib.util.spec_from_file_location(module_name, temp_file_path) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - return module.triton_kernel - - -@pytest.mark.parametrize('dtype', _float_dtypes + _int_dtypes) -@pytest.mark.parametrize('xblock_sub', _shape_1d) -@pytest.mark.parametrize('func_name', FUNCTIONS_TO_TEST.keys()) -def test_binary(dtype, xblock_sub, func_name): - - if dtype in _int_dtypes: - skip_int_dtype_ops = [ - 'div_rn', - 'fdiv', - ] - if func_name in skip_int_dtype_ops: - pytest.skip(f"{func_name} only tested with float dtypes") - if dtype in _float_dtypes: - skip_float_dtype_ops = [ - 'mod', - 'floordiv', - 'cdiv', - 'and', - 'or', - 'xor', - ] - if func_name in skip_float_dtype_ops: - pytest.skip(f"{func_name} only tested with int dtypes") - if func_name == 'mod' and dtype == 'int64': - pytest.skip(f"{func_name} skips int64") - if func_name in ['logical_and', 'logical_or'] and dtype != 'bool': - pytest.skip(f"{func_name} tests only bool") - - xblock = triton.next_power_of_2(xblock_sub) - shape = (xblock, ) - triton_func_op, torch_func_op = FUNCTIONS_TO_TEST[func_name] - - def torch_func(x0, x1): - return eval(torch_func_op) - - def get_autotune_config(): - return [ - triton.Config({'XBLOCK': xblock, 'XBLOCK_SUB': xblock_sub}), - ] - - triton_kernel = create_triton_kernel(func_name, triton_func_op) - triton_kernel = triton.autotune( - configs=get_autotune_config(), - key=['numel'], - )(triton_kernel) - - def triton_func(x0, x1, out_dtype): - y0 = torch.empty_like(x0).to(out_dtype) - numel = x0.numel() - grid = lambda meta: (triton.cdiv(numel, meta['XBLOCK']), ) - triton_kernel[grid](x0, x1, y0, numel) - return y0 - - x0 = generate_tensor(shape, dtype).npu() - x1 = generate_tensor(shape, dtype).npu() - - if func_name in ['div', 'fdiv', 'div_rn', 'cdiv', 'floordiv', 'mod']: - x1 = fill_zero_with_one(x1) - out_dtype = x0.dtype - if func_name in ['gt', 'gt', 'lt', 'le', 'eq', 'ne']: - out_dtype = torch.bool - - triton_cal = triton_func(x0, x1, out_dtype) - torch_ref = torch_func(x0, x1) - validate_cmp(dtype, triton_cal, torch_ref) diff --git a/third_party/ascend/examples/pytest_ut_regbase/test_binary_2d.py b/third_party/ascend/examples/pytest_ut_regbase/test_binary_2d.py deleted file mode 100644 index 983b1f186..000000000 --- a/third_party/ascend/examples/pytest_ut_regbase/test_binary_2d.py +++ /dev/null @@ -1,178 +0,0 @@ -import triton -import triton.language as tl -import torch -import pytest -import atexit -from test_common import ( - generate_tensor, - validate_cmp, - fill_zero_with_one, - _float_dtypes_without_bf16 as _float_dtypes, - _int_dtypes, - _shape_1d, -) - -################################################# -FUNCTIONS_TO_TEST = { - 'add': ("x0 + x1", "x0 + x1"), - 'sub': ("x0 - x1", "x0 - x1"), - 'mul': ("x0 * x1", "x0 * x1"), - 'div': - ("x0 / x1", - "x0 / x1 if x0.dtype in [torch.float32, torch.float16] else (x0.to(torch.float32) / x1.to(torch.float32)).to(x0.dtype)" - ), - 'floordiv': ("x0 // x1", "x0 // x1"), - 'cdiv': ("tl.cdiv(x0, x1)", "( (x0.cpu() + x1.cpu() - 1) // x1.cpu() ).npu()"), - 'fdiv': ("tl.fdiv(x0, x1)", "x0 / x1"), - 'div_rn': ("tl.div_rn(x0, x1)", "x0 / x1"), - 'mod': ("x0 % x1", "(x0.cpu() % x1.cpu()).npu() if x0.dtype in [torch.int8, torch.int16] else x0 % x1"), - # 'and': ("x0 & x1", "x0 & x1"), - # 'or': ("x0 | x1", "x0 | x1"), - # 'xor': ("x0 ^ x1", "x0 ^ x1"), - # 'gt': ("x0 > x1", "x0 > x1"), - # 'ge': ("x0 >= x1", "x0 >= x1"), - # 'lt': ("x0 < x1", "x0 < x1"), - # 'le': ("x0 <= x1", "x0 <= x1"), - # 'eq': ("x0 == x1", "x0 == x1"), - # 'ne': ("x0 != x1", "x0 != x1"), - # 'logical_and': ("tl.logical_and(x0, x1)", "torch.logical_and(x0, x1)"), - # 'logical_or': ("tl.logical_or(x0, x1)", "torch.logical_or(x0, x1)"), - # 'maximum': ("tl.maximum(x0, x1)", "torch.maximum(x0, x1)"), - # 'minimum': ("tl.minimum(x0, x1)", "torch.minimum(x0, x1)"), -} -################################################# - -# Global dictionary to keep track of temporary files -_temp_kernel_files = {} - - -def _cleanup_temp_files(): - import os - for file_path in _temp_kernel_files.values(): - try: - if os.path.exists(file_path): - os.unlink(file_path) - except: - pass - - -atexit.register(_cleanup_temp_files) - - -def create_triton_kernel(func_name, func_pattern): - import tempfile - import os - kernel_source = f""" -import triton -import triton.language as tl - -@triton.jit -def triton_kernel(in_ptr0, in_ptr1, out_ptr0, xnumel, ynumel, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr, YBLOCK: tl.constexpr): - offset = tl.program_id(0) * XBLOCK - xbase = tl.arange(0, XBLOCK_SUB) - ybase = tl.arange(0, YBLOCK) - num_loop: tl.constexpr = (XBLOCK + XBLOCK_SUB - 1) // XBLOCK_SUB - for loop in range(num_loop): - xidx = offset + loop * XBLOCK_SUB + xbase - xmsk = xidx < xnumel - yidx = ybase - ymsk = yidx < ynumel - idx = xidx[:, None] * ynumel + yidx[None, :] - msk = xmsk[:, None] & ymsk[None, :] - x0 = tl.load(in_ptr0 + idx, mask=msk) - x1 = tl.load(in_ptr1 + idx, mask=msk) - y0 = {func_pattern} - tl.store(out_ptr0 + idx, y0, mask=msk) -""" - - # Create a temporary file with a unique name based on the function name - if func_name in _temp_kernel_files: - temp_file_path = _temp_kernel_files[func_name] - else: - fd, temp_file_path = tempfile.mkstemp(suffix='.py', prefix=f'triton_kernel_{func_name}_') - os.close(fd) # We don't need the file descriptor - _temp_kernel_files[func_name] = temp_file_path - - # Write the kernel source to the file - with open(temp_file_path, 'w') as f: - f.write(kernel_source) - - # Import the kernel from the temporary file - import importlib.util - module_name = f"triton_kernel_{func_name.replace('.', '_')}" - spec = importlib.util.spec_from_file_location(module_name, temp_file_path) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - return module.triton_kernel - - -# @pytest.mark.parametrize('dtype', _float_dtypes + _int_dtypes) -@pytest.mark.parametrize('dtype', ['float32']) -@pytest.mark.parametrize('xblock_sub', [3]) -@pytest.mark.parametrize('yblock', _shape_1d) -@pytest.mark.parametrize('func_name', FUNCTIONS_TO_TEST.keys()) -def test_binary(dtype, xblock_sub, yblock, func_name): - - if dtype in _int_dtypes: - skip_int_dtype_ops = [ - 'div_rn', - 'fdiv', - ] - if func_name in skip_int_dtype_ops: - pytest.skip(f"{func_name} only tested with float dtypes") - if dtype in _float_dtypes: - skip_float_dtype_ops = [ - 'mod', - 'floordiv', - 'cdiv', - 'and', - 'or', - 'xor', - ] - if func_name in skip_float_dtype_ops: - pytest.skip(f"{func_name} only tested with int dtypes") - if func_name == 'mod' and dtype == 'int64': - pytest.skip(f"{func_name} skips int64") - if func_name in ['logical_and', 'logical_or'] and dtype != 'bool': - pytest.skip(f"{func_name} tests only bool") - - xblock = triton.next_power_of_2(xblock_sub) - shape = (xblock, yblock) - triton_func_op, torch_func_op = FUNCTIONS_TO_TEST[func_name] - - def torch_func(x0, x1): - return eval(torch_func_op) - - def get_autotune_config(): - return [ - triton.Config({'XBLOCK': xblock, 'XBLOCK_SUB': xblock_sub, 'YBLOCK': yblock}), - ] - - triton_kernel = create_triton_kernel(func_name, triton_func_op) - triton_kernel = triton.autotune( - configs=get_autotune_config(), - key=['xnumel'], - )(triton_kernel) - - def triton_func(x0, x1, out_dtype): - xnumel, ynumel = x0.shape - y0 = torch.empty_like(x0).to(out_dtype) - grid = lambda meta: (triton.cdiv(xnumel, meta['XBLOCK']), ) - triton_kernel[grid](x0, x1, y0, xnumel, ynumel) - return y0 - - x0 = generate_tensor(shape, dtype).npu() - x1 = generate_tensor(shape, dtype).npu() - - if func_name in ['div', 'fdiv', 'div_rn', 'cdiv', 'floordiv', 'mod']: - x1 = fill_zero_with_one(x1) - out_dtype = x0.dtype - if func_name in ['gt', 'gt', 'lt', 'le', 'eq', 'ne']: - out_dtype = torch.bool - - triton_cal = triton_func(x0, x1, out_dtype) - torch_ref = torch_func(x0, x1) - validate_cmp(dtype, triton_cal, torch_ref) - - -test_binary('float32', 3, 741, 'add') diff --git a/third_party/ascend/examples/pytest_ut_regbase/test_cast.py b/third_party/ascend/examples/pytest_ut_regbase/test_cast.py deleted file mode 100644 index cc47f5833..000000000 --- a/third_party/ascend/examples/pytest_ut_regbase/test_cast.py +++ /dev/null @@ -1,137 +0,0 @@ -import triton -import triton.language as tl -import torch -import pytest -import atexit -from test_common import ( - generate_tensor, - validate_cmp, - fill_negative_with_one, - _float_dtypes_without_bf16 as _float_dtypes, - _int_dtypes, - _shape_1d, -) - -################################################# -FUNCTIONS_TO_TEST = { - 'cast': ("x0.to(dst_dtype) + x1", "x0.to(dst_dtype) + x1"), -} -################################################# - -# Global dictionary to keep track of temporary files -_temp_kernel_files = {} - - -def _cleanup_temp_files(): - import os - for file_path in _temp_kernel_files.values(): - try: - if os.path.exists(file_path): - os.unlink(file_path) - except: - pass - - -atexit.register(_cleanup_temp_files) - - -def create_triton_kernel(func_name, func_pattern): - import tempfile - import os - kernel_source = f""" -import triton -import triton.language as tl - -@triton.jit -def triton_kernel(in_ptr0, in_ptr1, out_ptr0, numel, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr, dst_dtype: tl.constexpr): - offset = tl.program_id(0) * XBLOCK - base = tl.arange(0, XBLOCK_SUB) - num_loop: tl.constexpr = (XBLOCK + XBLOCK_SUB - 1) // XBLOCK_SUB - for loop in range(num_loop): - idx = offset + loop * XBLOCK_SUB + base - msk = idx < numel - x0 = tl.load(in_ptr0 + idx, mask=msk) - x1 = tl.load(in_ptr1 + idx, mask=msk) - y0 = {func_pattern} - tl.store(out_ptr0 + idx, y0, mask=msk) -""" - - # Create a temporary file with a unique name based on the function name - if func_name in _temp_kernel_files: - temp_file_path = _temp_kernel_files[func_name] - else: - fd, temp_file_path = tempfile.mkstemp(suffix='.py', prefix=f'triton_kernel_{func_name}_') - os.close(fd) # We don't need the file descriptor - _temp_kernel_files[func_name] = temp_file_path - - # Write the kernel source to the file - with open(temp_file_path, 'w') as f: - f.write(kernel_source) - - # Import the kernel from the temporary file - import importlib.util - module_name = f"triton_kernel_{func_name.replace('.', '_')}" - spec = importlib.util.spec_from_file_location(module_name, temp_file_path) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - return module.triton_kernel - - -@pytest.mark.parametrize('src_dtype', ['float32']) -@pytest.mark.parametrize('dst_dtype', _float_dtypes + _int_dtypes) -@pytest.mark.parametrize('xblock_sub', _shape_1d) -@pytest.mark.parametrize('func_name', FUNCTIONS_TO_TEST.keys()) -def test_cast(src_dtype, dst_dtype, xblock_sub, func_name): - - if (src_dtype == dst_dtype): - return - - # if dtype in _int_dtypes: - # skip_int_dtype_ops = [ - # 'exp', 'exp2', 'log', 'log2', 'sin', 'cos', 'sqrt', 'rsqrt', 'sqrt_rn', - # 'sigmoid', 'erf', 'ceil', 'floor', - # ] - # if func_name in skip_int_dtype_ops: - # pytest.skip(f"{func_name} only tested with float dtypes") - # if dtype in _float_dtypes: - # skip_float_dtype_ops = [ - # 'not', 'invert', 'lshift', 'rshift', - # ] - # if func_name in skip_float_dtype_ops: - # pytest.skip(f"{func_name} only tested with int dtypes") - - xblock = triton.next_power_of_2(xblock_sub) - shape = (xblock, ) - triton_func_op, torch_func_op = FUNCTIONS_TO_TEST[func_name] - - def torch_func(x0, dst_dtype): - x1 = torch.ones((x0.numel(), ), dtype=dst_dtype, device=x0.device) - return eval(torch_func_op) - - def get_autotune_config(): - return [ - triton.Config({'XBLOCK': xblock, 'XBLOCK_SUB': xblock_sub, 'dst_dtype': eval(f"tl.{dst_dtype}")}), - ] - - triton_kernel = create_triton_kernel(func_name, triton_func_op) - triton_kernel = triton.autotune( - configs=get_autotune_config(), - key=['numel'], - )(triton_kernel) - - def triton_func(x0, dst_dtype): - numel = x0.numel() - y0 = torch.empty((numel, ), dtype=dst_dtype, device=x0.device) - y1 = torch.ones((numel, ), dtype=dst_dtype, device=x0.device) - grid = lambda meta: (triton.cdiv(numel, meta['XBLOCK']), ) - triton_kernel[grid](x0, y1, y0, numel) - return y0 - - x0 = generate_tensor(shape, src_dtype).npu() - - # if func_name in ['sqrt', 'rsqrt', 'sqrt_rn', 'log', 'log2']: - # x0 = fill_negative_with_one(x0) - - torch_ref = torch_func(x0, eval(f"torch.{dst_dtype}")) - triton_cal = triton_func(x0, eval(f"torch.{dst_dtype}")) - validate_cmp(dst_dtype, triton_cal, torch_ref) diff --git a/third_party/ascend/examples/pytest_ut_regbase/test_clamp.py b/third_party/ascend/examples/pytest_ut_regbase/test_clamp.py deleted file mode 100644 index 9dd300d61..000000000 --- a/third_party/ascend/examples/pytest_ut_regbase/test_clamp.py +++ /dev/null @@ -1,63 +0,0 @@ -import triton -import triton.language as tl -import torch -import pytest -import atexit -from test_common import ( - generate_tensor, - validate_cmp, - fill_negative_with_one, - _float_dtypes_without_bf16 as _float_dtypes, - _int_dtypes, - _shape_1d, -) - - -@pytest.mark.parametrize('dtype', _float_dtypes + _int_dtypes) -@pytest.mark.parametrize('xblock_sub', _shape_1d) -@pytest.mark.parametrize('min', [1]) -def test_clamp(dtype, xblock_sub): - - xblock = triton.next_power_of_2(xblock_sub) - shape = (xblock, ) - - def torch_func(x0, x1, x2): - return torch.where(x0, x1, x2) - - def get_autotune_config(): - return [ - triton.Config({'XBLOCK': xblock, 'XBLOCK_SUB': xblock_sub}), - ] - - @triton.autotune( - configs=get_autotune_config(), - key=['numel'], - ) - @triton.jit - def triton_kernel(in_ptr0, in_ptr1, in_ptr2, out_ptr0, numel, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): - offset = tl.program_id(0) * XBLOCK - base = tl.arange(0, XBLOCK_SUB) - num_loop: tl.constexpr = (XBLOCK + XBLOCK_SUB - 1) // XBLOCK_SUB - for loop in range(num_loop): - idx = offset + loop * XBLOCK_SUB + base - msk = idx < numel - x0 = tl.load(in_ptr0 + idx, mask=msk) - x1 = tl.load(in_ptr1 + idx, mask=msk) - x2 = tl.load(in_ptr2 + idx, mask=msk) - y0 = tl.where(x0, x1, x2) - tl.store(out_ptr0 + idx, y0, mask=msk) - - def triton_func(x0, x1, x2): - numel = x0.numel() - y0 = torch.empty_like(x1) - grid = lambda meta: (triton.cdiv(numel, meta['XBLOCK']), ) - triton_kernel[grid](x0, x1, x2, y0, numel) - return y0 - - x0 = generate_tensor(shape, 'bool').npu() - x1 = generate_tensor(shape, dtype).npu() - x2 = generate_tensor(shape, dtype).npu() - - torch_ref = torch_func(x0, x1, x2) - triton_cal = triton_func(x0, x1, x2) - validate_cmp(dtype, triton_cal, torch_ref) diff --git a/third_party/ascend/examples/pytest_ut_regbase/test_common.py b/third_party/ascend/examples/pytest_ut_regbase/test_common.py deleted file mode 100644 index 6f17d04d0..000000000 --- a/third_party/ascend/examples/pytest_ut_regbase/test_common.py +++ /dev/null @@ -1,179 +0,0 @@ -import torch -import torch_npu -import pytest -import functools -import re - -_float_dtypes = ['float32', 'float16', 'bfloat16'] -_int_dtypes = ['int32', 'int64', 'int16', 'int8'] -_all_dtypes_no_bool = _float_dtypes + _int_dtypes -_all_dtypes = _all_dtypes_no_bool + ['bool'] -_32bit_dtypes = ['float32', 'int32'] -_16bit_dtypes = ['float16', 'bfloat16', 'int16'] -_float_dtypes_without_bf16 = ['float32', 'float16'] - -_shape_1d = [1, 3, 17, 32, 741] -_shape_5d = [ - (2, 2, 2, 2, 8), - (3, 1, 3, 5, 7), - (3, 7, 5, 3, 1), -] - - -def generate_tensor(shape, dtype): - if dtype == 'float32' or dtype == 'float16' or dtype == 'bfloat16': - return torch.randn(size=shape, dtype=eval('torch.' + dtype)) - elif dtype == 'int32' or dtype == 'int64' or dtype == 'int16': - return torch.randint(low=-2000, high=2000, size=shape, dtype=eval('torch.' + dtype)) - elif dtype == 'int8': - return torch.randint(low=-128, high=127, size=shape, dtype=eval('torch.' + dtype)) - elif dtype == 'bool': - return torch.randint(low=0, high=2, size=shape).bool() - else: - raise ValueError('Invalid parameter \"dtype\" is found : {}'.format(dtype)) - - -def fill_zero_with_one(x): - return x.masked_fill(x == 0, 1) - - -def fill_negative_with_one(x): - return x.masked_fill(x < 0, 1) - - -def get_triton_sig_typename(dtype): - if dtype == 'float32': - tyname = "*fp32" - elif dtype == 'int32': - tyname = "*i32" - elif dtype == 'int64': - tyname = "*i64" - elif dtype == 'float16': - tyname = "*fp16" - elif dtype == 'int16': - tyname = "*i16" - elif dtype == 'int8': - tyname = "*i8" - elif dtype == 'bool': - tyname = "*i1" - else: - raise ValueError('Invalid parameter \"dtype\" is found : {}'.format(dtype)) - return tyname - - -# Relative error: abs(x_ref - x_cal) / abs(x_ref) -# Absolute error: abs(x_ref - x_cal) - - -# calculation type operators require different error range -# It is a stricter verification and not satisfied now, save it here -def validate_cal(dtype, y_cal, y_ref): - if dtype == 'float16': - if torch.mean(y_ref) < 0.001: - assert torch.abs(y_cal - y_ref) < 0.001, "|y_cal - y_ref| < 0.001 is required !" - else: - diff = torch.div(torch.abs(y_cal - y_ref), torch.abs(y_cal)) < 0.001 - # all true - assert diff.all(), "Relative error is less than 0.001 !" - if dtype == 'float32': - if torch.mean(y_ref) < 0.0001: - assert torch.abs(y_cal - y_ref) < 0.0001, "|y_cal - y_ref| < 0.0001 is required !" - else: - diff = torch.div(torch.abs(y_cal - y_ref), torch.abs(y_cal)) < 0.0001 - assert diff.all(), "Relative error is less than 0.001 !" - elif dtype == 'bfloat16': - diff = torch.div(torch.abs(y_cal - y_ref), torch.abs(y_cal)) < 0.001 - assert diff.all(), "Relative error is less than 0.001 !" - elif dtype == 'int32' or dtype == 'int64' or dtype == 'int16': - assert torch.equal(y_cal, y_ref) - elif dtype == 'bool': - assert torch.equal(y_cal, y_ref) - else: - raise ValueError('Invalid parameter \"dtype\" is found : {}'.format(dtype)) - - -# moving and comparison ops require no precision error -def validate_cmp(dtype, y_cal, y_ref): - y_cal = y_cal.npu() - y_ref = y_ref.npu() - if dtype == 'float16': - torch.testing.assert_close(y_ref, y_cal, rtol=1e-03, atol=1e-03, equal_nan=True) - elif dtype == 'bfloat16': - torch.testing.assert_close(y_ref.to(torch.float32), y_cal.to(torch.float32), rtol=1e-03, atol=1e-03, - equal_nan=True) - elif dtype == 'float32': - torch.testing.assert_close(y_ref, y_cal, rtol=1e-04, atol=1e-04, equal_nan=True) - elif dtype == 'int32' or dtype == 'int64' or dtype == 'int16' or dtype == 'int8': - assert torch.equal(y_cal, y_ref) - elif dtype == 'bool': - assert torch.equal(y_cal, y_ref) - else: - raise ValueError('Invalid parameter \"dtype\" is found : {}'.format(dtype)) - - -def validate_cmp_with_expection(dtype, y_cal, y_ref, expect): - if dtype == 'float32' or dtype == 'float16' or dtype == 'bfloat16': - if expect: - assert torch.allclose(y_ref, y_cal, rtol=1e-03, atol=1e-03, equal_nan=True) - else: - assert not torch.allclose(y_ref, y_cal, rtol=1e-03, atol=1e-03, equal_nan=True) - elif dtype == 'int32' or dtype == 'int64' or dtype == 'int16' or dtype == 'int8': - if expect: - assert torch.equal(y_cal, y_ref) - else: - assert not torch.equal(y_cal, y_ref) - else: - raise ValueError('Invalid parameter \"dtype\" is found : {}'.format(dtype)) - - -# Use the following pytest fixture to run one test case by only single worker. -# Refer to https://pytest-xdist.readthedocs.io/en/stable/how-to.html#making-session-scoped-fixtures-execute-only-once -@pytest.fixture(scope="function") -def pytest_runonce(worker_id, request, cache): - if (cache.get(request.node.nodeid, "none")) == "none": - cache.set(request.node.nodeid, worker_id) - else: - file_name = f"pytest_{worker_id}.txt" - with open(file_name, 'a') as file: - file.write(f"{request.node.nodeid} is already processed by {worker_id}") - return True - yield True - cache.set(request.node.nodeid, "none") - - -def raises_with_match(expected_exception, match_pattern): - - def decorator(test_func): - - @functools.wraps(test_func) - def wrapper(*args, **kwargs): - with pytest.raises(expected_exception, match=match_pattern): - return test_func(*args, **kwargs) - - return wrapper - - return decorator - - -def capture_output(expected_output): - - def decorator(test_func): - - @functools.wraps(test_func) - def wrapper(*args, **kwargs): - capsys = kwargs.pop('capsys', None) - if capsys is None: - try: - capsys = pytest.fixture(capsys)() - except: - raise RuntimeError("This decorator requires pytest's capsys fixture") - test_func(capsys, *args, **kwargs) - captured = capsys.readouterr() - # pybind11::scoped_ostream_redirect captures std::cout with \x00 inserted - # for now, no idea how to eliminate \x00 from C++ side. - cleaned = re.sub(r"\x00", "", captured.out) - assert expected_output in cleaned - - return wrapper - - return decorator diff --git a/third_party/ascend/examples/pytest_ut_regbase/test_fma.py b/third_party/ascend/examples/pytest_ut_regbase/test_fma.py deleted file mode 100644 index c8e62826f..000000000 --- a/third_party/ascend/examples/pytest_ut_regbase/test_fma.py +++ /dev/null @@ -1,65 +0,0 @@ -import triton -import triton.language as tl -import torch -import pytest -from test_common import ( - generate_tensor, - validate_cmp, - _float_dtypes_without_bf16 as _float_dtypes, - _int_dtypes, - _shape_1d, -) - - -# @pytest.mark.parametrize('dtype', ['float32']) -@pytest.mark.parametrize('xblock_sub', [32]) -@pytest.mark.parametrize('dtype', _float_dtypes) -# @pytest.mark.parametrize('xblock_sub', _shape_1d) -def test_fma(dtype, xblock_sub): - - xblock = triton.next_power_of_2(xblock_sub) - shape = (xblock, ) - - def torch_func(x0, x1, x2): - return x0 * x1 + x2 - - def get_autotune_config(): - return [ - triton.Config({'XBLOCK': xblock, 'XBLOCK_SUB': xblock_sub}), - ] - - @triton.autotune( - configs=get_autotune_config(), - key=['numel'], - ) - @triton.jit - def triton_kernel(in_ptr0, in_ptr1, in_ptr2, out_ptr0, numel, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): - offset = tl.program_id(0) * XBLOCK - base = tl.arange(0, XBLOCK_SUB) - num_loop: tl.constexpr = (XBLOCK + XBLOCK_SUB - 1) // XBLOCK_SUB - for loop in range(num_loop): - idx = offset + loop * XBLOCK_SUB + base - msk = idx < numel - x0 = tl.load(in_ptr0 + idx, mask=msk) - x1 = tl.load(in_ptr1 + idx, mask=msk) - x2 = tl.load(in_ptr2 + idx, mask=msk) - y0 = tl.fma(x0, x1, x2) - tl.store(out_ptr0 + idx, y0, mask=msk) - - def triton_func(x0, x1, x2): - numel = x0.numel() - y0 = torch.empty_like(x0) - grid = lambda meta: (triton.cdiv(numel, meta['XBLOCK']), ) - triton_kernel[grid](x0, x1, x2, y0, numel) - return y0 - - x0 = generate_tensor(shape, dtype).npu() - x1 = generate_tensor(shape, dtype).npu() - x2 = generate_tensor(shape, dtype).npu() - - torch_ref = torch_func(x0, x1, x2) - triton_cal = triton_func(x0, x1, x2) - validate_cmp(dtype, triton_cal, torch_ref) - - -# test_fma('float32', 32) diff --git a/third_party/ascend/examples/pytest_ut_regbase/test_pure_simt.py b/third_party/ascend/examples/pytest_ut_regbase/test_pure_simt.py deleted file mode 100644 index 317b34bf6..000000000 --- a/third_party/ascend/examples/pytest_ut_regbase/test_pure_simt.py +++ /dev/null @@ -1,87 +0,0 @@ -import triton -import triton.language as tl -import numpy as np -import torch -import pytest -import test_common -import math - - -def torch_pointwise(x0, x1): - res = x0 + x1 - return res - - -@triton.jit -def add_kernel(x_ptr, # *Pointer* to first input vector. - y_ptr, # *Pointer* to second input vector. - output_ptr, # *Pointer* to output vector. - n_elements, # Size of the vector. - BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. - # NOTE: `constexpr` so it can be used as a shape value. - ): - for xoffset in range(0, n_elements, BLOCK_SIZE): - offsets = xoffset + tl.arange(0, BLOCK_SIZE)[:, None] - # Create a mask to guard memory operations against out-of-bounds accesses. - mask = offsets < n_elements - # Load x and y from DRAM, masking out any extra elements in case the input is not a - # multiple of the block size. - x = tl.load(x_ptr + offsets, mask=mask) - y = tl.load(y_ptr + offsets, mask=mask) - output = x + y - # Write x + y back to DRAM. - tl.store(output_ptr + offsets, output, mask=mask) - - -@triton.jit -def add_kernel_any_grid(x_ptr, # *Pointer* to first input vector. - y_ptr, # *Pointer* to second input vector. - output_ptr, # *Pointer* to output vector. - n_elements, # Size of the vector. - BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. - # NOTE: `constexpr` so it can be used as a shape value. - ): - pid_id = tl.program_id(axis=0) - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - x = tl.load(x_ptr + offsets, mask=mask) - y = tl.load(y_ptr + offsets, mask=mask) - output = x + y - tl.store(output_ptr + offsets, output, mask=mask) - - -@pytest.mark.parametrize('param_list', [ - ['float32', (98432, ), 1024], - ['float16', (98432, ), 1024], -]) -def test_case(param_list, monkeypatch): - monkeypatch.setenv("TRITON_DISABLE_FFTS", "1") - dtype, shape, xblock = param_list - x0 = test_common.generate_tensor(shape, dtype).npu() - x1 = test_common.generate_tensor(shape, dtype).npu() - n_elements = math.prod(shape) - y_ref = torch_pointwise(x0, x1) - y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() - # TODO: support grid > 1 after we pass grid size as input arguments - add_kernel[(1, )](x0, x1, y_cal, n_elements, BLOCK_SIZE=xblock, force_simt_only=True) - test_common.validate_cmp(dtype, y_cal, y_ref) - monkeypatch.delenv("TRITON_DISABLE_FFTS") - - -@pytest.mark.parametrize('param_list', [ - ['float32', (98432, ), 1024], - ['float16', (98432, ), 1024], -]) -def test_any_grid(param_list, monkeypatch): - monkeypatch.setenv("TRITON_DISABLE_FFTS", "1") - dtype, shape, xblock = param_list - x0 = test_common.generate_tensor(shape, dtype).npu() - x1 = test_common.generate_tensor(shape, dtype).npu() - n_elements = math.prod(shape) - y_ref = torch_pointwise(x0, x1) - y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() - grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) - add_kernel[grid](x0, x1, y_cal, n_elements, BLOCK_SIZE=xblock, force_simt_only=True) - test_common.validate_cmp(dtype, y_cal, y_ref) - monkeypatch.delenv("TRITON_DISABLE_FFTS") diff --git a/third_party/ascend/examples/pytest_ut_regbase/test_reduce.py b/third_party/ascend/examples/pytest_ut_regbase/test_reduce.py deleted file mode 100644 index 5d627ca9a..000000000 --- a/third_party/ascend/examples/pytest_ut_regbase/test_reduce.py +++ /dev/null @@ -1,148 +0,0 @@ -import triton -import triton.language as tl -import torch -import pytest -import atexit -from test_common import ( - generate_tensor, - validate_cmp, - fill_zero_with_one, - _float_dtypes_without_bf16 as _float_dtypes, - _int_dtypes, - _shape_1d, -) - -################################################# -FUNCTIONS_TO_TEST = { - 'sum': ("tl.sum(x0, 1) + tl.sum(x1, 1)", "torch.sum(x0, 1) + torch.sum(x1, 1)"), - 'max': ("tl.max(x0, 1) + tl.max(x1, 1)", "( torch.max(x0.cpu(), 1)[0] + torch.max(x1.cpu(), 1)[0] ).npu()"), - 'min': ("tl.min(x0, 1) + tl.min(x1, 1)", "( torch.min(x0.cpu(), 1)[0] + torch.min(x1.cpu(), 1)[0] ).npu()"), - 'argmax': ("tl.argmax(x0, 1) + tl.argmax(x1, 1)", "torch.argmax(x0, 1) + torch.argmax(x1, 1)"), - 'argmin': ("tl.argmin(x0, 1) + tl.argmin(x1, 1)", "torch.argmin(x0, 1) + torch.argmin(x1, 1)"), -} -################################################# - -# Global dictionary to keep track of temporary files -_temp_kernel_files = {} - - -def _cleanup_temp_files(): - import os - for file_path in _temp_kernel_files.values(): - try: - if os.path.exists(file_path): - os.unlink(file_path) - except: - pass - - -atexit.register(_cleanup_temp_files) - - -def create_triton_kernel(func_name, func_pattern): - import tempfile - import os - kernel_source = f""" -import triton -import triton.language as tl - -@triton.jit -def triton_kernel(in_ptr0, in_ptr1, out_ptr0, xnumel, rnumel, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr, RBLOCK: tl.constexpr): - offset = tl.program_id(0) * XBLOCK - xbase = tl.arange(0, XBLOCK_SUB) - ridx = tl.arange(0, RBLOCK) - num_loop: tl.constexpr = (XBLOCK + XBLOCK_SUB - 1) // XBLOCK_SUB - for loop in range(num_loop): - xidx = offset + loop * XBLOCK_SUB + xbase - idx = xidx[:, None] * rnumel + ridx[None, :] - xmsk = xidx < xnumel - rmsk = ridx < xnumel - msk = xmsk[:, None] & rmsk[None, :] - x0 = tl.load(in_ptr0 + idx, mask=msk) - x1 = tl.load(in_ptr1 + idx, mask=msk) - y0 = {func_pattern} - tl.store(out_ptr0 + xidx, y0, mask=xmsk) -""" - - # Create a temporary file with a unique name based on the function name - if func_name in _temp_kernel_files: - temp_file_path = _temp_kernel_files[func_name] - else: - fd, temp_file_path = tempfile.mkstemp(suffix='.py', prefix=f'triton_kernel_{func_name}_') - os.close(fd) # We don't need the file descriptor - _temp_kernel_files[func_name] = temp_file_path - - # Write the kernel source to the file - with open(temp_file_path, 'w') as f: - f.write(kernel_source) - - # Import the kernel from the temporary file - import importlib.util - module_name = f"triton_kernel_{func_name.replace('.', '_')}" - spec = importlib.util.spec_from_file_location(module_name, temp_file_path) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - return module.triton_kernel - - -@pytest.mark.parametrize('dtype', _float_dtypes + _int_dtypes) -@pytest.mark.parametrize('xblock_sub', [5]) -@pytest.mark.parametrize('rblock', [128]) -@pytest.mark.parametrize('func_name', FUNCTIONS_TO_TEST.keys()) -def test_reduce(dtype, xblock_sub, rblock, func_name): - - # if dtype in _int_dtypes: - # skip_int_dtype_ops = [ - # 'div_rn', 'fdiv', - # ] - # if func_name in skip_int_dtype_ops: - # pytest.skip(f"{func_name} only tested with float dtypes") - # if dtype in _float_dtypes: - # skip_float_dtype_ops = [ - # 'mod', 'floordiv', 'cdiv', 'and', 'or', 'xor', - # ] - # if func_name in skip_float_dtype_ops: - # pytest.skip(f"{func_name} only tested with int dtypes") - # if func_name == 'mod' and dtype == 'int64': - # pytest.skip(f"{func_name} skips int64") - # if func_name in ['logical_and', 'logical_or'] and dtype != 'bool': - # pytest.skip(f"{func_name} tests only bool") - - xblock = triton.next_power_of_2(xblock_sub) - shape = (xblock, rblock) - triton_func_op, torch_func_op = FUNCTIONS_TO_TEST[func_name] - - def torch_func(x0, x1): - return eval(torch_func_op) - - def get_autotune_config(): - return [ - triton.Config({'XBLOCK': xblock, 'XBLOCK_SUB': xblock_sub, 'RBLOCK': rblock}), - ] - - triton_kernel = create_triton_kernel(func_name, triton_func_op) - triton_kernel = triton.autotune( - configs=get_autotune_config(), - key=['xnumel'], - )(triton_kernel) - - def triton_func(x0, x1, out_dtype): - xnumel, rnumel = x0.shape - y0 = torch.empty((xnumel, ), dtype=x0.dtype).npu() - grid = lambda meta: (triton.cdiv(xnumel, meta['XBLOCK']), ) - triton_kernel[grid](x0, x1, y0, xnumel, rnumel) - return y0 - - x0 = generate_tensor(shape, dtype).npu() - x1 = generate_tensor(shape, dtype).npu() - - # if func_name in ['div', 'fdiv', 'div_rn', 'cdiv', 'floordiv', 'mod']: - # x1 = fill_zero_with_one(x1) - - out_dtype = x0.dtype - if func_name in ['argmax', 'argmin']: - out_dtype = torch.int32 - - triton_cal = triton_func(x0, x1, out_dtype) - torch_ref = torch_func(x0, x1) - validate_cmp(out_dtype.__str__().split('.')[1], triton_cal, torch_ref) diff --git a/third_party/ascend/examples/pytest_ut_regbase/test_unary.py b/third_party/ascend/examples/pytest_ut_regbase/test_unary.py deleted file mode 100644 index 87e61bd4d..000000000 --- a/third_party/ascend/examples/pytest_ut_regbase/test_unary.py +++ /dev/null @@ -1,163 +0,0 @@ -import triton -import triton.language as tl -import torch -import pytest -import atexit -from test_common import ( - generate_tensor, - validate_cmp, - fill_negative_with_one, - _float_dtypes_without_bf16 as _float_dtypes, - _int_dtypes, - _shape_1d, -) - -################################################# -FUNCTIONS_TO_TEST = { - 'abs': ("tl.abs(x0)", "torch.abs(x0)"), - 'exp': ("tl.exp(x0)", "torch.exp(x0)"), - 'exp2': ("tl.exp2(x0)", "torch.exp2(x0)"), - 'log': ("tl.log(x0)", "torch.log(x0)"), - 'log2': ("tl.log2(x0)", "torch.log2(x0)"), - 'sin': ("tl.sin(x0)", "torch.sin(x0)"), - 'cos': ("tl.cos(x0)", "torch.cos(x0)"), - 'sqrt': ("tl.sqrt(x0)", "torch.sqrt(x0)"), - 'rsqrt': ("tl.rsqrt(x0)", "torch.rsqrt(x0)"), - 'sigmoid': ("tl.sigmoid(x0)", "torch.sigmoid(x0)"), - 'sqrt_rn': ("tl.sqrt_rn(x0)", "torch.sqrt(x0)"), - 'erf': ("tl.erf(x0)", "torch.erf(x0)"), - 'neg': ("-x0", "-x0"), - 'not': ("not(x0)", "torch.bitwise_not(x0)"), - 'invert': - ("~x0", "( ~( x0.to(torch.int32) ) ).to(x0.dtype) if x0.dtype in [torch.float32, torch.float16] else ~x0"), - 'ceil': ("tl.ceil(x0)", "torch.ceil(x0)"), - 'floor': ("tl.floor(x0)", "torch.floor(x0)"), - 'lshift': ("x0 << 2", "x0 << 2"), - 'rshift': ("x0 >> 2", "x0 >> 2"), -} -################################################# - -# Global dictionary to keep track of temporary files -_temp_kernel_files = {} - - -def _cleanup_temp_files(): - import os - for file_path in _temp_kernel_files.values(): - try: - if os.path.exists(file_path): - os.unlink(file_path) - except: - pass - - -atexit.register(_cleanup_temp_files) - - -def create_triton_kernel(func_name, func_pattern): - import tempfile - import os - kernel_source = f""" -import triton -import triton.language as tl - -@triton.jit -def triton_kernel(in_ptr0, out_ptr0, numel, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): - offset = tl.program_id(0) * XBLOCK - base = tl.arange(0, XBLOCK_SUB) - num_loop: tl.constexpr = (XBLOCK + XBLOCK_SUB - 1) // XBLOCK_SUB - for loop in range(num_loop): - idx = offset + loop * XBLOCK_SUB + base - msk = idx < numel - x0 = tl.load(in_ptr0 + idx, mask=msk) - y0 = {func_pattern} - tl.store(out_ptr0 + idx, y0, mask=msk) -""" - - # Create a temporary file with a unique name based on the function name - if func_name in _temp_kernel_files: - temp_file_path = _temp_kernel_files[func_name] - else: - fd, temp_file_path = tempfile.mkstemp(suffix='.py', prefix=f'triton_kernel_{func_name}_') - os.close(fd) # We don't need the file descriptor - _temp_kernel_files[func_name] = temp_file_path - - # Write the kernel source to the file - with open(temp_file_path, 'w') as f: - f.write(kernel_source) - - # Import the kernel from the temporary file - import importlib.util - module_name = f"triton_kernel_{func_name.replace('.', '_')}" - spec = importlib.util.spec_from_file_location(module_name, temp_file_path) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - return module.triton_kernel - - -@pytest.mark.parametrize('dtype', _float_dtypes + _int_dtypes) -@pytest.mark.parametrize('xblock_sub', _shape_1d) -@pytest.mark.parametrize('func_name', FUNCTIONS_TO_TEST.keys()) -def test_unary(dtype, xblock_sub, func_name): - - if dtype in _int_dtypes: - skip_int_dtype_ops = [ - 'exp', - 'exp2', - 'log', - 'log2', - 'sin', - 'cos', - 'sqrt', - 'rsqrt', - 'sqrt_rn', - 'sigmoid', - 'erf', - 'ceil', - 'floor', - ] - if func_name in skip_int_dtype_ops: - pytest.skip(f"{func_name} only tested with float dtypes") - if dtype in _float_dtypes: - skip_float_dtype_ops = [ - 'not', - 'invert', - 'lshift', - 'rshift', - ] - if func_name in skip_float_dtype_ops: - pytest.skip(f"{func_name} only tested with int dtypes") - - xblock = triton.next_power_of_2(xblock_sub) - shape = (xblock, ) - triton_func_op, torch_func_op = FUNCTIONS_TO_TEST[func_name] - - def torch_func(x0): - return eval(torch_func_op) - - def get_autotune_config(): - return [ - triton.Config({'XBLOCK': xblock, 'XBLOCK_SUB': xblock_sub}), - ] - - triton_kernel = create_triton_kernel(func_name, triton_func_op) - triton_kernel = triton.autotune( - configs=get_autotune_config(), - key=['numel'], - )(triton_kernel) - - def triton_func(x0): - y0 = torch.empty_like(x0) - numel = x0.numel() - grid = lambda meta: (triton.cdiv(numel, meta['XBLOCK']), ) - triton_kernel[grid](x0, y0, numel) - return y0 - - x0 = generate_tensor(shape, dtype).npu() - - if func_name in ['sqrt', 'rsqrt', 'sqrt_rn', 'log', 'log2']: - x0 = fill_negative_with_one(x0) - - torch_ref = torch_func(x0) - triton_cal = triton_func(x0) - validate_cmp(dtype, triton_cal, torch_ref) diff --git a/third_party/ascend/examples/pytest_ut_regbase/test_where.py b/third_party/ascend/examples/pytest_ut_regbase/test_where.py deleted file mode 100644 index 9be852bb3..000000000 --- a/third_party/ascend/examples/pytest_ut_regbase/test_where.py +++ /dev/null @@ -1,62 +0,0 @@ -import triton -import triton.language as tl -import torch -import pytest -import atexit -from test_common import ( - generate_tensor, - validate_cmp, - fill_negative_with_one, - _float_dtypes_without_bf16 as _float_dtypes, - _int_dtypes, - _shape_1d, -) - - -@pytest.mark.parametrize('dtype', _float_dtypes + _int_dtypes) -@pytest.mark.parametrize('xblock_sub', _shape_1d) -def test_where(dtype, xblock_sub): - - xblock = triton.next_power_of_2(xblock_sub) - shape = (xblock, ) - - def torch_func(x0, x1, x2): - return torch.where(x0, x1, x2) - - def get_autotune_config(): - return [ - triton.Config({'XBLOCK': xblock, 'XBLOCK_SUB': xblock_sub}), - ] - - @triton.autotune( - configs=get_autotune_config(), - key=['numel'], - ) - @triton.jit - def triton_kernel(in_ptr0, in_ptr1, in_ptr2, out_ptr0, numel, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): - offset = tl.program_id(0) * XBLOCK - base = tl.arange(0, XBLOCK_SUB) - num_loop: tl.constexpr = (XBLOCK + XBLOCK_SUB - 1) // XBLOCK_SUB - for loop in range(num_loop): - idx = offset + loop * XBLOCK_SUB + base - msk = idx < numel - x0 = tl.load(in_ptr0 + idx, mask=msk) - x1 = tl.load(in_ptr1 + idx, mask=msk) - x2 = tl.load(in_ptr2 + idx, mask=msk) - y0 = tl.where(x0, x1, x2) - tl.store(out_ptr0 + idx, y0, mask=msk) - - def triton_func(x0, x1, x2): - numel = x0.numel() - y0 = torch.empty_like(x1) - grid = lambda meta: (triton.cdiv(numel, meta['XBLOCK']), ) - triton_kernel[grid](x0, x1, x2, y0, numel) - return y0 - - x0 = generate_tensor(shape, 'bool').npu() - x1 = generate_tensor(shape, dtype).npu() - x2 = generate_tensor(shape, dtype).npu() - - torch_ref = torch_func(x0, x1, x2) - triton_cal = triton_func(x0, x1, x2) - validate_cmp(dtype, triton_cal, torch_ref) diff --git a/third_party/ascend/examples/run_daily.sh b/third_party/ascend/examples/run_daily.sh deleted file mode 100755 index 3ce2497a5..000000000 --- a/third_party/ascend/examples/run_daily.sh +++ /dev/null @@ -1,197 +0,0 @@ -#!/bin/bash - -set -ex - -script=$(readlink -f "$0") -script_dir=$(dirname "$script") - -source /usr/local/CANN_8.2.RC1.alpha002/ascend-toolkit/set_env.sh -export LLVM_BUILD_DIR=/opt/llvm-b5cc222 - -COMPILER_ROOT=/home/shared/bisheng_toolkit_20250922 -BSIR_COMPILE_PATH=$(find "$COMPILER_ROOT" -name "bishengir-compile" | xargs dirname) -export PATH=${COMPILER_ROOT}:${BSIR_COMPILE_PATH}:$PATH -# FIXME: the 20250812 bishengir-compile requires the pairing bisheng compiler -export BISHENG_INSTALL_PATH=/home/shared/cann_compiler_20250812/compiler/ccec_compiler/bin - -# 新增:定义统计文件路径 -SUMMARY_FILE="${WORKSPACE}/ascend/examples/summary.txt" - -function build_triton() { - cd ${WORKSPACE} - pip3 uninstall triton_ascend -y - - git submodule set-url third_party/triton https://gitee.com/shijingchang/triton.git - git submodule sync && git submodule update --init --recursive - - bash scripts/build.sh ${WORKSPACE}/ascend ${LLVM_BUILD_DIR} 3.2.0 install 0 -} - -function run_case_by_multi_card() { - NPU_DEVICES=$(ls /dev/davinci? 2>/dev/null | wc -l) - [ $NPU_DEVICES -eq 0 ] && { - echo "No Ascend devices found!" - exit 1 - } - - echo "Detected $NPU_DEVICES Ascend devices" - - if [ -d ${WORKSPACE}triton ];then - rm -rf ${WORKSPACE}triton - fi - - if [ -d ~/.triton/dump ];then - rm -rf ~/.triton/dump - fi - - if [ -d ~/.triton/cache ];then - rm -rf ~/.triton/cache - fi - - test_dir=$1 - cd ${test_dir} - - # 清理旧日志 - rm -rf logs && mkdir logs - - # 记录测试开始时间 - start_time=$(date +"%Y-%m-%d %H:%M:%S") - echo "===== 测试开始时间: ${start_time} =====" - - # 运行测试并捕获退出状态 - set +e - pytest ${test_dir} -n auto --dist=loadfile -v --junitxml=logs/results.xml | tee logs/raw_output.log - pytest_exit=$? - set -e - - # 处理日志(添加设备标签) - awk ' - />> Worker gw[0-9]+ using NPU device/ { - split($0, parts, / /) - dev_id = parts[NF] - worker = parts[3] - print "[" strftime("%Y-%m-%d %H:%M:%S") "| DEV-" dev_id "] " $0 - next - } - { print "[" strftime("%Y-%m-%d %H:%M:%S") "| DEV-" dev_id "] " $0 } - ' logs/raw_output.log > logs/combined.log - - # 新增:解析测试结果统计 - total_tests=0 - passed_tests=0 - failed_tests=0 - skipped_tests=0 - error_tests=0 - - # 使用Python解析JUnit XML报告 - python3 -c " -import xml.etree.ElementTree as ET -import os - -xml_file = os.path.join('logs', 'results.xml') -if not os.path.exists(xml_file): - print('JUnitXML report not found:', xml_file) - exit(1) - -tree = ET.parse(xml_file) -root = tree.getroot() - -total = 0 -passed = 0 -failed = 0 -skipped = 0 -errors = 0 - -# 遍历所有testsuite -for testsuite in root.findall('testsuite'): - total += int(testsuite.get('tests', 0)) - passed += int(testsuite.get('tests', 0)) - int(testsuite.get('errors', 0)) - int(testsuite.get('failures', 0)) - int(testsuite.get('skipped', 0)) - failed += int(testsuite.get('failures', 0)) - skipped += int(testsuite.get('skipped', 0)) - errors += int(testsuite.get('errors', 0)) - -print(f'total_tests={total}') -print(f'passed_tests={passed}') -print(f'failed_tests={failed}') -print(f'skipped_tests={skipped}') -print(f'error_tests={errors}') -" > logs/stats.tmp - - # 加载统计结果 - source logs/stats.tmp - rm logs/stats.tmp - - # 记录测试结束时间 - end_time=$(date +"%Y-%m-%d %H:%M:%S") - duration=$(( $(date -d "$end_time" +%s) - $(date -d "$start_time" +%s) )) - duration_str=$(printf "%02dh %02dm %02ds" $((duration/3600)) $(((duration%3600)/60)) $((duration%60))) - - # 新增:生成统计摘要 - stats_summary=" -===== generalization_cases测试统计摘要 ===== -测试目录: $(basename ${test_dir}) -测试开始时间: ${start_time} -测试结束时间: ${end_time} -总耗时: ${duration_str} ------------------------- -总用例数: ${total_tests} -成功用例: ${passed_tests} -失败用例: ${failed_tests} -跳过用例: ${skipped_tests} -错误用例: ${error_tests} -成功率: $(( passed_tests * 100 / total_tests ))% (成功/总数) -设备数量: ${NPU_DEVICES} -======================== -" - - # 输出统计信息到控制台 - echo "${stats_summary}" - - # 追加统计信息到summary.txt - echo "${stats_summary}" >> ${SUMMARY_FILE} - - echo "========================================" - echo "All tests completed!" - echo "JUnit Report: logs/results.xml" - echo "Combined Log: logs/combined.log" - echo "统计摘要已追加到: ${SUMMARY_FILE}" - echo "========================================" - - zip_file=$2 - cd ${test_dir}/logs - zip ${zip_file} combined.log - cp ${zip_file} "/home/daily_log" - - # 返回pytest的退出状态 - return $pytest_exit -} - -# build in torch 2.6.0 -source /opt/miniconda3/bin/activate torch_260 -build_triton - -cd ${WORKSPACE} - -# 初始化统计文件 -echo "生成时间: $(date +"%Y-%m-%d %H:%M:%S")" >> ${SUMMARY_FILE} -echo "========================================" >> ${SUMMARY_FILE} - -# run inductor cases -TEST_inductor_cases="${WORKSPACE}/ascend/examples/inductor_cases" -cd ${TEST_inductor_cases} -bash run_inductor_test.sh - -# run gene case -zip_file="test_generalizetion_case_$(date +%Y%m%d).zip" -TEST_generalization="${WORKSPACE}/ascend/examples/generalization_cases" -run_case_by_multi_card ${TEST_generalization} ${zip_file} - -echo "========================================" >> ${SUMMARY_FILE} - -# run flaggems cases -TEST_flaggems_cases="${WORKSPACE}/ascend/examples/flaggems_cases" -cd ${TEST_flaggems_cases} -bash run_flaggems_test.sh - -# copy summary.txt to /home/daily_log -cp ${SUMMARY_FILE} /home/daily_log diff --git a/third_party/ascend/examples/run_test.sh b/third_party/ascend/examples/run_test.sh deleted file mode 100755 index baa5ee860..000000000 --- a/third_party/ascend/examples/run_test.sh +++ /dev/null @@ -1,190 +0,0 @@ -#!/bin/bash - -set -ex - -script=$(readlink -f "$0") -script_dir=$(dirname "$script") - -# skiped script -skip_script=("bench_utils.py" "11-rab_time.py") - -function uninstall_triton_ascend() { - set +e - while true; do - pip3 uninstall triton_ascend -y | grep "Found existing installation" - if [ $? -eq 1 ]; then - echo "All triton_ascend versions are uninstalled" - break - fi - done - set -e -} - -function build_triton() { - - cd ${WORKSPACE} - # Run uninstall once because the while-loop does not stop. No idea why. - # uninstall_triton_ascend - pip3 uninstall triton_ascend -y - - git submodule set-url third_party/triton https://gitee.com/shijingchang/triton.git - git submodule sync && git submodule update --init --recursive - - bash scripts/build.sh ${WORKSPACE}/ascend ${LLVM_BUILD_DIR} 3.2.0 install 0 -} - -function run_pytestcases() { - if [ -d ${HOME}/.triton/dump ]; then - rm -rf ${HOME}/.triton/dump - fi - if [ -d ${HOME}/.triton/cache ]; then - rm -rf ${HOME}/.triton/cache - fi - - cd ${script_dir} - TARGET_DIR="$1" - cd ${TARGET_DIR} - pytest -n 16 --dist=load . || { exit 1 ; } - -} - -function run_pythoncases() { - if [ -d ${HOME}/.triton/dump ]; then - rm -rf ${HOME}/.triton/dump - fi - if [ -d ${HOME}/.triton/cache ]; then - rm -rf ${HOME}/.triton/cache - fi - - cd ${script_dir} - TARGET_DIR="$1" - cd ${TARGET_DIR} - - declare -a pids - declare -A status_map - has_failure=0 - - # 查找并运行所有.py文件 - for test_script in *.py; do - for skip_item in "${skip_script[@]}"; do - if [ "$test_script" == "$skip_item" ]; then - break - fi - done - - if [ -f "$test_script" ]; then - echo "启动测试: $test_script" - python "./$test_script" & - pid=$! - pids+=($pid) - status_map[$pid]=$test_script - fi - done - - # 等待所有后台进程完成并检查状态 - for pid in "${pids[@]}"; do - wait "$pid" - exit_status=$? - script_name=${status_map[$pid]} - - if [ $exit_status -ne 0 ]; then - echo "[失败] $script_name - 退出码 $exit_status" - has_failure=1 - else - echo "[成功] $script_name" - fi - done - - echo "--------------------------------" - - # 根据测试结果退出 - if [ $has_failure -eq 1 ]; then - echo "部分测试失败!" - exit 1 - else - echo "所有测试通过!" - exit 0 - fi -} - -function validate_git_commit_title() { - if [ $# -lt 1 ]; then - echo "Usage: $0 " - exit 1 - fi - commit_title=$1 - if ! echo "${commit_title}" | grep -qE "^(feat|fix|docs|style|refactor|test|chore|revert)(\(.*\))?: .+"; then - echo "❌ The git commit title does not comply with the specifications!" - echo "Format Requirements: (): " - echo "e.g.: feat(user): The login function is added." - echo "Allowed Types: feat | fix | docs | style | refactor | test | chore | revert" - exit 1 - fi - echo "✅ The submitted information complies with the specifications." -} - -function validate_pr_all_commits_title() { - commit_titles=$(git log master..HEAD --oneline | sed 's/^[^ ]* //') - if [ -z "$commit_titles" ]; then - echo "No commits found between HEAD and master." - exit 1 - fi - echo "Validating commit titles..." - echo "----------------------------" - while IFS= read -r title; do - echo "Checking: $title" - if ! validate_git_commit_title "$title" 2>/dev/null; then - echo "Error in commit: $title" >&2 - HAS_ERROR=true - fi - done <<< "$commit_titles" - if [ "$HAS_ERROR" = true ]; then - echo "----------------------------" - echo "❌ Some commit titles do not meet the specifications." >&2 - exit 1 - else - echo "----------------------------" - echo "✅ All commit titles meet the specifications." - fi -} - -# if ! validate_pr_all_commits_title 2>/dev/null; then -# exit 1 -# fi - -source /usr/local/CANN_8.2.RC1.alpha002/ascend-toolkit/set_env.sh -export LLVM_BUILD_DIR=/opt/llvm-b5cc222 - -# FIXME: 20250508 the bishengir-compile in the CANN 8.0.T115 fails lots of cases -# So we need to use another version of compiler. -COMPILER_ROOT=/home/shared/bisheng_toolkit_20250922 -BSIR_COMPILE_PATH=$(find "$COMPILER_ROOT" -name "bishengir-compile" | xargs dirname) -export PATH=${COMPILER_ROOT}:${BSIR_COMPILE_PATH}:$PATH -# FIXME: the 20250812 bishengir-compile requires the pairing bisheng compiler -export BISHENG_INSTALL_PATH=/home/shared/cann_compiler_20250812/compiler/ccec_compiler/bin - -# build in torch 2.6.0 -source /opt/miniconda3/bin/activate torch_260 -build_triton - -echo "Run ttir to linalg tests..." -cd ${WORKSPACE}/build/cmake.linux-aarch64-cpython-3.11 -ninja check-triton-adapter-lit-tests -if [ $? -eq 0 ]; then - echo "All ttir to linalg tests passed" -else - echo "Some ttir to linalg tests failed" - exit 1 -fi - -pytestcase_dir=("pytest_ut") -for test_dir in "${pytestcase_dir[@]}"; do - echo "run pytestcase in ${test_dir}" - run_pytestcases ${test_dir} -done - -pythoncase_dir=("autotune_cases" "benchmark_cases" "tutorials") -for test_dir in "${pythoncase_dir[@]}"; do - echo "run pythoncase in ${test_dir}" - run_pythoncases ${test_dir} -done diff --git a/third_party/ascend/examples/simt_perf_cases/test_embedding_sum.py b/third_party/ascend/examples/simt_perf_cases/test_embedding_sum.py deleted file mode 100644 index b6e226600..000000000 --- a/third_party/ascend/examples/simt_perf_cases/test_embedding_sum.py +++ /dev/null @@ -1,105 +0,0 @@ -import triton -import triton.language as tl -import torch -import torch_npu -import pytest - - -@triton.jit -def triton_unk_fused_embedding_sum_5( - in_ptr0, - in_ptr1, - out_ptr0, - y0_numel, - x1_numel, - r2_numel, - X1BLOCK_SUB: tl.constexpr, - R2BLOCK_SUB: tl.constexpr, -): - y0_offset = tl.program_id(0) - grid_size = tl.num_programs(0) - base_x1 = tl.arange(0, X1BLOCK_SUB) - base_r2 = tl.arange(0, R2BLOCK_SUB) - loops_r2 = (r2_numel + R2BLOCK_SUB - 1) // R2BLOCK_SUB - for y0 in range(y0_offset, y0_numel, grid_size): - x1 = base_x1[None, :] - x1_store = base_x1 - _tmp8 = tl.full([R2BLOCK_SUB, X1BLOCK_SUB], 0, tl.float32) - for loop_r2 in range(loops_r2): - r2 = loop_r2 + base_r2[:, None] * loops_r2 - r2_mask = r2 < r2_numel - tmp0 = tl.load(in_ptr0 + (r2 + r2_numel * y0), r2_mask, other=0.0) - #embedding table length constant fold - tmp1 = tl.full([R2BLOCK_SUB, X1BLOCK_SUB], 9000, tl.int64) - tmp2 = tmp0 + tmp1 - tmp3 = tmp0 < 0 - tmp4 = tl.where(tmp3, tmp2, tmp0) - tmp6 = tl.load(in_ptr1 + (x1 + x1_numel * tmp4), r2_mask, other=0.0) - tmp7 = tl.reshape(tmp6, [R2BLOCK_SUB, X1BLOCK_SUB]) - tmp9 = _tmp8 + tmp7 - _tmp8 = tl.where(r2_mask, tmp9, _tmp8) - tmp8 = tl.sum(_tmp8, 0) - tl.store(out_ptr0 + (x1_store + x1_numel * y0), tmp8) - - -def torch_red_fused_embedding_sum_vec(arg0, arg1): - adjusted = torch.where(arg0 < 0, arg0 + 9000, arg0) - emb = arg1[adjusted] - return emb.sum(dim=1) - - -@pytest.mark.parametrize( - "param_list", - [ - [128, 128, 4000], - ], -) -def test_embedding_sum(param_list): - y0_numel, x1_numel, r2_numel = param_list - arg0_1 = torch.randint(-1000, 9000, (y0_numel, r2_numel), dtype=torch.int64, device="npu") - arg2_1 = torch.randn(9000, x1_numel, dtype=torch.float32, device="npu") - buf44 = torch.empty((y0_numel, x1_numel), dtype=torch.float32, device="npu") - grid_size = 128 - result_path = "./result_profiling" - skip_first = 10 - wait = 0 - warmup = 3 - active = 30 - repeat = 1 - stream = torch.npu.current_stream() - experimental_config = torch_npu.profiler._ExperimentalConfig( - aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, - profiler_level=torch_npu.profiler.ProfilerLevel.Level1, - l2_cache=False, - data_simplification=False, - ) - with torch_npu.profiler.profile( - activities=[ - torch_npu.profiler.ProfilerActivity.CPU, - torch_npu.profiler.ProfilerActivity.NPU, - ], - schedule=torch_npu.profiler.schedule( - wait=wait, - warmup=warmup, - active=active, - repeat=repeat, - skip_first=skip_first, - ), - on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(result_path), - record_shapes=True, - profile_memory=False, - with_stack=False, - with_flops=False, - with_modules=False, - experimental_config=experimental_config, - ) as prof: - stream.synchronize() - for _ in range(skip_first + (wait + warmup + active) * repeat): - triton_unk_fused_embedding_sum_5[[grid_size, 1, 1]](arg0_1, arg2_1, buf44, y0_numel, x1_numel, r2_numel, - 128, 32, num_warps=32, force_simt_only=True) - prof.step() - stream.synchronize() - triton_unk_fused_embedding_sum_5[[grid_size, 1, 1]](arg0_1, arg2_1, buf44, y0_numel, x1_numel, r2_numel, 128, 32, - num_warps=32, force_simt_only=True) - torch_out = torch_red_fused_embedding_sum_vec(arg0_1, arg2_1) - torch.testing.assert_close(buf44, torch_out, rtol=1e-04, atol=1e-04, equal_nan=True) diff --git a/third_party/ascend/examples/simt_perf_cases/test_index_select_flaggems.py b/third_party/ascend/examples/simt_perf_cases/test_index_select_flaggems.py deleted file mode 100644 index 3547288c7..000000000 --- a/third_party/ascend/examples/simt_perf_cases/test_index_select_flaggems.py +++ /dev/null @@ -1,160 +0,0 @@ -import logging -import triton -import triton.language as tl -import torch -import torch_npu -import pytest - - -@triton.jit -def index_select_kernel_dim_0(inp, out, N, index, index_len, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): - """flaggems originl index_select implementation on dim 0""" - pid_x = tl.program_id(axis=0) - pid_y = tl.program_id(axis=1) - cols_offsets = pid_y * BLOCK_N + tl.arange(0, BLOCK_N) - cols_mask = cols_offsets < N - rows_offsets = pid_x * BLOCK_M + tl.arange(0, BLOCK_M) - indices = tl.load(index + rows_offsets, mask=(rows_offsets < index_len), other=0) - inp_off = indices[:, None] * N + cols_offsets[None, :] - out_off = rows_offsets[:, None] * N + cols_offsets[None, :] - selected = tl.load(inp + inp_off, mask=cols_mask[None, :], other=0.0) - tl.store(out + out_off, selected, mask=cols_mask[None, :]) - - -@triton.jit -def index_select_kernel_dim_1(inp, out, M, N, index, index_len, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): - """flaggems originl index_select implementation on dim 1""" - pid_x = tl.program_id(axis=0) - pid_y = tl.program_id(axis=1) - rows_offsets = pid_x * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] - rows_mask = rows_offsets < M - cols_offsets = pid_y * BLOCK_N + tl.arange(0, BLOCK_N) - out_mask = rows_mask and (cols_offsets < index_len) - indices = tl.load(index + cols_offsets, mask=(cols_offsets < index_len), other=0) - inp_off = rows_offsets * N + indices[None, :] - out_off = rows_offsets * index_len + cols_offsets[None, :] - selected = tl.load(inp + inp_off, mask=rows_mask, other=0.0) - tl.store(out + out_off, selected, mask=out_mask) - - -def get_grid(dim, index_len, BLOCK_M, BLOCK_N, M, N): - if dim == 0: - grid_M_size = triton.cdiv(index_len, BLOCK_M) - grid_N_size = triton.cdiv(N, BLOCK_N) - else: - grid_M_size = triton.cdiv(M, BLOCK_M) - grid_N_size = triton.cdiv(index_len, BLOCK_N) - logging.info("grid_M_size:%d, grid_N_size:%d", grid_M_size, grid_N_size) - return grid_M_size, grid_N_size - - -def kernel_select(dim, grid_M_size, grid_N_size, inp, out, M, N, index, index_len, BLOCK_M, BLOCK_N): - """kernel select""" - if dim == 0: - index_select_kernel_dim_0[[ - grid_M_size, - grid_N_size, - ]]( - inp, - out, - N, - index, - index_len, - BLOCK_M, - BLOCK_N, - num_warps=32, - force_simt_only=True, - ) - else: - index_select_kernel_dim_1[[ - grid_M_size, - grid_N_size, - ]]( - inp, - out, - M, - N, - index, - index_len, - BLOCK_M, - BLOCK_N, - num_warps=32, - force_simt_only=True, - ) - - -@pytest.mark.parametrize( - "param_list", - [ - [[26, 140], 0, 23, 32, 32], - [[3, 16], 0, 3, 32, 32], - [[9, 6], 0, 9, 32, 32], - [[992, 16], 0, 632, 32, 32], - [[500000, 37], 0, 322364, 32, 32], - [[500000, 240], 0, 375144, 64, 64], - [[500000, 37], 0, 324344, 32, 32], - [[500000, 240], 0, 377816, 64, 64], - [[64, 64], 1, 6, 32, 32], - ], -) -def test_index_select_flaggems(param_list): - shape, dim, index_size, BLOCK_M, BLOCK_N = param_list - inp = torch.randn(shape, dtype=torch.float32, device="npu") - index = torch.randint(0, inp.size(dim), [index_size], device="npu") - golden = torch.index_select(inp, dim, index) - index_len = index.numel() - if index.ndim == 0: - index = index.unsqueeze(0) - dim = dim % inp.ndim - inp_shape = list(inp.shape) - N = inp_shape[1] - M = inp.numel() // N - if dim != 0 and dim != 1: - logging.error("error dim:%d", dim) - return - grid_M_size, grid_N_size = get_grid(dim, index_len, BLOCK_M, BLOCK_N, M, N) - out_shape = list(inp.shape) - out_shape[dim] = index_len - out = torch.empty(out_shape, dtype=inp.dtype, device=inp.device) - # Perf testing - result_path = "./result_profiling" - skip_first = 10 - wait = 0 - warmup = 3 - active = 30 - repeat = 1 - stream = torch.npu.current_stream() - experimental_config = torch_npu.profiler._ExperimentalConfig( - aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, - profiler_level=torch_npu.profiler.ProfilerLevel.Level1, - l2_cache=False, - data_simplification=False, - ) - with torch_npu.profiler.profile( - activities=[ - torch_npu.profiler.ProfilerActivity.CPU, - torch_npu.profiler.ProfilerActivity.NPU, - ], - schedule=torch_npu.profiler.schedule( - wait=wait, - warmup=warmup, - active=active, - repeat=repeat, - skip_first=skip_first, - ), - on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(result_path), - record_shapes=True, - profile_memory=False, - with_stack=False, - with_flops=False, - with_modules=False, - experimental_config=experimental_config, - ) as prof: - stream.synchronize() - for _ in range(skip_first + (wait + warmup + active) * repeat): - kernel_select(dim, grid_M_size, grid_N_size, inp, out, M, N, index, index_len, BLOCK_M, BLOCK_N) - prof.step() - stream.synchronize() - # Correctness testing - kernel_select(dim, grid_M_size, grid_N_size, inp, out, M, N, index, index_len, BLOCK_M, BLOCK_N) - torch.testing.assert_close(golden, out, rtol=1e-04, atol=1e-04, equal_nan=True) diff --git a/third_party/ascend/examples/simt_perf_cases/test_index_select_flaggems_ascend.py b/third_party/ascend/examples/simt_perf_cases/test_index_select_flaggems_ascend.py deleted file mode 100644 index 732f9b3e2..000000000 --- a/third_party/ascend/examples/simt_perf_cases/test_index_select_flaggems_ascend.py +++ /dev/null @@ -1,134 +0,0 @@ -import logging -import triton -import triton.language as tl -import torch -import torch_npu -import pytest - - -@triton.jit -def index_select_kernel_ascend_dim_0(inp, out, N, index, index_len, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): - """flaggems ascend index_select implementation on dim 0""" - pid_x = tl.program_id(axis=0) - pid_y = tl.program_id(axis=1) - grid_x = tl.num_programs(0) - grid_y = tl.num_programs(1) - for x in range(pid_x * BLOCK_M, index_len, grid_x * BLOCK_M): - rows_offsets = x + tl.arange(0, BLOCK_M) - indices = tl.load(index + rows_offsets, mask=(rows_offsets < index_len), other=0) - for y in range(pid_y * BLOCK_N, N, grid_y * BLOCK_N): - cols_offsets = y + tl.arange(0, BLOCK_N) - cols_mask = cols_offsets < N - inp_off = indices[:, None] * N + cols_offsets[None, :] - out_off = rows_offsets[:, None] * N + cols_offsets[None, :] - selected = tl.load(inp + inp_off, mask=cols_mask[None, :], other=0.0) - tl.store(out + out_off, selected, mask=cols_mask[None, :]) - - -def get_grid(): - import triton.runtime.driver as driver - num_cores = driver.active.utils.get_aivector_core_num() - logging.info("grid_M_size:%d, grid_N_size:%d", num_cores, 1) - return num_cores, 1 - - -@pytest.mark.parametrize( - "param_list", - [ - [[26, 140], 0, 23, 32, 32], - [[3, 16], 0, 3, 32, 32], - [[9, 6], 0, 9, 32, 32], - [[992, 16], 0, 632, 32, 32], - [[500000, 37], 0, 322364, 32, 32], - [[500000, 240], 0, 375144, 64, 64], - [[500000, 37], 0, 324344, 32, 32], - [[500000, 240], 0, 377816, 64, 64], - ], -) -def test_index_select_flaggems_ascend(param_list): - shape, dim, index_size, BLOCK_M, BLOCK_N = param_list - inp = torch.randn(shape, dtype=torch.float32, device="npu") - index = torch.randint(0, inp.size(dim), [index_size], device="npu") - golden = torch.index_select(inp, dim, index) - index_len = index.numel() - if index.ndim == 0: - index = index.unsqueeze(0) - dim = dim % inp.ndim - inp_shape = list(inp.shape) - N = inp_shape[1] - M = inp.numel() // N - if dim != 0: - logging.error("error dim:%d", dim) - return - grid_M_size, grid_N_size = get_grid() - out_shape = list(inp.shape) - out_shape[dim] = index_len - out = torch.empty(out_shape, dtype=inp.dtype, device=inp.device) - # Perf testing - result_path = "./result_profiling" - skip_first = 10 - wait = 0 - warmup = 3 - active = 30 - repeat = 1 - stream = torch.npu.current_stream() - experimental_config = torch_npu.profiler._ExperimentalConfig( - aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, - profiler_level=torch_npu.profiler.ProfilerLevel.Level1, - l2_cache=False, - data_simplification=False, - ) - with torch_npu.profiler.profile( - activities=[ - torch_npu.profiler.ProfilerActivity.CPU, - torch_npu.profiler.ProfilerActivity.NPU, - ], - schedule=torch_npu.profiler.schedule( - wait=wait, - warmup=warmup, - active=active, - repeat=repeat, - skip_first=skip_first, - ), - on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(result_path), - record_shapes=True, - profile_memory=False, - with_stack=False, - with_flops=False, - with_modules=False, - experimental_config=experimental_config, - ) as prof: - stream.synchronize() - for _ in range(skip_first + (wait + warmup + active) * repeat): - index_select_kernel_ascend_dim_0[[ - grid_M_size, - grid_N_size, - ]]( - inp, - out, - N, - index, - index_len, - BLOCK_M, - BLOCK_N, - num_warps=32, - force_simt_only=True, - ) - prof.step() - stream.synchronize() - # Correctness testing - index_select_kernel_ascend_dim_0[[ - grid_M_size, - grid_N_size, - ]]( - inp, - out, - N, - index, - index_len, - BLOCK_M, - BLOCK_N, - num_warps=32, - force_simt_only=True, - ) - torch.testing.assert_close(golden, out, rtol=1e-04, atol=1e-04, equal_nan=True) diff --git a/third_party/ascend/examples/simt_perf_cases/test_reduce.py b/third_party/ascend/examples/simt_perf_cases/test_reduce.py deleted file mode 100644 index 5727935cd..000000000 --- a/third_party/ascend/examples/simt_perf_cases/test_reduce.py +++ /dev/null @@ -1,88 +0,0 @@ -import triton -import triton.language as tl -import torch -import torch_npu -import pytest - - -@triton.jit -def triton_unk_reduce(in_ptr0, out_ptr0, y0_numel, x1_numel, X1BLOCK_SUB: tl.constexpr): - y0_offset = tl.program_id(0) - grid_size = tl.num_programs(0) - base_x1 = tl.arange(0, X1BLOCK_SUB) - loops_x1 = (x1_numel + X1BLOCK_SUB - 1) // X1BLOCK_SUB - for y0 in range(y0_offset, y0_numel, grid_size): - _tmp8 = tl.full([X1BLOCK_SUB], 0, tl.float32) - for loop_x1 in range(loops_x1): - x1 = (loop_x1 * X1BLOCK_SUB) + base_x1 - x1_mask = x1 < x1_numel - tmp0 = tl.load(in_ptr0 + x1_numel * y0 + x1, x1_mask, other=0.0) - _tmp8 += tmp0 - tmp8 = tl.sum(_tmp8, 0) - tl.store(out_ptr0 + y0, tmp8) - - -def torch_reduce(arg0): - return arg0.sum(dim=1) - - -@pytest.mark.parametrize( - "param_list", - [ - [128, 40000], - ], -) -def test_reduce(param_list): - y0_numel, x1_numel = param_list - arg0_1 = torch.randn(y0_numel, x1_numel, dtype=torch.float32, device="npu") - buf44 = torch.empty((y0_numel), dtype=torch.float32, device="npu") - grid_size = 64 - result_path = "./result_profiling" - skip_first = 10 - wait = 0 - warmup = 3 - active = 30 - repeat = 1 - stream = torch.npu.current_stream() - experimental_config = torch_npu.profiler._ExperimentalConfig( - aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, - profiler_level=torch_npu.profiler.ProfilerLevel.Level1, - l2_cache=False, - data_simplification=False, - ) - with torch_npu.profiler.profile( - activities=[ - torch_npu.profiler.ProfilerActivity.CPU, - torch_npu.profiler.ProfilerActivity.NPU, - ], - schedule=torch_npu.profiler.schedule( - wait=wait, - warmup=warmup, - active=active, - repeat=repeat, - skip_first=skip_first, - ), - on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(result_path), - record_shapes=True, - profile_memory=False, - with_stack=False, - with_flops=False, - with_modules=False, - experimental_config=experimental_config, - ) as prof: - stream.synchronize() - for _ in range(skip_first + (wait + warmup + active) * repeat): - triton_unk_reduce[[grid_size, 1, 1]]( - arg0_1, - buf44, - y0_numel, - x1_numel, - 4096, - num_warps=32, - force_simt_only=True, - ) - prof.step() - stream.synchronize() - triton_unk_reduce[[grid_size, 1, 1]](arg0_1, buf44, y0_numel, x1_numel, 4096, num_warps=32, force_simt_only=True) - torch_out = torch_reduce(arg0_1) - torch.testing.assert_close(buf44, torch_out, rtol=1e-04, atol=1e-04, equal_nan=True) diff --git a/third_party/ascend/examples/tutorials/05-matrix-multiplication-flagtree.py b/third_party/ascend/examples/tutorials/05-matrix-multiplication-flagtree.py deleted file mode 100644 index 83894e807..000000000 --- a/third_party/ascend/examples/tutorials/05-matrix-multiplication-flagtree.py +++ /dev/null @@ -1,177 +0,0 @@ -""" -Matrix Multiplication (Flagtree Hints Version) -=============== -""" - -import triton -import triton.language as tl -import torch -import torch_npu - -DEV = "npu" -activation = "leaky_relu_custom" - - -def get_autotune_config(): - return [ - triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128}), - triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64}), - ] - - -@triton.autotune( - configs=get_autotune_config(), - key=["M", "N", "K"], -) -@triton.jit -def matmul_kernel( - # Pointers to matrices - a_ptr, b_ptr, c_ptr, - # Matrix dimensions - M, N, K, - # The stride variables represent how much to increase the ptr by when moving by 1 - # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` - # by to get the element one row down (A has M rows). - stride_am, stride_ak, # - stride_bk, stride_bn, # - stride_cm, stride_cn, - # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # - ACTIVATION: tl.constexpr, # -): - """Kernel for computing the matmul C = A x B. - A has shape (M, K), B has shape (K, N) and C has shape (M, N) - """ - GROUP_SIZE_M: tl.constexpr = 1 - # ----------------------------------------------------------- - # Map program ids `pid` to the block of C it should compute. - # This is done in a grouped ordering to promote L2 data reuse. - # See above `L2 Cache Optimizations` section for details. - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - # ---------------------------------------------------------- - # Create pointers for the first blocks of A and B. - # We will advance this pointer as we move in the K direction - # and accumulate - # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers - # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers - # See above `Pointer Arithmetic` section for details - offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs_base = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) - b_ptrs_base = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) - msk_m = offs_am < M - msk_n = offs_bn < N - # ----------------------------------------------------------- - # Iterate to compute a block of the C matrix. - # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block - # of fp32 values for higher accuracy. - # `accumulator` will be converted back to fp16 after the loop. - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - # Load the next block of A and B, generate a mask by checking the K dimension. - # If it is out of bounds, set it to 0. - a_ptrs = a_ptrs_base + k * BLOCK_SIZE_K * stride_ak - b_ptrs = b_ptrs_base + k * BLOCK_SIZE_K * stride_bk - a = tl.load( - a_ptrs, - mask=msk_m[:, None] and (offs_k[None, :] < K - k * BLOCK_SIZE_K), - other=0.0, - ) - b = tl.load( - b_ptrs, - mask=msk_n[None, :] and (offs_k[:, None] < K - k * BLOCK_SIZE_K), - other=0.0, - ) - # We accumulate along the K dimension. - accumulator = tl.dot(a, b, accumulator) - # You can fuse arbitrary activation functions here - # while the accumulator is still in FP32! - # Original vector operations - # if ACTIVATION == "leaky_relu_custom": - # accumulator = leaky_relu_custom(accumulator) - # c = accumulator.to(tl.float16) - # # ----------------------------------------------------------- - # # Write back the block of the output matrix C with masks. - # offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - # offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - # c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] - # c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) - # tl.store(c_ptrs, c, mask=c_mask) - # Comment out the following lines to enable split the workload to two vector cores using flagtree hints - SUB_BLK_M: tl.constexpr = BLOCK_SIZE_M // 2 - for s in range(0, 2): # @hint: bind_sub_block - vec_sub_blk = tl.extract_slice(accumulator, (s * SUB_BLK_M, 0), (SUB_BLK_M, BLOCK_SIZE_N), (1, 1)) - if ACTIVATION == "leaky_relu_custom": - vec_sub_blk = leaky_relu_custom(vec_sub_blk) - c_sub_blk = vec_sub_blk.to(tl.float16) - # ----------------------------------------------------------- - # Write back the block of the output matrix C with masks. - offs_cm = pid_m * BLOCK_SIZE_M + s * SUB_BLK_M + tl.arange(0, SUB_BLK_M) - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) - tl.store(c_ptrs, c_sub_blk, mask=c_mask) - - -# We can fuse `leaky_relu_custom` by providing it as an `ACTIVATION` meta-parameter in `matmul_kernel`. -@triton.jit -def leaky_relu_custom(x): - return tl.where(x >= 0, x, 0.01 * x) + 1.0 - - -def torch_matmul(a, b, activation=""): - c = torch.matmul(a, b) - if activation == "leaky_relu_custom": - c = torch.where(c >= 0, c, 0.01 * c) + 1.0 - return c - - -# %% -# We can now create a convenience wrapper function that only takes two input tensors, -# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel. - - -def matmul(a, b, activation=""): - # Check constraints. - assert a.shape[1] == b.shape[0], "Incompatible dimensions" - assert a.is_contiguous(), "Matrix A must be contiguous" - M, K = a.shape - K, N = b.shape - # Allocates output. - c = torch.empty((M, N), device=a.device, dtype=torch.float16) - # 1D launch kernel where each block gets its own program. - grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), ) - matmul_kernel[grid]( - a, b, c, # - M, N, K, # - a.stride(0), a.stride(1), # - b.stride(0), b.stride(1), # - c.stride(0), c.stride(1), # - ACTIVATION=activation, # - ) - return c - - -# %% -# Unit Test -# --------- -# -# We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS). -torch.npu.set_device(1) -torch.manual_seed(0) -a = torch.randn((512, 512), device=DEV, dtype=torch.float16) -b = torch.randn((512, 512), device=DEV, dtype=torch.float16) -triton_output = matmul(a, b, activation) -torch_output = torch_matmul(a, b, activation) -print(f"triton_output_with_fp16_inputs={triton_output}") -print(f"torch_output_with_fp16_inputs={torch_output}") diff --git a/third_party/ascend/examples/tutorials/13-matrix-multiplication-optimized-flagtree.py b/third_party/ascend/examples/tutorials/13-matrix-multiplication-optimized-flagtree.py deleted file mode 100644 index a8d6a9794..000000000 --- a/third_party/ascend/examples/tutorials/13-matrix-multiplication-optimized-flagtree.py +++ /dev/null @@ -1,192 +0,0 @@ -import torch -import torch_npu -import triton -import triton.language as tl -import triton.runtime.driver as driver - - -# get device properties of npu -def get_npu_properties(): - device = torch.npu.current_device() - return driver.active.utils.get_device_properties(device) - - -@triton.autotune( - configs=[ - triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 4}), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 5}), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 6}), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 7}), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 8}), - triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 4}), - triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 5}), - triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 6}), - triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 7}), - triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 256, "BLOCK_TRESHHOLD": 8}), - ], key=["M", "N", "K"]) -@triton.jit -def matmul_kernel( - mat_a, - mat_b, - mat_c, - M: tl.constexpr, - N: tl.constexpr, - K: tl.constexpr, - num_cores: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - BLOCK_TRESHHOLD: tl.constexpr, -): - pid = tl.program_id(axis=0) - task_m_idx = 0 - task_n_idx = 0 - ''' - 水平分核方式每个任务块编号如下 - [0, 1, 2, 3, 4, 5, 6, 7] - [8, 9, 10, 11, 12, 13, 14, 15] - [16, 17, 18, 19, 20, 21, 22, 23] - [24, 25, 26, 27, 28, 29, 30, 31] - [32, 33, 34, 35, 36, 37, 38, 39] - [40, 41, 42, 43, 44, 45, 46, 47] - [48, 49, 50, 51, 52, 53, 54, 55] - [56, 57, 58, 59, 60, 61, 62, 63] - 0核处理 0 20 40 60 4块任务 - 1核处理 1 21 41 61 4块任务 - 2核处理 2 22 42 62 4块任务 - ... - 19核处理 19 39 59 3块任务 - - 大shape下如果使用传统水平分核方式,会有如下问题 - 1:同一时间大量核心需要访问同一块左矩阵内存,产生Bank冲突,导致硬件访问效率降低 - 2:当完成一整行mat_c运算时,已经将所有右矩阵数据全部使用上,右矩阵较大时会超过L2Cache的容量上限, - 从而导致L2Cache的搬入及换出,此后每行运算都会或多或少产生CacheMiss,导致L2Cche命中率较低,影响 - 算子执行效率 - 此处使用8 * 8对角线分核方式可以按8 * 8的方块沿对角线方向分核计算,可以很大程度优化上面两点。 - - 此处以8*8对角线分核为例,实际以BLOCK_TRESHHOLD为tune参数选择最优的阈值 - 8 * 8 对角线分核方式中,每8 * 8分格内任务块编号如下 - [0, 8, 16, 24, 32, 40, 48, 56] - [57, 1, 9, 17, 25, 33, 41, 49] - [50, 58, 2, 10, 18, 26, 34, 42] - [43, 51, 59, 3, 11, 19, 27, 35] - [36, 44, 52, 60, 4, 12, 20, 28] - [29, 37, 45, 53, 61, 5, 13, 21] - [22, 30, 38, 46, 54, 62, 6, 14] - [15, 23, 31, 39, 47, 55, 63, 7] - - M轴方向超过8个基本块时,使用对角线分核可以明显减小Bank冲突 - 当右矩阵大小超过L2Cache大小时,采取对角线分核可以提升L2Cache利用率 - 所以当矩阵在M和N方向均超过8块时使能对角线分核即可有优化,当右矩阵大小超过L2Cache大小时优化效果尤为明显 - ''' - NUM_BLOCKS_M = triton.cdiv(M, BLOCK_M) - NUM_BLOCKS_N = triton.cdiv(N, BLOCK_N) - NUM_BLOCKS = NUM_BLOCKS_M * NUM_BLOCKS_N - #当任务量较多时,可以使能对角线分核策略进行优化 - if NUM_BLOCKS_M >= BLOCK_TRESHHOLD and NUM_BLOCKS_N >= BLOCK_TRESHHOLD: - for block_idx in range(pid, NUM_BLOCKS, num_cores): - #8 * 8 对角线分核代码实现 - curThresholdM = BLOCK_TRESHHOLD if block_idx < ( - NUM_BLOCKS_M // BLOCK_TRESHHOLD * BLOCK_TRESHHOLD) * NUM_BLOCKS_N else NUM_BLOCKS_M % BLOCK_TRESHHOLD - curThresholdM_thresholdN = curThresholdM * BLOCK_TRESHHOLD - curThresholdN = BLOCK_TRESHHOLD if block_idx % (NUM_BLOCKS_N * BLOCK_TRESHHOLD) < ( - curThresholdM * - NUM_BLOCKS_N) // curThresholdM_thresholdN * curThresholdM_thresholdN else NUM_BLOCKS_N % BLOCK_TRESHHOLD - localRelativeBlock = block_idx % (BLOCK_TRESHHOLD * NUM_BLOCKS_N) % (BLOCK_TRESHHOLD * curThresholdM) - task_m_idx = localRelativeBlock % curThresholdM + block_idx // (BLOCK_TRESHHOLD * - NUM_BLOCKS_N) * BLOCK_TRESHHOLD - #求最小公倍数,方便求基本块的坐标 - x, y = curThresholdM, curThresholdN if curThresholdM > curThresholdN else curThresholdN, curThresholdM - while y != 0: - x, y = y, x % y - lcm = curThresholdM * curThresholdN // x - task_n_idx = (localRelativeBlock + (localRelativeBlock // lcm)) % curThresholdN + block_idx % ( - BLOCK_TRESHHOLD * NUM_BLOCKS_N) // curThresholdM_thresholdN * BLOCK_TRESHHOLD - - m_start = task_m_idx * BLOCK_M - n_start = task_n_idx * BLOCK_N - - mat_c_block = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) - for k_start in range(0, K, BLOCK_K): - mat_a_offset = ( - (m_start + tl.arange(0, BLOCK_M)) * K)[:, None] + (k_start + tl.arange(0, BLOCK_K))[None, :] - mat_a_mask = ((m_start + tl.arange(0, BLOCK_M)) < M)[:, None] & ( - (k_start + tl.arange(0, BLOCK_K)) < K)[None, :] - mat_a_block = tl.load(mat_a + mat_a_offset, mask=mat_a_mask, other=0.0) # @hint: dot_pad_only_k - mat_b_offset = ( - (k_start + tl.arange(0, BLOCK_K)) * N)[:, None] + (n_start + tl.arange(0, BLOCK_N))[None, :] - mat_b_mask = ((k_start + tl.arange(0, BLOCK_K)) < K)[:, None] & ( - (n_start + tl.arange(0, BLOCK_N)) < N)[None, :] - mat_b_block = tl.load(mat_b + mat_b_offset, mask=mat_b_mask, other=0.0) # @hint: dot_pad_only_k - mat_c_block = tl.dot(mat_a_block, mat_b_block, mat_c_block) - mat_c_offset = ((m_start + tl.arange(0, BLOCK_M)) * N)[:, None] + (n_start + tl.arange(0, BLOCK_N))[None, :] - mat_c_mask = ((m_start + tl.arange(0, BLOCK_M)) < M)[:, None] & ( - (n_start + tl.arange(0, BLOCK_N)) < N)[None, :] - tl.store(mat_c + mat_c_offset, mat_c_block.to(tl.bfloat16), mask=mat_c_mask) - else: - #传统顺序分核 - for block_idx in range(pid, NUM_BLOCKS, num_cores): - task_m_idx = block_idx // NUM_BLOCKS_N - task_n_idx = block_idx % NUM_BLOCKS_N - m_start = task_m_idx * BLOCK_M - n_start = task_n_idx * BLOCK_N - - mat_c_block = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) - for k_start in range(0, K, BLOCK_K): - mat_a_offset = ( - (m_start + tl.arange(0, BLOCK_M)) * K)[:, None] + (k_start + tl.arange(0, BLOCK_K))[None, :] - mat_a_mask = ((m_start + tl.arange(0, BLOCK_M)) < M)[:, None] & ( - (k_start + tl.arange(0, BLOCK_K)) < K)[None, :] - mat_a_block = tl.load(mat_a + mat_a_offset, mask=mat_a_mask, other=0.0) # @hint: dot_pad_only_k - mat_b_offset = ( - (k_start + tl.arange(0, BLOCK_K)) * N)[:, None] + (n_start + tl.arange(0, BLOCK_N))[None, :] - mat_b_mask = ((k_start + tl.arange(0, BLOCK_K)) < K)[:, None] & ( - (n_start + tl.arange(0, BLOCK_N)) < N)[None, :] - mat_b_block = tl.load(mat_b + mat_b_offset, mask=mat_b_mask, other=0.0) # @hint: dot_pad_only_k - mat_c_block = tl.dot(mat_a_block, mat_b_block, mat_c_block) - mat_c_offset = ((m_start + tl.arange(0, BLOCK_M)) * N)[:, None] + (n_start + tl.arange(0, BLOCK_N))[None, :] - mat_c_mask = ((m_start + tl.arange(0, BLOCK_M)) < M)[:, None] & ( - (n_start + tl.arange(0, BLOCK_N)) < N)[None, :] - tl.store(mat_c + mat_c_offset, mat_c_block.to(tl.bfloat16), mask=mat_c_mask) - - -def triton_matmul( - mat_a, - mat_b, -): - m = mat_a.shape[0] - k = mat_a.shape[1] - n = mat_b.shape[1] - mat_c = torch.empty(m, n, dtype=mat_a.dtype, device=mat_a.device) - ''' - NPU芯片更加亲和512B对齐场景,如下分块通用性能较好,可以使用autotune选取最优 - BLOCK_M = 128 - BLOCK_N = 256 - BLOCK_K = 256 - ''' - - num_cores = get_npu_properties()["num_aicore"] - - matmul_kernel[(num_cores, )](mat_a, mat_b, mat_c, m, n, k, num_cores) - return mat_c - - -if __name__ == "__main__": - M = 2048 - K = 7168 - N = 16384 - - mat_a = torch.randn([M, K], dtype=torch.bfloat16, device="npu") - mat_b = torch.randn([K, N], dtype=torch.bfloat16, device="npu") - - result = triton_matmul(mat_a, mat_b) - golden = torch.matmul(mat_a, mat_b) - - mask = golden.abs() < 1.0 - tmpatol = tmprtol = 2**-6 - try: - torch.testing.assert_close(result[mask], golden[mask], atol=tmpatol, rtol=0) - torch.testing.assert_close(result[~mask], golden[~mask], atol=0, rtol=tmprtol) - print("run matmul success") - except: - print(f"[ERROR] M={M} ,K={K}, N={N}存在精度问题") diff --git a/third_party/ascend/language/ascend/__init__.py b/third_party/ascend/language/ascend/__init__.py deleted file mode 100644 index 229b57d87..000000000 --- a/third_party/ascend/language/ascend/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from . import libdevice - -__all__ = ["libdevice"] diff --git a/third_party/ascend/language/cann/__init__.py b/third_party/ascend/language/cann/__init__.py new file mode 100644 index 000000000..b1dd3a1a0 --- /dev/null +++ b/third_party/ascend/language/cann/__init__.py @@ -0,0 +1,50 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +from . import libdevice +from . import extension + +extension.parallel = extension.aux_ops.parallel +libdevice.atan2 = extension.math_ops.atan2 +libdevice.isfinited = extension.math_ops.isfinited +libdevice.finitef = extension.math_ops.finitef +libdevice.flip = extension.flip + +from triton.language import math + +libdevice.umulhi = math.umulhi +libdevice.exp = math.exp +libdevice.exp2 = math.exp2 +libdevice.log = math.log +libdevice.log2 = math.log2 +libdevice.cos = math.cos +libdevice.sin = math.sin +libdevice.sqrt = math.sqrt +libdevice.sqrt_rn = math.sqrt_rn +libdevice.rsqrt = math.rsqrt +libdevice.div_rn = math.div_rn +libdevice.erf = math.erf +libdevice.floor = math.floor +libdevice.ceil = math.ceil +libdevice.fdiv = math.fdiv +libdevice.fma = math.fma +libdevice.abs = math.abs + +__all__ = ["libdevice", "extension"] diff --git a/third_party/ascend/language/cann/extension/__init__.py b/third_party/ascend/language/cann/extension/__init__.py new file mode 100644 index 000000000..20c339bc8 --- /dev/null +++ b/third_party/ascend/language/cann/extension/__init__.py @@ -0,0 +1,122 @@ +try: + import acl + is_compile_on_910_95 = acl.get_soc_name().startswith("Ascend910_95") +except Exception as e: + is_compile_on_910_95 = False + +from .core import ( + ascend_address_space, + builtin, + CORE, + copy_from_ub_to_l1, + debug_barrier, + fixpipe, + FixpipeDMAMode, + FixpipeDualDstMode, + FixpipePreQuantMode, + FixpipePreReluMode, + int64, + is_builtin, + MODE, + PIPE, + sub_vec_id, + sub_vec_num, + sync_block_all, + sync_block_set, + sync_block_wait, + SYNC_IN_VF, +) + +from .scope import scope + +from .custom_op import ( + custom, + custom_semantic, + register_custom_op, +) + +from . import builtin_custom_ops + +from .math_ops import (atan2, isfinited, finitef) + +from .aux_ops import ( + parallel, + compile_hint, + multibuffer, +) + +from .vec_ops import ( + insert_slice, + extract_slice, + get_element, + sort, + flip, + cast, +) + +from .mem_ops import ( + index_select, + index_put, + gather_out_to_ub, + scatter_ub_to_out, + index_select_simd, +) + +__all__ = [ + # core + "builtin", + "copy_from_ub_to_l1", + "CORE", + "debug_barrier", + "fixpipe", + "FixpipeDMAMode", + "FixpipeDualDstMode", + "FixpipePreQuantMode", + "FixpipePreReluMode", + "int64", + "is_builtin", + "MODE", + "PIPE", + "sub_vec_id", + "sub_vec_num", + "sync_block_all", + "SYNC_IN_VF", + + # address space + "ascend_address_space", + + # scope + "scope", + + # custom op + "custom", + "custom_semantic", + "register_custom_op", + + # math ops + "atan2", + "isfinited", + "finitef", + + # aux ops + "sync_block_set", + "sync_block_wait", + "parallel", + "compile_hint", + "multibuffer", + + # vec ops + "insert_slice", + "extract_slice", + "get_element", + "sort", + "flip", + "cast", + + # mem ops + "index_select", + "index_put", + "gather_out_to_ub", + "scatter_ub_to_out", + "index_select_simd", +] diff --git a/third_party/ascend/language/cann/extension/_utils.py b/third_party/ascend/language/cann/extension/_utils.py new file mode 100644 index 000000000..6fbf826a1 --- /dev/null +++ b/third_party/ascend/language/cann/extension/_utils.py @@ -0,0 +1,54 @@ +import triton.language.core as tl +from triton._C.libtriton import ir + + +def custom_op(builder: ir.builder, op_name: str, **kwargs): + if op_name == "sync_block_all": + return builder.create_custom_op_for_inter_core_sync(op_name, kwargs["mode"], kwargs["event_id"]) + + elif op_name == "sync_block_set": + return builder.create_custom_op_for_inter_core_sync(op_name, kwargs["sender"], kwargs["event_id"]) + + elif op_name == "sync_block_wait": + return builder.create_custom_op_for_inter_core_sync(op_name, kwargs["sender"], kwargs["event_id"]) + + raise ValueError(f"Unsupported custom op: {op_name}") + + +def _is_int_like_elem(x) -> bool: + """Accept int / tl.constexpr(int) / tl.tensor(int*).""" + if isinstance(x, int): + return True + if isinstance(x, tl.constexpr): + # constexpr value should be python int + return isinstance(x.value, int) + if isinstance(x, tl.tensor): + # Offsets/strides must be integer typed (i32/i64 etc.) + return x.dtype.is_int() + return False + + +def _assert_int_like_tuple(name: str, xs): + assert isinstance(xs, (tuple, list)), f"{name} should be a tuple/list, but got {type(xs)}" + assert all(_is_int_like_elem(x) for x in xs), f"{name} should be integer" + + +def _convert_elem_to_ir_value(builder, elem, require_i64): + if isinstance(elem, int): + elem = tl.constexpr(elem) + if isinstance(elem, tl.constexpr): + if require_i64: + assert -2**63 <= elem.value < 2**63, f"Block pointers only support 64 bit `shape/strides`, " \ + f"got a value {elem.value} which is out of the range" + return builder.get_int64(elem.value) + else: + assert -2**31 <= elem.value < 2**31, f"Block pointers only support 32 bit `offsets/block_shape`, " \ + f"got a value {elem.value} which is out of the range" + return builder.get_int32(elem.value) + elif isinstance(elem, tl.tensor): + if require_i64: + return builder.create_int_cast(elem.handle, builder.get_int64_ty(), elem.dtype.is_int_signed()) + else: + return builder.create_int_cast(elem.handle, builder.get_int32_ty(), elem.dtype.is_int_signed()) + else: + assert False, f"Unsupported element type in shape/strides/offsets: {type(elem)}" diff --git a/third_party/ascend/language/cann/extension/aux_ops.py b/third_party/ascend/language/cann/extension/aux_ops.py new file mode 100644 index 000000000..d872e44ac --- /dev/null +++ b/third_party/ascend/language/cann/extension/aux_ops.py @@ -0,0 +1,159 @@ +import triton.language as tl +from triton.language import semantic, core, standard +from triton.language.core import (_constexpr_to_value, _tensor_member_fn, _unwrap_iterable, builtin, constexpr, dtype, + tensor, check_bit_width, _unwrap_if_constexpr, range) +from triton.language.semantic import ( + wrap_tensor, + _str_to_rounding_mode, + not_equal, + _str_to_dot_input_precision, + binary_op_type_checking_impl, + integer_promote_impl, + broadcast_impl_shape, + _str_to_sem, + _str_to_scope, + bitcast, + bitwise_op_type_checking_impl, + to_tensor, + _str_to_load_cache_modifier, + _str_to_eviction_policy, + _str_to_padding_option, + _canonicalize_boundary_check, +) + +from typing import Optional, Tuple, List, overload, Union +from triton._C.libtriton import ir +from ._utils import custom_op + + +@_tensor_member_fn +@builtin +def sync_block_all(mode, event_id, _builder=None): + import warnings + + warnings.warn( + ("This method would be deprecated. Use al.sync_block_all instead."), + DeprecationWarning, + stacklevel=1, + ) + mode = _constexpr_to_value(mode) + event_id = _constexpr_to_value(event_id) + assert isinstance(mode, str), f"mode: {mode} is not string" + assert isinstance(event_id, int) and (event_id >= 0) and (event_id < 16), f"event_id: {event_id} should be 0 ~ 15" + assert mode == "all_cube" or mode == "all_vector" or mode == "all", f"ERROR: mode = {mode}, only supports all_cube/all_vector/all" + custom_op(_builder, "sync_block_all", mode=mode, event_id=event_id) + + +@_tensor_member_fn +@builtin +def sync_block_set(sender, receiver, event_id, _builder=None): + import warnings + + warnings.warn( + ("This method would be deprecated. Use al.sync_block_set instead."), + DeprecationWarning, + stacklevel=1, + ) + sender = _constexpr_to_value(sender) + receiver = _constexpr_to_value(receiver) + event_id = _constexpr_to_value(event_id) + assert isinstance(sender, str) and (sender == "cube" + or sender == "vector"), f"ERROR: sender = {sender}, only supports cube/vector" + assert isinstance(receiver, str) and (receiver == "cube" or receiver + == "vector"), f"ERROR: receiver = {receiver}, only supports cube/vector" + assert isinstance(event_id, int) and (event_id >= 0) and (event_id < 16), f"event_id: {event_id} should be 0 ~ 15" + if sender == receiver: + raise ValueError(f'Unexpected pair: {sender} -> {receiver}, only supports cube -> vector or vector -> cube') + custom_op(_builder, "sync_block_set", sender=sender, event_id=event_id) + + +@_tensor_member_fn +@builtin +def sync_block_wait(sender, receiver, event_id, _builder=None): + import warnings + + warnings.warn( + ("This method would be deprecated. Use al.sync_block_wait instead."), + DeprecationWarning, + stacklevel=1, + ) + sender = _constexpr_to_value(sender) + receiver = _constexpr_to_value(receiver) + event_id = _constexpr_to_value(event_id) + assert isinstance(sender, str) and (sender == "cube" + or sender == "vector"), f"ERROR: sender = {sender}, only supports cube/vector" + assert isinstance(receiver, str) and (receiver == "cube" or receiver + == "vector"), f"ERROR: receiver = {receiver}, only supports cube/vector" + assert isinstance(event_id, int) and (event_id >= 0) and (event_id < 16), f"event_id: {event_id} should be 0 ~ 15" + if sender == receiver: + raise ValueError(f'Unexpected pair: {sender} -> {receiver}, only supports cube -> vector or vector -> cube') + custom_op(_builder, "sync_block_wait", sender=sender, event_id=event_id) + + +class parallel(range): + """ + Iterator that counts upward forever, with parallel execution semantics. + + This is a special iterator used to implement similar semantics to Python's :code:`range` in the context of + :code:`triton.jit` functions. In addition, it allows user to pass extra attributes to the compiler. + :param bind_sub_block: Tells the compiler if multiple vector cores participate in the loop. + This is used in the mixed cube-vector kernel on 910B. The number of vector cores is determined by the number of + iteration in this loop. Currently on 910B, max 2 vector cores could be used. + """ + + def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_factor=None, + bind_sub_block: bool = False): + super().__init__(arg1, arg2, step, num_stages, loop_unroll_factor) + self.bind_sub_block = bind_sub_block + + +def compile_hint_impl(ptr: tensor, hint_name: str, hint_val, builder: ir.builder): + # simt mode does not support hint annotations + # FIXME: is_simt_mode + # if builder.is_simt_mode(): + # return + if not hint_val: + hint_val = builder.get_unit_attr() + elif isinstance(hint_val, bool): + hint_val = builder.get_bool_attr(hint_val) + elif isinstance(hint_val, int): + hint_val = builder.get_int32_attr(hint_val) + elif isinstance(hint_val, core.constexpr): + hint_val = builder.get_str_attr(hint_val.value) + elif isinstance(hint_val, list): + # only support i64 array attr for now + hint_val = builder.get_i64_array_attr(hint_val) + else: + raise ValueError(f"Unsupported hint value type: {type(hint_val)}") + builder.create_annotation(ptr.handle, hint_name, hint_val) + + +@builtin +def compile_hint(ptr, hint_name, hint_val=None, _builder=None): + # simt mode does not support hint annotations + if _builder.is_simt_mode(): + return + + def _unwrap(val): + return _unwrap_if_constexpr(val) if val else val + + hint_name = _constexpr_to_value(hint_name) + assert isinstance(hint_name, str), f"hint name: {hint_name} is not string" + if isinstance(hint_val, list): + hint_val = [_unwrap(val) for val in hint_val] + else: + hint_val = _unwrap(hint_val) + hint_val = _unwrap_if_constexpr(hint_val) if hint_val else hint_val + compile_hint_impl(ptr, hint_name, hint_val, _builder) + + +@builtin +def multibuffer(src: tensor, size, _builder=None): + """ + Set multi_buffer for an existing tensor + :src: tensor set to bufferize multiple time + :size: number of copies + """ + buffer_size = _constexpr_to_value(size) + assert isinstance(buffer_size, int) and buffer_size == 2, f"only support bufferize equals 2" + compile_hint_impl(src, "multi_buffer", buffer_size, _builder) diff --git a/third_party/ascend/language/cann/extension/builder.py b/third_party/ascend/language/cann/extension/builder.py new file mode 100644 index 000000000..cfd4be3b0 --- /dev/null +++ b/third_party/ascend/language/cann/extension/builder.py @@ -0,0 +1,84 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +""" +Ascend-specific builder utilities for code generation. +""" + +__all__ = [ + "create_builder_method_wrapper", + "attach_builder_methods", + "setup_unified_builder", +] + + +def create_builder_method_wrapper(main_builder, delegate_builder, method_name): + """ + Create a wrapper that delegates a method call to another builder while + synchronizing insertion points and locations. + """ + delegate_method = getattr(delegate_builder, method_name) + + def wrapper(*args, **kwargs): + saved_ip = main_builder.get_insertion_point() + saved_loc = main_builder.get_loc() + delegate_builder.restore_insertion_point(saved_ip) + if saved_loc: + delegate_builder.set_loc(saved_loc) + result = delegate_method(*args, **kwargs) + main_builder.restore_insertion_point(saved_ip) + if saved_loc: + main_builder.set_loc(saved_loc) + return result + + wrapper.__name__ = method_name + wrapper.__doc__ = getattr(delegate_method, '__doc__', None) + return wrapper + + +def attach_builder_methods(main_builder, delegate_builder, method_names): + """Attach multiple methods from a delegate builder to the main builder.""" + for method_name in method_names: + wrapper = create_builder_method_wrapper(main_builder, delegate_builder, method_name) + setattr(main_builder, method_name, wrapper) + + +def setup_unified_builder(main_builder, ascend_builder): + """Set up a unified builder interface by attaching methods from specialized builders.""" + main_builder._ascend_builder = ascend_builder + ascend_methods = [ + 'create_scope_op', + 'scope_return', + 'get_t_core_type_attr_name', + 'get_t_core_type_cube_attr', + 'get_t_core_type_vector_attr', + 'get_target_attribute', + 'create_get_sub_vec_id', + 'create_copy_buffer', + 'create_copy_tensor', + 'create_fixpipe', + 'create_bind_buffer', + 'create_debug_barrier', + 'is_910_95', + "sync_block_set", + "sync_block_wait", + "create_convert_layout", + 'sync_block_all', + ] + attach_builder_methods(main_builder, ascend_builder, ascend_methods) diff --git a/third_party/ascend/language/cann/extension/builtin_custom_ops.py b/third_party/ascend/language/cann/extension/builtin_custom_ops.py new file mode 100644 index 000000000..49ac7bed5 --- /dev/null +++ b/third_party/ascend/language/cann/extension/builtin_custom_ops.py @@ -0,0 +1,220 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. +# Copyright 2018-2020 Philippe Tillet +# Copyright 2020-2022 OpenAI +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import triton.language.core as tl +from .custom_op import register_custom_op +from .core import CORE, PIPE, MODE +from ._utils import _is_int_like_elem, _assert_int_like_tuple + + +@register_custom_op +class _index_select: + """ + This operation gathers values from the src GM tensor into the out UB tensor + at positions with offsets specified by the index UB tensor along the specified + dimension using a SIMT template. This operation supports 2D–5D. + + Arguments: + - src: pointer type, the source tensor pointer (in GM) + - index: tensor, a tensor to gather (in UB) + - dim: int, the dimension to gather along + - bound: int, the upper boundary for index + - end_offset: tuple of int, the end offsets of each dimension for index tensor + - start_offset: tuple of int, the start offsets of each dimension for src tensor + - src_stride: tuple of int, the stride of each dimension of src tensor + - other(Optional): scalar value, the default value when index is out of boundary (in UB) + - out: the output tensor (in UB) + + Note: + - Supported source ranks: 2D ~ 5D. + - Supported index ranks: 1D or 2D. + - `dim` must be valid (0 <= dim < source ranks). + + Reference formula: + Index select operation for different tensor ranks: + 1. 2D index gather (0 <= dim <= 1) + 1.1 dim = 0, index_rank = 1, src_rank = 2, out_rank = 2 + index_shape = (Ai,) + end_offset = (Ai_end, B_end) + start_offset = (0, B_begin) + out[i][0:B_end-B_begin] = src[index[i]][B_begin:B_end] + 1.2 dim = 0, index_rank = 2, src_rank = 2, out_rank = 3 + index_shape = (Ai, Aj) + end_offset = (Ai_end, Aj_end, B_end) + start_offset = (0, B_begin) + out[i][j][0:B_end-B_begin] = src[index[i][j]][B_begin:B_end] + 2. 3D index gather (0 <= dim <= 2) + 2.1 dim = 0, index_rank = 2, src_rank = 3, out_rank = 4 + index_shape = (Ai, Aj) + end_offset = (Ai_end, Aj_end, B_end, C_end) + start_offset = (0, B_begin, C_begin) + out[i][j][0:B_end-B_begin][0:C_end-C_begin] = src[index[i][j]][B_begin:B_end][C_begin:C_end] + and so on. + """ + name = '__builtin_index_select' + core = CORE.VECTOR + pipe = PIPE.PIPE_V + mode = MODE.SIMT + + def __init__(self, src, index, dim, bound: tl.int64, end_offset, start_offset, src_stride, other=None, out=None): + assert src.type.is_ptr() or src.dtype.is_ptr(), f"src should be a pointer, but got {src.type}" + assert index.dtype.is_int(), "index should be integer tensor" + src_rank = len(src_stride) + idx_rank = len(index.shape) + assert 2 <= src_rank <= 5, f"src rank should in [2, 5], but got {src_rank}" + assert 1 <= idx_rank <= 2, f"index rank should in [1, 2], but got {idx_rank}" + assert _is_int_like_elem(dim), "dim should be an integer" + assert _is_int_like_elem(bound), "bound should be an integer" + assert 0 <= dim < src_rank, f"dim should in [0, {src_rank - 1}], but got {dim}" + assert len(start_offset) == len(src_stride), "start_offset and src_stride should have same size" + assert len(end_offset) == idx_rank + len( + start_offset) - 1, "len(end_offset) should be equal to index rank + len(start_offset) - 1" + + _assert_int_like_tuple("end_offset", end_offset) + _assert_int_like_tuple("start_offset", start_offset) + _assert_int_like_tuple("src_stride", src_stride) + + assert out, "out is required" + assert out.dtype == src.dtype.element_ty, "out should have same dtype as src" + + # use index type for end_offset, start_offset and src_stride. + self.arg_type['end_offset'] = index.dtype + self.arg_type['start_offset'] = index.dtype + self.arg_type['src_stride'] = index.dtype + self.extra_attr = f"src_stride_len={len(src_stride)}" + + +@register_custom_op +class _index_put: + """This operation assigns values from the value UB tensor to the dst GM buffer + at positions with offsets specified by the index UB tensor along the specified + scatter dimension with SIMT template. This operation supports 2D-5D. + + Arguments: + - dst: the destination tensor pointer (in GM) + - index: the index tensor (in UB) + - value: the value tensor to be put (in UB) + - dim: the dimension on which index is applied + - bound: upper bound of index + - dst_shape: the shape of destination tensor + - dst_offset: the offset of each dimension in destination tensor + - dst_stride: the stride of each dimension in destination tensor + """ + name = '__builtin_index_put' + core = CORE.VECTOR + pipe = PIPE.PIPE_V + mode = MODE.SIMT + + def __init__(self, dst, index, value, dim, bound: tl.int64, dst_shape, dst_offset, dst_stride): + assert dst.type.is_ptr() or dst.dtype.is_ptr(), f"dst should be a pointer, but got {dst.type}" + assert index.dtype.is_int(), "index should be integer tensor" + value_rank = len(value.shape) + assert 2 <= value_rank <= 5, f"value rank should in [2, 5], but got {value_rank}" + assert isinstance(dim, int), "dim should be an integer" + assert isinstance(bound, int), "bound should be an integer" + assert 0 <= dim < value_rank - 1, f"dim should in [0, {value_rank - 1}), but got {dim}" + assert len(dst_shape) == len(dst_offset), "dst_shape and dst_offset should have same size" + assert len(dst_shape) == len(dst_stride), "dst_shape and dst_stride should have same size" + assert all(isinstance(x, int) for x in dst_shape), "dst_shape should all be integer" + assert all(isinstance(x, int) for x in dst_offset), "dst_offset should all be integer" + assert all(isinstance(x, int) for x in dst_stride), "dst_stride should all be integer" + + # use index type for dst_shape, dst_offset and dst_stride. + self.arg_type['dst_shape'] = index.dtype + self.arg_type['dst_offset'] = index.dtype + self.arg_type['dst_stride'] = index.dtype + + +@register_custom_op +class _gather_load: + """This operation takes a source memory GM buffer and a UB tensor of index, + and produces an output UB tensor by gathering elements from the source + at the index positions with offsets. This operation supports 1D-5D. + + Arguments: + - src: pointer to the source memory GM buffer + - index: UB tensor that specifying the position in the src + - bound: upper bound of index + - dim: the dimension to gather along + - src_stride: the stride of the source tensor + - index_shape: the shape of the index tensor + - offsets: the offsets of each dimension for index tensor + - out: the gathered UB tensor, with the same shape as index.shape + """ + name = '__builtin_gather_load' + core = CORE.VECTOR + pipe = PIPE.PIPE_V + mode = MODE.SIMT + + def __init__(self, src, index, bound: tl.int64, dim, src_stride: tl.int64, index_shape, offsets, out=None): + assert src.type.is_ptr() or src.dtype.is_ptr(), f"src should be a pointer, but got {src.type}" + assert index.dtype.is_int(), "index should be an integer tensor" + assert isinstance(bound, int), "bound should be an integer" + assert isinstance(dim, int), "dim should be an integer" + idx_rank = len(index.shape) + assert 1 <= idx_rank <= 5, f"index rank should in [1, 5], but got {idx_rank}" + assert 0 <= dim < idx_rank, f"dim should in [0, {idx_rank}), but got {dim}" + assert len(src_stride) == idx_rank, f"src_stride size should be {idx_rank}" + assert len(index_shape) == idx_rank, f"index_shape size should be {idx_rank}" + assert len(offsets) == idx_rank, f"offsets size should be {idx_rank}" + assert all(isinstance(x, int) for x in src_stride), "src_stride should all be integer" + assert all(isinstance(x, int) for x in index_shape), "index_shape should all be integer" + assert all(isinstance(x, int) for x in offsets), "offsets should all be integer" + assert out, "out is required" + assert out.shape == index.shape, "Output should have same shape as index" + + +@register_custom_op +class _scatter_store: + """This operation assigns values from the UB tensor to the dst GM buffer at positions with + offsets specified by the UB index tensor with SIMT template. this operation supports 2D-5D. + + Arguments: + - dst: the pointer of destination tensor GM memory buffer + - value: value tensor from UB to store + - index: index tensor from UB specifying positions in the destination tensor + - bound: upper bound of index + - dim: dimension along which the assignment operation is performed + - dst_stride: the strides of destination + - index_shape: the shape of index tensor + - offsets: the offsets for each dims of the index + """ + name = '__builtin_scatter_store' + core = CORE.VECTOR + pipe = PIPE.PIPE_V + mode = MODE.SIMT + + def __init__(self, dst, value, index, bound: tl.int64, dim, dst_stride: tl.int64, index_shape, offsets): + assert dst.type.is_ptr() or dst.dtype.is_ptr(), f"dst should be a pointer, but got {dst.type}" + assert index.dtype.is_int(), "index should be an integer tensor" + assert isinstance(value, tl.tensor), "value should be a tensor" + assert isinstance(bound, int), "bound should be an integer" + assert isinstance(dim, int), "dim should be an integer" + idx_rank = len(index.shape) + assert 1 <= idx_rank <= 5, f"index rank should in [1, 5], but got {idx_rank}" + assert 0 <= dim < idx_rank, f"dim should in [0, {idx_rank}), but got {dim}" + assert len(dst_stride) == idx_rank, f"dst_stride size should be {idx_rank}" + assert len(index_shape) == idx_rank, f"index_shape size should be {idx_rank}" + assert len(offsets) == idx_rank, f"offsets size should be {idx_rank}" + assert all(isinstance(x, int) for x in dst_stride), "dst_stride should all be integer" + assert all(isinstance(x, int) for x in index_shape), "index_shape should all be integer" + assert all(isinstance(x, int) for x in offsets), "offsets should all be integer" diff --git a/third_party/ascend/language/cann/extension/code_generator.py b/third_party/ascend/language/cann/extension/code_generator.py new file mode 100644 index 000000000..5bd4b67ff --- /dev/null +++ b/third_party/ascend/language/cann/extension/code_generator.py @@ -0,0 +1,201 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +""" +Ascend-specific code generation handlers for 'with' statement context managers. +""" + +__all__ = ["handle_scope_with", "mangle_ty"] +import ast + + +def mangle_ty(ty): + """ + Replacement implementation for triton.compiler.code_generator.mangle_ty. + + This is registered via ASCEND_WITH_DISPATCH["mangle_ty"] and picked up by + triton.compiler.code_generator through its global WITH_DISPATCH table. + """ + # Lazy imports to avoid circular dependencies at module import time. + from triton import language + from triton.extension.buffer.language import core as bl + + # Buffer types are Python-side dtypes; handle them first. + if isinstance(ty, bl.buffer_type): + elt = mangle_ty(ty.element_ty) + shape = "_".join(map(str, ty.shape)) + return f"B{elt}S{shape}S" + + if ty.is_ptr(): + return "P" + mangle_ty(ty.element_ty) + if ty.is_int(): + SIGNED = language.dtype.SIGNEDNESS.SIGNED + prefix = "i" if ty.int_signedness == SIGNED else "u" + return prefix + str(ty.int_bitwidth) + if ty.is_floating(): + return str(ty) + if ty.is_block(): + elt = mangle_ty(ty.scalar) + shape = "_".join(map(str, ty.shape)) + return f"{elt}S{shape}S" + if ty.is_void(): + return "V" + raise TypeError(f"Unsupported type {ty}") + + +def _extract_scope_attributes(context_expr): + """Extract attributes from scope(...) call.""" + scope_attrs = {} + for keyword in context_expr.keywords: + if isinstance(keyword.value, ast.Constant): + scope_attrs[keyword.arg] = keyword.value.value + return scope_attrs + + +def _py_value_to_mlir_attr(builder, value): + """Convert Python value to MLIR attribute.""" + attr_creators = { + str: lambda v: builder.get_str_attr(v), + bool: lambda v: builder.get_bool_attr(v), + int: lambda v: builder.get_int32_attr(v), + list: lambda v: builder.get_i64_array_attr(v), + } + creator = attr_creators.get(type(value)) + return creator(value) if creator else value + + +def _handle_core_mode_attr(builder, core_mode): + """Handle core_mode attribute conversion.""" + if core_mode not in ("cube", "vector"): + return {} + return { + builder.get_t_core_type_attr_name(): + (builder.get_t_core_type_cube_attr() if core_mode == "cube" else builder.get_t_core_type_vector_attr()) + } + + +def _build_mlir_attrs_from_scope_attrs(builder, scope_attrs): + """Convert Python scope attributes to MLIR attributes. + + Args: + builder: The IR builder + scope_attrs: Dict of scope attributes (e.g., {'core_mode': 'vector', 'noinline': True}) + + Returns: + Dict of MLIR attributes + """ + mlir_attrs = {"noinline": builder.get_unit_attr()} + for k, v in scope_attrs.items(): + if k == "core_mode": + mlir_attrs.update(_handle_core_mode_attr(builder, v)) + elif k == "noinline": + if not v: + mlir_attrs.pop("noinline") + elif k == "disable_auto_sync": + if v: + mlir_attrs["hivm.disable_auto_sync"] = _py_value_to_mlir_attr(builder, v) + else: + mlir_attrs[k] = _py_value_to_mlir_attr(builder, v) + return mlir_attrs + + +def _verify_loop_carried_variable(_is_triton_value, _is_triton_tensor, name, loop_val, live_val): + """Verify that loop-carried variable types are consistent.""" + assert _is_triton_value(loop_val), f'cannot reassign constxpr {name} in the loop' + assert _is_triton_value(live_val), f'cannot reasign constexpr {name} in the loop' + assert type(loop_val) == type(live_val), f'Loop carried variable {name} changed type' + assert not _is_triton_tensor(loop_val) or loop_val.type == live_val.type, \ + f'Loop-carried variable {name} has initial type {live_val.type} '\ + f'but is re-assigned to {loop_val.type} in loop! '\ + f'Please make sure that the type stays consistent.' + + +def _reconstruct_value_from_ir(language, entry_block_arg, ret_type): + """Reconstruct a tensor value from IR.""" + return language.core.tensor(entry_block_arg, ret_type) + + +def handle_scope_with(generator, node): + """ + Handle 'with scope(...)' statements by creating a scope.scope operation. + + This creates a scope.scope operation with a region for the scope block. + Uses SSA threading to properly handle variables modified inside the scope. + + Args: + generator: The CodeGenerator instance + node: AST node for the with statement + """ + # Lazy imports to avoid circular dependency + from triton import language + from triton.compiler.code_generator import enter_sub_region, _is_triton_value, _is_triton_tensor + + context_expr = node.items[0].context_expr + scope_attrs = _extract_scope_attributes(context_expr) + + with enter_sub_region(generator) as sr: + liveins, _ = sr + ip, last_loc = generator._get_insertion_point_and_loc() + + # This implementation is similar to visit_while + dummy = generator.builder.create_block() + generator.builder.set_insertion_point_to_start(dummy) + generator.visit_compound_statement(node.body) + scope_defs = generator.local_defs + dummy.erase() + + # Verify and get return type of the scope.scope + # (variables that exist in parent scope AND are modified in scope) + names = [] + ret_types = [] + for name in scope_defs: + scope_val = scope_defs[name] + ret_types.append(scope_val.type) + names.append(name) + if name in liveins: + live_val = liveins[name] + _verify_loop_carried_variable(_is_triton_value, _is_triton_tensor, name, scope_val, live_val) + + # Convert Python primitives to MLIR attributes + mlir_attrs = _build_mlir_attrs_from_scope_attrs(generator.builder, scope_attrs) + + # Create scope operation with operands (values from outside) + generator._set_insertion_point_and_loc(ip, last_loc) + scope_op = generator.builder.create_scope_op(mlir_attrs, [ty.to_ir(generator.builder) for ty in ret_types]) + + # Create the entry block with arguments matching the operands + entry_block = generator.builder.create_block_with_parent(scope_op.get_region(0), []) + generator.builder.set_insertion_point_to_start(entry_block) + + # Initialize the scope's symbol table with liveins + generator.lscope = liveins.copy() + generator.visit_compound_statement(node.body) + generator.builder.set_insertion_point_to_end(entry_block) + reconstructed_values = [] + + for i in range(len(names)): + # generator.lscope[names[i]] is already a tensor, just get its IR handle + reconstructed_values.append(generator.lscope[names[i]].handle) + generator.builder.scope_return(reconstructed_values) + + # After exiting enter_sub_region, update symbol table with results + # Convert IR values back to tensor objects + for i, name in enumerate(names): + generator.set_value(name, _reconstruct_value_from_ir(language, scope_op.get_result(i), ret_types[i])) + return None diff --git a/third_party/ascend/language/cann/extension/core.py b/third_party/ascend/language/cann/extension/core.py new file mode 100644 index 000000000..1710d7b96 --- /dev/null +++ b/third_party/ascend/language/cann/extension/core.py @@ -0,0 +1,312 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# Copyright 2018-2020 Philippe Tillet +# Copyright 2020-2022 OpenAI +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +__all__ = [ + "ascend_address_space", "builtin", "CORE", "copy_from_ub_to_l1", "debug_barrier", "fixpipe", "FixpipeDMAMode", + "FixpipeDualDstMode", "FixpipePreQuantMode", "FixpipePreReluMode", "int64", "is_builtin", "MODE", "PIPE", + "sub_vec_id", "sub_vec_num", "sync_block_all", "sync_block_set", "sync_block_wait", "SYNC_IN_VF" +] + +import enum +from typing import TypeVar, List, Union +from functools import wraps + +from triton._C.libtriton import ir +from triton._C.libtriton.ascend import ir as ascend_ir +import triton.language.core as tl + +import triton.extension.buffer.language as bl +from triton.language.core import _constexpr_to_value +from triton.backends.ascend.driver import NPUUtils + +from . import semantic as semantic + +PIPE = semantic.PIPE + +T = TypeVar("T") + +TRITON_BUILTIN = "__triton_builtin__" +ASCEND_BUILTIN = "__ascend_builtin__" + + +def builtin(fn: T) -> T: + """Mark a function as a buffer language builtin.""" + assert callable(fn) + + @wraps(fn) + def wrapper(*args, **kwargs): + if "_builder" not in kwargs or kwargs["_builder"] is None: + raise ValueError("Did you forget to add @triton.jit ? " + "(`_builder` argument must be provided outside of JIT functions.)") + return fn(*args, **kwargs) + + # also set triton_builtin to true so that CodeGenerator will recognize this function + setattr(wrapper, TRITON_BUILTIN, True) + setattr(wrapper, ASCEND_BUILTIN, True) + + return wrapper + + +def is_builtin(fn) -> bool: + """Is this a registered ascend language builtin function?""" + return getattr(fn, ASCEND_BUILTIN, False) + + +class int64(int): + """ + For custom op, python int argument will be converted to int32 by default, + if a device-side int64 is required, you can pass an al.int64(x) to it. + """ + + def __new__(cls, value): + obj = int.__new__(cls, value) + obj.type = tl.int64 + return obj + + +class CORE(enum.Enum): + VECTOR = ascend_ir.CoreType.VECTOR + CUBE = ascend_ir.CoreType.CUBE + CUBE_OR_VECTOR = ascend_ir.CoreType.CUBE_OR_VECTOR + CUBE_AND_VECTOR = ascend_ir.CoreType.CUBE_AND_VECTOR + + +class PIPE(enum.Enum): + PIPE_S = ascend_ir.PIPE.PIPE_S + PIPE_V = ascend_ir.PIPE.PIPE_V + PIPE_M = ascend_ir.PIPE.PIPE_M + PIPE_MTE1 = ascend_ir.PIPE.PIPE_MTE1 + PIPE_MTE2 = ascend_ir.PIPE.PIPE_MTE2 + PIPE_MTE3 = ascend_ir.PIPE.PIPE_MTE3 + PIPE_ALL = ascend_ir.PIPE.PIPE_ALL + PIPE_FIX = ascend_ir.PIPE.PIPE_FIX + + +class MODE(enum.Enum): + SIMD = ascend_ir.MODE.SIMD + SIMT = ascend_ir.MODE.SIMT + MIX = ascend_ir.MODE.MIX + + +class ascend_address_space_base(bl.address_space): + + def __init__(self, address_space_value: ascend_ir.AddressSpace) -> None: + super().__init__() + self.real_address_space = address_space_value + + def to_ir(self, builder: ir.builder) -> ir.attribute: + return builder.get_target_attribute(self.real_address_space) + + +class ascend_address_space_group: + + def __init__(self): + for k, v in {k: v + for k, v in ascend_ir.AddressSpace.__dict__.items() + if isinstance(v, ascend_ir.AddressSpace)}.items(): + setattr(self, k, ascend_address_space_base(v)) + + +ascend_address_space = ascend_address_space_group() + + +@builtin +def sub_vec_id(_builder=None) -> tl.tensor: + """ + Get the Vector Core index on the AI Core. + """ + return semantic.sub_vec_id(_builder) + + +@builtin +def copy_from_ub_to_l1(src: Union[tl.tensor, bl.buffer], dst: Union[tl.tensor, bl.buffer], _builder: None) -> None: + """ + Copies data from the Unified Buffer (UB) to the L1 Buffer. + + :param src: The source data located in the Unified Buffer. + :type src: tl.tensor | bl.buffer + :param dst: The destination buffer located in L1 memory. + :type dst: tl.tensor | bl.buffer + """ + return semantic.copy_from_ub_to_l1(src, dst, _builder) + + +def create_sync_block(sender, receiver, event_id, is_set: bool, sender_pipe=None, receiver_pipe=None, _builder=None): + sender = _constexpr_to_value(sender) + receiver = _constexpr_to_value(receiver) + assert isinstance(sender, str) and (sender == "cube" + or sender == "vector"), f"ERROR: sender = {sender}, only supports cube/vector" + assert isinstance(receiver, str) and (receiver == "cube" or receiver + == "vector"), f"ERROR: receiver = {receiver}, only supports cube/vector" + if isinstance(event_id, int): + assert (event_id >= 0) and (event_id < 16), f"event_id: {event_id} should be 0 ~ 15" + if sender == receiver: + raise ValueError(f'Unexpected pair: {sender} -> {receiver}, only supports cube -> vector or vector -> cube') + if sender_pipe is None and receiver_pipe is None: + if sender == "cube": + sender_pipe = PIPE.PIPE_FIX + receiver_pipe = PIPE.PIPE_MTE2 + if sender == "vector": + sender_pipe = PIPE.PIPE_MTE3 + receiver_pipe = PIPE.PIPE_MTE2 + if not isinstance(sender_pipe, PIPE) or not isinstance(receiver_pipe, PIPE): + raise TypeError("sender_pipe and receiver_pipe must be instances of PIPE enum") + if is_set: + return semantic.create_sync_block_set(sender, receiver, event_id, sender_pipe, receiver_pipe, _builder) + return semantic.create_sync_block_wait(sender, receiver, event_id, sender_pipe, receiver_pipe, _builder) + + +@builtin +def sync_block_set(sender, receiver, event_id, sender_pipe=None, receiver_pipe=None, _builder=None): + return create_sync_block(sender, receiver, event_id, True, sender_pipe, receiver_pipe, _builder) + + +@builtin +def sync_block_wait(sender, receiver, event_id, sender_pipe=None, receiver_pipe=None, _builder=None): + return create_sync_block(sender, receiver, event_id, False, sender_pipe, receiver_pipe, _builder) + + +@builtin +def sync_block_all(mode, event_id, _builder=None): + mode = _constexpr_to_value(mode) + event_id = _constexpr_to_value(event_id) + assert isinstance(mode, str), f"mode: {mode} is not string" + assert isinstance(event_id, int) and (event_id >= 0) and (event_id < 16), f"event_id: {event_id} should be 0 ~ 15" + assert mode in ("all_cube", "all_vector", "all", + "all_sub_vector"), f"ERROR: mode = {mode}, only supports all_cube/all_vector/all/all_sub_vector" + _builder.sync_block_all(mode, event_id) + + +class FixpipeDMAMode(enum.Enum): + NZ2DN = ascend_ir.FixpipeDMAMode.NZ2DN + NZ2ND = ascend_ir.FixpipeDMAMode.NZ2ND + NZ2NZ = ascend_ir.FixpipeDMAMode.NZ2NZ + + +class FixpipeDualDstMode(enum.Enum): + NO_DUAL = ascend_ir.FixpipeDualDstMode.NO_DUAL + COLUMN_SPLIT = ascend_ir.FixpipeDualDstMode.COLUMN_SPLIT + ROW_SPLIT = ascend_ir.FixpipeDualDstMode.ROW_SPLIT + + +class FixpipePreQuantMode(enum.Enum): + NO_QUANT = ascend_ir.FixpipePreQuantMode.NO_QUANT + F322BF16 = ascend_ir.FixpipePreQuantMode.F322BF16 + F322F16 = ascend_ir.FixpipePreQuantMode.F322F16 + S322I8 = ascend_ir.FixpipePreQuantMode.S322I8 + + +class FixpipePreReluMode(enum.Enum): + LEAKY_RELU = ascend_ir.FixpipePreReluMode.LEAKY_RELU + NO_RELU = ascend_ir.FixpipePreReluMode.NO_RELU + NORMAL_RELU = ascend_ir.FixpipePreReluMode.NORMAL_RELU + P_RELU = ascend_ir.FixpipePreReluMode.P_RELU + + +@builtin +def fixpipe( + src: tl.tensor, + dst: bl.buffer, + dma_mode: FixpipeDMAMode = FixpipeDMAMode.NZ2ND, + dual_dst_mode: FixpipeDualDstMode = FixpipeDualDstMode.NO_DUAL, + _builder=None, +) -> None: + """ + Directly store a tensor on L0C to a local buffer via fixpipe. + Fixpipe is pipeline that performing data movement from L0C to other memory hierarchies. + Currently support: + - L0C to UB (for Ascend910_95 sereies) + + :param src: the source tensor, Must be located in the l0C memory region. + :type src: tl.tensor + :param dst: The destination buffer, Must be located in the UB memory region. + :type dst: bl.buffer + :param dma_mode: DMA transfer mode, "nz2nd" enables NZ to ND layout transformation + :type dma_mode: str + """ + if not _builder.is_910_95(): + raise RuntimeError("this feature is only supported on Ascend910_95") + if not isinstance(src, tl.tensor): + raise TypeError("src is not of tensor type") + elif not isinstance(dst, bl.buffer): + raise TypeError("dst is not of buffer type") + if dst.space != ascend_address_space.UB: + raise TypeError("dst must be located in the UB memory region") + + if len(dst.shape) == 2 and (dst.type.element_ty == tl.float32 or dst.type.element_ty == tl.int32): + N = dst.shape[1] + if N % 8 != 0: + raise ValueError("32b Fixpipe last dim must be aligned to 8") + if (dma_mode != FixpipeDMAMode.NZ2ND) and (N % 16 != 0): + raise ValueError("32b non-NZ2ND Fixpipe last dim must be aligned to 16") + if (dual_dst_mode == FixpipeDualDstMode.COLUMN_SPLIT) and (N % 32 != 0): + raise ValueError("32b Column split dual Fixpipe last dim must be aligned to 32") + M = dst.shape[0] + if (dma_mode == FixpipeDMAMode.NZ2DN) and (M % 8 != 0): + raise ValueError("32b NZ2DN Fixpipe first dim must be aligned to 8") + dst16bits = (dst.type.element_ty == tl.float16 or dst.type.element_ty == tl.int16 + or dst.type.element_ty == tl.bfloat16) + if len(dst.shape) == 2 and dst16bits: + N = dst.shape[1] + if N % 16 != 0: + raise ValueError("16b Fixpipe last dim must be aligned to 16") + M = dst.shape[0] + if (dma_mode == FixpipeDMAMode.NZ2DN) and (M % 16 != 0): + raise ValueError("16b NZ2DN Fixpipe first dim must be aligned to 16") + + return semantic.fixpipe(src, dst, dma_mode, dual_dst_mode, FixpipePreQuantMode.NO_QUANT, FixpipePreReluMode.NO_RELU, + _builder) + + +class SYNC_IN_VF(enum.Enum): + VV_ALL = enum.auto() + VST_VLD = enum.auto() + VLD_VST = enum.auto() + VST_VST = enum.auto() + VS_ALL = enum.auto() + VST_LD = enum.auto() + VLD_ST = enum.auto() + VST_ST = enum.auto() + SV_ALL = enum.auto() + ST_VLD = enum.auto() + LD_VST = enum.auto() + ST_VST = enum.auto() + + +@builtin +def debug_barrier( + sync_mode: SYNC_IN_VF, + _builder=None, +) -> None: + return semantic.debug_barrier(sync_mode.name, _builder) + + +@builtin +def sub_vec_num(_builder=None) -> tl.constexpr: + """ + Get the Vector Core Num on one AI Core. + """ + npuUtils = NPUUtils() + cube_num = npuUtils.get_aivector_core_num() + vector_num = npuUtils.get_aicore_num() + const_val = cube_num // vector_num + return tl.constexpr(const_val) diff --git a/third_party/ascend/language/cann/extension/custom_op.py b/third_party/ascend/language/cann/extension/custom_op.py new file mode 100644 index 000000000..b3352c26b --- /dev/null +++ b/third_party/ascend/language/cann/extension/custom_op.py @@ -0,0 +1,273 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. +# Copyright 2018-2020 Philippe Tillet +# Copyright 2020-2022 OpenAI +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +__all__ = ["custom", "custom_semantic", "register_custom_op"] + +import inspect +import types +import typing +import itertools +import triton.language.core as tl +from . import core + +# Registry for custom op, mapping name to its configuration. +_custom_op_registry = {} + + +def _get_op_class(name): + # Try to get op class in _custom_op_registry. + op_class = _custom_op_registry.get(name) + if op_class is None: + # Allow bulitin custom ops used without registry. + assert name.startswith('__builtin_'), f"Custom Op '{name}' not registered." + # Return a dummy op class for builtin custom op. + op_class = type( + "_builtin_custom_op", (object, ), { + "name": name, + "core": core.CORE.VECTOR, + "pipe": core.PIPE.PIPE_V, + "mode": core.MODE.SIMT, + "signature": inspect.signature(object), + }) + return op_class + + +def _unwrap_constexpr(arg): + if isinstance(arg, tl.constexpr): + return arg.value + if isinstance(arg, tuple): + return tuple(_unwrap_constexpr(x) for x in arg) + if isinstance(arg, list): + return [_unwrap_constexpr(x) for x in arg] + if isinstance(arg, dict): + return {k: _unwrap_constexpr(v) for k, v in arg.items()} + return arg + + +def _to_value(value, builder, ty=None): + # Try to use 'type' attribute if ty not set. + ty = getattr(value, 'type', ty) if ty is None else ty + if isinstance(value, tl.tensor): + if not value.type.is_block() and isinstance(ty, tl.dtype) and value.type != ty: + # For a scalar variable, if its type is not the expected one + # that specified by type hint 'ty', insert a cast for it. + return tl.semantic.cast(value, ty, builder).handle + return value.handle + if isinstance(value, bool): + return builder.get_int1(value) + if isinstance(value, int): + if isinstance(ty, tl.dtype): + if ty.is_int64(): + return builder.get_int64(value) + if ty.is_uint64(): + return builder.get_uint64(value) + if ty.is_int32(): + return builder.get_int32(value) + if ty.is_uint32(): + return builder.get_uint32(value) + if ty.is_int16(): + return builder.get_int16(value) + if ty.is_uint16(): + return builder.get_uint16(value) + if ty.is_int8(): + return builder.get_int8(value) + if ty.is_uint8(): + return builder.get_uint8(value) + # default int32 + return builder.get_int32(value) + if isinstance(value, float): + if isinstance(ty, tl.dtype): + if ty.is_fp64(): + return builder.get_fp64(value) + if ty.is_fp32(): + return builder.get_fp32(value) + if ty.is_fp16(): + return builder.get_fp16(value) + if ty.is_bf16(): + return builder.get_bf16(value) + # default float32 + return builder.get_fp32(value) + if isinstance(value, tl.constexpr): + return _to_value(value.value, builder) + raise TypeError(f"Unsupported argument type {value} : {type(value)}") + + +def _to_operands(args, builder): + operands = [] + for value in args: + if value is None: + continue + if isinstance(value, (list, tuple)): + for item in value: + operands.append(_to_value(item, builder)) + else: + operands.append(_to_value(value, builder)) + return operands + + +def _get_element_type(ty): + if isinstance(ty, types.GenericAlias): + return typing.get_args(ty)[0] + return ty + + +def _args_to_operands(op, builder, args, kwargs): + if not op.signature.parameters: + # Without parameters in signature, use the actual parameter order. + return _to_operands(itertools.chain(args, kwargs.values()), builder) + + # Convert arguments to operands according the signature. + operands = [] + bind = op.signature.bind(*args, **kwargs) + for param in op.signature.parameters.values(): + value = bind.arguments.get(param.name) + if value is None: + continue + ty = op.arg_type.get(param.name, param.annotation) + if isinstance(value, (list, tuple)): + ty = _get_element_type(ty) + for item in value: + operands.append(_to_value(item, builder, ty)) + else: + operands.append(_to_value(value, builder, ty)) + return operands + + +def _add_optional_attr(op, name, builder, attrs): + if hasattr(op, name): + attrs[name] = builder.get_str_attr(getattr(op, name)) + + +def _make_attrs(op, builder): + attrs = { + 'hivm.tcore_type': builder.get_core_type_attr(op.core.value), + 'hivm.pipe': builder.get_pipe_attr(op.pipe.value), + 'hivm.vf_mode': builder.get_vf_mode_attr(op.mode.value), + } + _add_optional_attr(op, 'symbol', builder, attrs) + _add_optional_attr(op, 'source', builder, attrs) + _add_optional_attr(op, 'compile', builder, attrs) + # Extra attributes can be added here, such as op.extra_attr="attr_a=xx" + _add_optional_attr(op, 'extra_attr', builder, attrs) + return attrs + + +def _to_result(res, res_types): + assert (len(res) == len(res_types)) + n_res = len(res) + if n_res == 0: + return None + if n_res == 1: + return tl.tensor(res[0], res_types[0]) + return tuple(tl.tensor(res[i], res_types[i]) for i in range(n_res)) + + +def _init_op(op_class, *args, **kwargs): + op = op_class.__new__(op_class) + # Add arg_type dict to support dynamic argument type specifying. + setattr(op, 'arg_type', {}) + if op_class.signature.parameters: + # Init with arguments validate. + op_class.__init__(op, *args, **kwargs) + return op + + +def custom_semantic(name: str, *args, _builder=None, **kwargs): + name = _unwrap_constexpr(name) + # Get op class according the name. + op_class = _get_op_class(name) + # Convert constexpr to value in arguments. + args = _unwrap_constexpr(args) + kwargs = _unwrap_constexpr(kwargs) + # Create op instance from op class with the arguments. + op = _init_op(op_class, *args, **kwargs) + # Prepare inputs and outputs operands. + out = kwargs.pop('out', []) + outs = out if isinstance(out, (list, tuple)) else [out] + outputs = _to_operands(outs, _builder) + inputs = _args_to_operands(op, _builder, args, kwargs) + # Setup attributes. + attrs = _make_attrs(op, _builder) + # Build IR for the custom op. + res = _builder.create_custom_op(name, attrs, inputs, outputs) + # Results with same types as outputs. + res_types = [out.type for out in outs] + return _to_result(res, res_types) + + +@core.builtin +def custom(name: str, *args, _builder=None, **kwargs): + """Invoke a custom operation with the given name and arguments.""" + return custom_semantic(name, *args, _builder=_builder, **kwargs) + + +def register_custom_op(op): + """Register a custom operation so that we can invoke it using al.custom().""" + assert inspect.isclass(op), "@register_custom_op should decorate on a class." + # Use class name if name not set. + if not hasattr(op, 'name'): + setattr(op, 'name', op.__name__) + # The op name should not be used. + assert op.name not in _custom_op_registry, f"Custom op name '{op.name}' already used." + # Check required core, pipe, mode fields. + assert hasattr(op, 'core'), "'core' field is required." + assert hasattr(op, 'pipe'), "'pipe' field is required." + assert hasattr(op, 'mode'), "'mode' field is required." + assert isinstance(op.core, core.CORE), "Invalid 'core' field, CORE type is required." + assert isinstance(op.pipe, core.PIPE), "Invalid 'pipe' field, PIPE type is required." + assert isinstance(op.mode, core.MODE), "Invalid 'mode' field, MODE type is required." + # Retrieve arguments signature from __init__ method and save it. + signature = inspect.signature(op) + setattr(op, 'signature', signature) + # Register the custom op configuration. + _custom_op_registry[op.name] = op + return op + + +_dtype_cname_dict = { + 'int1': 'bool', + 'int8': 'int8_t', + 'int16': 'int16_t', + 'int32': 'int32_t', + 'int64': 'int64_t', + 'uint8': 'uint8_t', + 'uint16': 'uint16_t', + 'uint32': 'uint32_t', + 'uint64': 'uint64_t', + 'fp16': 'half', + 'bf16': 'bfloat16_t', + 'fp32': 'float', + 'fp64': 'double', + 'fp8e5': 'float8_e5m2_t', + 'fp8e4nv': 'float8_e4m3_t', + # other float8 types are not supported yet, + # such as 'fp8e4b8', 'fp8e4b15', 'fp8e5b16'. +} + + +def _cname(self): + """Return the corresponding C name of the given tl.dtype""" + return _dtype_cname_dict.get(self.name, self.name) + + +# Add 'cname' property to tl.dtype class. +tl.dtype.cname = property(_cname, None) diff --git a/third_party/ascend/language/cann/extension/dispatch.py b/third_party/ascend/language/cann/extension/dispatch.py new file mode 100644 index 000000000..d91abdee3 --- /dev/null +++ b/third_party/ascend/language/cann/extension/dispatch.py @@ -0,0 +1,33 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +""" +Dispatch table for Ascend-specific 'with' statement context managers. +""" + +from .scope import scope +from .code_generator import handle_scope_with, mangle_ty + +__all__ = ["ASCEND_WITH_DISPATCH"] + +# Registry of 'with' statement handlers for Ascend extension +ASCEND_WITH_DISPATCH = { + scope: handle_scope_with, + "mangle_ty": mangle_ty, +} diff --git a/third_party/ascend/language/cann/extension/math_ops.py b/third_party/ascend/language/cann/extension/math_ops.py new file mode 100644 index 000000000..cfd17a96f --- /dev/null +++ b/third_party/ascend/language/cann/extension/math_ops.py @@ -0,0 +1,54 @@ +from math import pi as math_pi +from triton.language import core, math +from triton.language.core import float32, int1 +from ..libdevice import atan, isnan, isinf +from triton.runtime.jit import jit + +pi: core.constexpr = math_pi + + +@core._tensor_member_fn +@jit +@math._add_math_2arg_docstr("atan2") +def atan2(y, x): + _is_int8_type_x: core.constexpr = x.dtype.is_int8() + core.static_assert(not _is_int8_type_x, f"Expected dtype fp16/fp32/bf16, but got int8 or int1") + _is_int8_type_y: core.constexpr = y.dtype.is_int8() + core.static_assert(not _is_int8_type_y, f"Expected dtype fp16/fp32/bf16, but got int8 or int1") + _is_floating_type_x: core.constexpr = x.dtype.is_floating() + core.static_assert(_is_floating_type_x == True, f"Expected dtype fp16/fp32/bf16, but got {core.constexpr(x.dtype)}") + _is_floating_type_y: core.constexpr = y.dtype.is_floating() + core.static_assert(_is_floating_type_y == True, f"Expected dtype fp16/fp32/bf16, but got {core.constexpr(y.dtype)}") + half_pi: core.constexpr = 0.5 * pi + base = core.where(x == 0, 0.0, atan(y.to(core.dtype("fp32")) / x.to(core.dtype("fp32")))) + base = core.where((x == 0) & (y > 0), half_pi, base) + base = core.where((x == 0) & (y < 0), -half_pi, base) + + add_pi = core.where((x < 0) & (y >= 0), pi, 0.0) + sub_pi = core.where((x < 0) & (y < 0), -pi, 0.0) + return (base + add_pi + sub_pi).to(x.dtype) + + +@core._tensor_member_fn +@jit +@math._add_math_1arg_docstr("isfinited") +def isfinited(x): + _is_int8_type: core.constexpr = x.dtype.is_int8() + core.static_assert(not _is_int8_type, f"Expected dtype fp16/fp32/bf16, but got int8 or int1") + _is_floating_type: core.constexpr = x.dtype.is_floating() + core.static_assert(_is_floating_type == True, f"Expected dtype fp16/fp32/bf16, but got {core.constexpr(x.dtype)}") + nan_mask = isnan(x) + inf_mask = isinf(x) + return (~nan_mask & ~inf_mask).to(int1) + + +@core._tensor_member_fn +@jit +@math._add_math_1arg_docstr("finitef") +def finitef(x): + _is_int8_type: core.constexpr = x.dtype.is_int8() + core.static_assert(not _is_int8_type, f"finitef only supports float32, but got int8 or int1") + core.static_assert(x.dtype == float32, f"finitef only supports float32, but got {core.constexpr(x.dtype)}") + nan_mask = isnan(x) + inf_mask = isinf(x) + return (~nan_mask & ~inf_mask).to(int1) diff --git a/third_party/ascend/language/cann/extension/mem_ops.py b/third_party/ascend/language/cann/extension/mem_ops.py new file mode 100644 index 000000000..859bbecd6 --- /dev/null +++ b/third_party/ascend/language/cann/extension/mem_ops.py @@ -0,0 +1,551 @@ +import numbers +import triton.language as tl +from triton.language import semantic as real_semantic +from triton.language.core import ( + _constexpr_to_value, + _tensor_member_fn, + _unwrap_iterable, + builtin, + constexpr, + dtype, + tensor, + check_bit_width, + _unwrap_if_constexpr, +) +from triton.language.semantic import ( + wrap_tensor, + _str_to_rounding_mode, + not_equal, + _str_to_dot_input_precision, + binary_op_type_checking_impl, + integer_promote_impl, + broadcast_impl_shape, + _str_to_sem, + _str_to_scope, + bitcast, + bitwise_op_type_checking_impl, + to_tensor, + _str_to_load_cache_modifier, + _str_to_eviction_policy, + _str_to_padding_option, + _canonicalize_boundary_check, +) + +from typing import Optional, Tuple, List, overload, Union +from triton._C.libtriton import ir + +from ._utils import _convert_elem_to_ir_value + + +@_tensor_member_fn +@builtin +def index_select(src: tensor, idx: tensor, bound, lstdim_blksiz, offsets, numels, _builder=None): + """ + Embedding + :src_ptr: + :idx: + """ + + def embedding_gather_impl(src: tl.tensor, idx: tl.tensor, bound: int, blksiz: int, offsets: Tuple, numels: Tuple, + builder: ir.builder) -> tl.tensor: + assert idx.dtype.is_int(), "index must be an integer tensor" + if not src.dtype.element_ty.is_floating(): + raise ValueError(f"Expected dtype fp16/fp32/bf16, but got {src.dtype.element_ty}") + + require_i64 = idx.dtype.is_int64() + # require_i64 = True + offsets = [_convert_elem_to_ir_value(builder, elem, require_i64) for elem in offsets] + numels = [_convert_elem_to_ir_value(builder, elem, require_i64) for elem in numels] + ret = builder.create_embedding_gather(src.handle, idx.handle, bound, blksiz, offsets, numels) + ret_shape = [_unwrap_if_constexpr(s) for s in idx.shape] + ret_shape.append(blksiz) + return wrap_tensor(ret, src.dtype.element_ty, ret_shape) + + bound = _constexpr_to_value(bound) + lstdim_blksiz = _constexpr_to_value(lstdim_blksiz) + + return embedding_gather_impl(src, idx, bound, lstdim_blksiz, offsets, numels, _builder) + + +@_tensor_member_fn +@builtin +def index_put(ptr: tensor, index: tensor, value: tensor, dim: int, index_boundary: int, end_offset: tuple, + start_offset: tuple, dst_stride: tuple, _builder=None): + """ + Index put values from a tensor into a destination tensor. + + Index put operation for different tensor ranks: + 1. 2D index scatter (0 <= dim < 1): + 1.1 dim = 0 + out[index[i]][start_offset[1]:end_offset[1]] = value[i][0:end_offset[1]-start_offset[1]] + 2. 3D index scatter (0 <= dim < 2): + 2.1 dim = 0 + out[index[i]][start_offset[1]:end_offset[1]][start_offset[2]:end_offset[2]] + = value[i][0:end_offset[1]-start_offset[1]][0:end_offset[2]-start_offset[2]] + 2.2 dim = 1 + out[start_offset[0]:end_offset[0]][index[j]][start_offset[2]:end_offset[2]] + = value[0:end_offset[0]-start_offset[0]][j][0:end_offset[2]-start_offset[2]] + + + :param ptr: pointer type, the destination tensor pointer (in GM) + :param index: tensor, a index to scatter (in UB) + :param value: tensor, a value to store (in UB) + :param dim: int32, the dimension to scatter along + :param index_boundary: int64, the upper boundary for index values + :param end_offset: tuple of int, the offsets of each dimension for the end of the scatter region + :param start_offset: tuple of int, the offsets of each dimension for the start of the scatter region + :param dst_stride: tuple of int, the stride of each dimension of destination tensor + + Constraints + *********** + - `ptr` and `value` must have the same rank. + - `ptr.dtype` only supports `float16`, `bfloat16`, `float32` currently. + - `index` must be an integer tensor. If `index.rank` != 1, it will be reshaped to 1D. + - `index.numel` must equal `value.shape[dim]`. + - `value` support 2~5D tensors. + - `dim` must be valid (0 <= dim < rank(value) - 1). + + Example + ******* + .. code-block:: python + + import torch + import triton + import triton.language as tl + from triton.language.extra.cann.extension import index_put + + @triton.jit + def simple_index_put_kernel(value_ptr, index_ptr, dst_ptr): + # index tile shape: [2] + index_local = tl.arange(0, 2) + x1_local = tl.arange(0, 2)[None, :] # shape=(1,2) + + index_tile = tl.load(index_ptr + index_local) + value_tile = tl.load(value_ptr + index_local[:, None]*2 + x1_local) + + index_put( + ptr=dst_ptr, + index=index_tile, + value=value_tile, + dim=0, + index_boundary=4, + end_offset=(2, 2), + start_offset=(0, 0), + dst_stride=(2, 1) + ) + + dst = torch.zeros((4,2), device='npu', dtype=torch.float32) + value = torch.tensor([[1.,2.], [3.,4.]], device='npu') + index = torch.tensor([2, 0], device='npu') + + simple_index_put_kernel[(1,)](value, index, dst) + print("IndexPut result:", dst) # ref:[[3.,4.], [0.,0.], [1.,2.], [0.,0.]] + """ + + def index_put_impl(ptr: tl.tensor, index: tl.tensor, value: tl.tensor, dim: int, index_boundary: int, + end_offset: Tuple, start_offset: Tuple, dst_stride: Tuple, builder: ir.builder): + assert index.dtype.is_int(), "index must be an integer tensor" + if not ptr.dtype.element_ty.is_floating(): + raise ValueError(f"Expected dtype fp16/fp32/bf16, but got {ptr.dtype.element_ty}") + if not isinstance(dim, int): + raise ValueError("dim must be of type tl.constexpr") + + v_rank = len(value.shape) + idx_rank = len(index.shape) + if v_rank < 2 or v_rank > 5: + raise ValueError(f"value rank must be in [2, 5], got value rank={v_rank}") + if dim < 0 or dim >= v_rank - 1: + raise ValueError(f"dim must satisfy 0<=dim 5: + raise ValueError(f"index rank must be in [1, 5], got rank={idx_rank}") + if dim < 0 or dim >= idx_rank: + raise ValueError(f"dim must satisfy 0<=dim 5: + raise ValueError(f"index rank must be in [1, 5], got rank={idx_rank}") + if dim < 0 or dim >= idx_rank: + raise ValueError(f"dim must satisfy 0<=dim 0 + + dim = _constexpr_to_value(dim) + index_boundary = _constexpr_to_value(index_boundary) + value = _constexpr_to_value(value) + + if not _is_ranked_tensor(value) or isinstance(value, constexpr): + element_ty = ptr.type.scalar.element_ty + value = real_semantic.full(index.shape, value, element_ty, _builder) + return scatter_ub_to_out_impl(ptr, value, index, index_boundary, dim, dst_stride, end_offset, start_offset, + _builder) + + +@_tensor_member_fn +@builtin +def index_select_simd(src, dim, index, src_shape, src_offset, read_shape, _builder=None) -> tensor: + """ + Parallel index_select operation from Global Memory to Unified Buffer (SIMD version). + + Selects data from multiple indices along a specified dimension and loads + them as tiles from GM directly to UB with zero-copy semantics. + + :param src: Source tensor pointer (in GM) + :type src: tensor (pointer type) + :param dim: The dimension along which to select indices + :type dim: int or constexpr + :param index: 1D tensor of indices to select (in UB) + :type index: tensor + :param src_shape: Complete shape of the source tensor (can be int or tensor) + :type src_shape: List[Union[int, tensor]] + :param src_offset: Starting offset for reading (can be int or tensor) + :type src_offset: List[Union[int, tensor]] + :param read_shape: Size to read (tile shape, can be int or tensor) + :type read_shape: List[Union[int, tensor]] + + **Constraints:** + + - ``read_shape[dim]`` must be ``-1`` + - ``src_offset[dim]`` can be ``-1`` (will be ignored) + - Boundary handling: ``src_offset + read_shape > src_shape`` automatically + truncates to ``src_shape`` boundary + - Does not check if ``index`` contains out-of-bounds values + + **Example:** + + .. code-block:: python + + @triton.jit + def kernel(src_ptr, output_ptr, indices_ptr, M, N, D, ...): + # Load indices (e.g., [5, 10, 15, 20]) + indices = tl.load(indices_ptr + tl.arange(0, 4)) + + # Example 1: Static shapes (constants) + # Index select from dimension 1 + # src: [8, 100, 256], index_select at dim=1 + # Read: [4, ?, 128] starting from [4, ?, 128] + result = extension.index_select_simd( + src_ptr, + dim=1, + index=indices, + src_shape=[8, 100, 256], + src_offset=[4, -1, 128], + read_shape=[4, -1, 128] + ) + # result shape: [4, 4, 128] + + # Example 2: Dynamic shapes (variables) + result2 = extension.index_select_simd( + src_ptr, + dim=1, + index=indices, + src_shape=[M, N, D], + src_offset=[4, -1, 128], + read_shape=[4, -1, 128] + ) + + tl.store(output_ptr + ..., result) + + :return: Result tensor in UB with shape where ``dim`` is replaced + by the length of ``index`` + :rtype: tensor + """ + + def index_select_simd_impl(src: tl.tensor, dim: int, index: tl.tensor, src_shape: List[Union[int, tl.tensor]], + src_offset: List[Union[int, tl.tensor]], read_shape: List[Union[int, tl.tensor]], + builder: ir.builder) -> tl.tensor: + # Validate inputs + ndim = len(src_shape) + assert len(src_offset) == ndim, \ + f"src_offset length {len(src_offset)} must match src_shape length {ndim}" + assert len(read_shape) == ndim, \ + f"read_shape length {len(read_shape)} must match src_shape length {ndim}" + assert 0 <= dim < ndim, \ + f"dim={dim} must be in range [0, {ndim})" + assert len(index.shape) == 1, \ + f"index must be 1D tensor, got {len(index.shape)}D" + assert dim < ndim - 1, \ + f"index_select_simd cannot support trailing dimension as dim={dim}, ndim={ndim}" + + newsrc_shape = [o.handle for o in src_shape] + newsrc_offset = [o.handle for o in src_offset] + # Create output type + return_shape = [index.shape[0] if i == dim else read_shape[i] for i in range(ndim)] + element_ty = src.type.element_ty + output_ty = tl.block_type(element_ty, return_shape) + out = builder.create_index_select_simd(src.handle, index.handle, dim, newsrc_shape, newsrc_offset, read_shape, + return_shape) + return tl.tensor(out, output_ty) + + dim = _constexpr_to_value(dim) + + # Process shape parameters: convert constexpr to values, keep tensors as-is + def process_param(val): + """Convert constexpr to value, keep tensor or int as-is""" + if isinstance(val, tensor): + return val + else: + return _constexpr_to_value(val) + + newsrc_shape = [real_semantic.to_tensor(o, _builder) if isinstance(o, constexpr) else o for o in src_shape] + newsrc_offset = [real_semantic.to_tensor(o, _builder) if isinstance(o, constexpr) else o for o in src_offset] + assert len(index.shape) == 1, "index must be a 1D tensor" + + return index_select_simd_impl(src, dim, index, newsrc_shape, newsrc_offset, read_shape, _builder) diff --git a/third_party/ascend/language/cann/extension/scope.py b/third_party/ascend/language/cann/extension/scope.py new file mode 100644 index 000000000..2db85d1be --- /dev/null +++ b/third_party/ascend/language/cann/extension/scope.py @@ -0,0 +1,71 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# Copyright 2018-2020 Philippe Tillet +# Copyright 2020-2022 OpenAI +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +__all__ = ["scope"] + +from triton.language.core import _constexpr_to_value + + +class scope: + """ + Context manager for entering and exiting a scope, where operations within a scope shares some common characteristics. + + Example: + ```python + import triton.language.extra.cann.extension as extension + + @triton.jit + def kernel(x_ptr, y_ptr, N): + # specify annotation + with extension.scope(feature_a=True): + a = tl.load(x_ptr) + b = tl.load(y_ptr) + result = tl.dot(a, b) + ``` + + Reserved keywords: + - `core_mode`: Allows explicitly specify which core type should be used for operations within a code block, helping the compiler generate appropriate code for cube or vector cores. + """ + + def __init__(self, core_mode: str, _builder=None, _semantic=None, **kwargs): + """ + :param core_mode: Either "cube" or "vector" to specify the core type + :param _builder: Internal builder object (set by code_generator) + :param _semantic: Internal semantic object (set by code_generator) + :param kwargs: Additional internal parameters + """ + # Convert constexpr to value if not being called from code generator + self.core_mode = _constexpr_to_value(core_mode) if _builder is None else core_mode + self._builder = _builder + self._semantic = _semantic + + # Validate core_mode + if self.core_mode not in ("cube", "vector"): + raise ValueError(f'core_mode must be "cube" or "vector", got {self.core_mode}') + + def __enter__(self): + if self._builder is None: + raise RuntimeError("scope can only be used inside a Triton kernel") + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + return False diff --git a/third_party/ascend/language/cann/extension/semantic.py b/third_party/ascend/language/cann/extension/semantic.py new file mode 100644 index 000000000..29df62e65 --- /dev/null +++ b/third_party/ascend/language/cann/extension/semantic.py @@ -0,0 +1,129 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# Copyright 2018-2020 Philippe Tillet +# Copyright 2020-2022 OpenAI +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +__all__ = [ + "fixpipe", + "create_address_space", +] + +import enum +from typing import (TypeVar, List, Union) + +from triton._C.libtriton import ir +from triton._C.libtriton.ascend import ir as ascend_ir +import triton.language.core as tl +import triton.language.extra.cann.extension as al +import triton.extension.buffer.language as bl + +from triton.language import semantic as real_semantic + +T = TypeVar('T') + + +def create_address_space(address_space: ascend_ir.AddressSpace, + builder: ascend_ir.ascendnpu_ir_builder) -> ir.attribute: + return builder.get_target_attribute(address_space) + + +class PIPE(enum.Enum): + PIPE_S = ascend_ir.PIPE.PIPE_S + PIPE_V = ascend_ir.PIPE.PIPE_V + PIPE_M = ascend_ir.PIPE.PIPE_M + PIPE_MTE1 = ascend_ir.PIPE.PIPE_MTE1 + PIPE_MTE2 = ascend_ir.PIPE.PIPE_MTE2 + PIPE_MTE3 = ascend_ir.PIPE.PIPE_MTE3 + PIPE_ALL = ascend_ir.PIPE.PIPE_ALL + PIPE_FIX = ascend_ir.PIPE.PIPE_FIX + + +def create_sync_block_set(sender, receiver, event_id, sender_pipe: PIPE, receiver_pipe: PIPE, _builder=None): + if isinstance(event_id, int): + _builder.sync_block_set(sender, receiver, + real_semantic.to_tensor(tl.constexpr(event_id), _builder).handle, sender_pipe.value, + receiver_pipe.value) + elif isinstance(event_id, tl.constexpr): + _builder.sync_block_set(sender, receiver, + real_semantic.to_tensor(event_id, _builder).handle, sender_pipe.value, + receiver_pipe.value) + else: + _builder.sync_block_set(sender, receiver, event_id.handle, sender_pipe.value, receiver_pipe.value) + + +def create_sync_block_wait(sender, receiver, event_id, sender_pipe: PIPE, receiver_pipe: PIPE, _builder=None): + if isinstance(event_id, int): + _builder.sync_block_wait(sender, receiver, + real_semantic.to_tensor(tl.constexpr(event_id), _builder).handle, sender_pipe.value, + receiver_pipe.value) + elif isinstance(event_id, tl.constexpr): + _builder.sync_block_wait(sender, receiver, + real_semantic.to_tensor(event_id, _builder).handle, sender_pipe.value, + receiver_pipe.value) + else: + _builder.sync_block_wait(sender, receiver, event_id.handle, sender_pipe.value, receiver_pipe.value) + + +def sub_vec_id(builder: ascend_ir.ascendnpu_ir_builder) -> tl.tensor: + return tl.tensor(builder.create_get_sub_vec_id(), tl.int64) + + +def copy_from_ub_to_l1(src: Union[tl.tensor, bl.buffer], dst: Union[tl.tensor, bl.buffer], builder): + if not builder.is_910_95(): + raise RuntimeError("this feature is only supported on Ascend910_95") + if isinstance(src, tl.tensor) or isinstance(dst, tl.tensor): + raise TypeError("tensor not support yet") + if src.shape != dst.shape: + raise TypeError("src and dst must have same shape") + if src.dtype != dst.dtype: + raise TypeError("src and dst need to have the same type") + if isinstance(src, bl.buffer) and isinstance(dst, bl.buffer): + if src.space != al.ascend_address_space.UB: + raise TypeError("src's AddressSpace must be UB") + if dst.space != al.ascend_address_space.L1: + raise TypeError("dst's AddressSpace must be L1") + builder.create_copy_buffer(src.handle, dst.handle) + else: + raise TypeError("src and dst must be tl.tensor or bl.buffer") + + +def fixpipe( + src: tl.tensor, + dst, + dma_mode, + dual_dst_mode, + pre_quant_mode, + pre_relu_mode, + builder: ascend_ir.ascendnpu_ir_builder, +) -> None: + builder.create_fixpipe( + src.handle, + dst.handle, + dma_mode.value, + dual_dst_mode.value, + pre_quant_mode.value, + pre_relu_mode.value, + ) + + +def debug_barrier(sync_mode: str, builder) -> None: + target = tl.tensor(builder.get_int64(0), tl.int64) + attr = builder.get_str_attr(sync_mode) + builder.create_debug_barrier(target.handle, "SYNC_IN_VF", attr) diff --git a/third_party/ascend/language/cann/extension/vec_ops.py b/third_party/ascend/language/cann/extension/vec_ops.py new file mode 100644 index 000000000..effbbc0fa --- /dev/null +++ b/third_party/ascend/language/cann/extension/vec_ops.py @@ -0,0 +1,535 @@ +# insert_slice +# extract_slice +# get_element +# sort +# flip +# gather + +import triton.language as tl +from triton.language import semantic, core, standard +from triton.language.core import (_constexpr_to_value, _tensor_member_fn, _unwrap_iterable, builtin, constexpr, dtype, + tensor, check_bit_width, _unwrap_if_constexpr, range) +from triton.language.semantic import ( + wrap_tensor, + _str_to_rounding_mode, + not_equal, + _str_to_dot_input_precision, + binary_op_type_checking_impl, + integer_promote_impl, + broadcast_impl_shape, + _str_to_sem, + _str_to_scope, + bitcast, + bitwise_op_type_checking_impl, + to_tensor, + _str_to_load_cache_modifier, + _str_to_eviction_policy, + _str_to_padding_option, + _canonicalize_boundary_check, +) + +from . import is_compile_on_910_95 +from .aux_ops import compile_hint_impl + +from typing import Optional, Tuple, List, overload +from triton._C.libtriton import ir + + +@_tensor_member_fn +@builtin +def insert_slice(ful, sub, offsets, sizes, strides, _builder=None, _generator=None) -> tensor: + """ + Insert a tensor to another tensor as specified by the operation’s offsets, sizes and strides arguments. + + :param ful: The tensor to receive tensor. + :type ful: Tensor + :param sub: The tensor to be inserted. + :type sub: Tensor + :param offsets: + :type offsets: tuple of ints + :param sizes: + :type sizes: tuple of ints + :param strides: + :type strides: tuple of ints + """ + + def insert_slice_impl(ful: tensor, sub: tensor, offsets: List[tensor], sizes: List[int], strides: List[int], + builder: ir.builder) -> tensor: + assert (len(ful.shape) == len(offsets)) + assert (len(ful.shape) == len(sizes)) + assert (len(ful.shape) == len(strides)) + assert (all([s >= 1 for s in sizes])) + assert (all([s >= 0 for s in strides])) + # Handle both tensor and int offsets (for interpreter mode) + new_offsets = [] + for o in offsets: + if isinstance(o, tensor): + new_offsets.append(o.handle) + elif isinstance(o, int): + # For interpreter mode: keep as int + new_offsets.append(o) + else: + new_offsets.append(o.handle if hasattr(o, 'handle') else o) + ret_type = tl.block_type(ful.type.scalar, ful.shape) + out = builder.create_insert_slice(ful.handle, sub.handle, new_offsets, sizes, strides) + return tensor(out, ret_type) + + assert len(ful.shape) > 0 + assert len(ful.shape) == len(sub.shape) + new_offsets = [semantic.to_tensor(o, _builder) if isinstance(o, constexpr) else o for o in offsets] + out = insert_slice_impl(ful, sub, new_offsets, sizes, strides, _builder) + return out + + +@_tensor_member_fn +@builtin +def extract_slice(ful, offsets, sizes, strides, _builder=None, _generator=None) -> tensor: + """ + Extract a tensor from another tensor as specified by the operation’s offsets, sizes and strides arguments. + + :param ful: The tensor to split. + :type ful: Tensor + :param offsets: + :type offsets: tuple of ints + :param sizes: + :type sizes: tuple of ints + :param strides: + :type strides: tuple of ints + """ + + def extract_slice_impl(ful: tensor, offsets: List[tensor], sizes: List[int], strides: List[int], + builder: ir.builder) -> tensor: + assert (len(ful.shape) == len(offsets)) + assert (len(ful.shape) == len(sizes)) + assert (len(ful.shape) == len(strides)) + assert (all([s >= 1 for s in sizes])) + assert (all([s >= 0 for s in strides])) + # Handle both tensor and int offsets (for interpreter mode) + new_offsets = [] + for o in offsets: + if isinstance(o, tensor): + new_offsets.append(o.handle) + elif isinstance(o, int): + # For interpreter mode: keep as int + new_offsets.append(o) + else: + new_offsets.append(o.handle if hasattr(o, 'handle') else o) + ret_type = tl.block_type(ful.type.scalar, sizes) + out = builder.create_extract_slice(ful.handle, new_offsets, sizes, strides) + return tensor(out, ret_type) + + assert len(ful.shape) > 0 + new_offsets = [semantic.to_tensor(o, _builder) if isinstance(o, constexpr) else o for o in offsets] + sub = extract_slice_impl(ful, new_offsets, sizes, strides, _builder) + return sub + + +@_tensor_member_fn +@builtin +def get_element(src, indice, _builder=None, _generator=None): + """ + get_element op reads a ranked tensor and returns one element as specified by the given indices. + The result of the op is a value with the same type as the elements of the tensor. + The arity of indices must match the rank of the accessed value. + + :param src: The tensor to be accessed. + :type src: Tensor + :param indice: + :type indice: tuple of ints + """ + + def get_element_impl(src: tensor, indice: List[tensor], builder: ir.builder): + if len(src.shape) != len(indice): + raise ValueError("Indice's rank must be equal to src tensor's rank") + + # Handle both tensor and int indices (for interpreter mode) + new_indice = [] + for i in indice: + if isinstance(i, tensor): + new_indice.append(i.handle) + elif isinstance(i, int): + # For interpreter mode: convert int to TensorHandle + new_indice.append(i) + else: + # Try to use .handle attribute if available + new_indice.append(i.handle if hasattr(i, 'handle') else i) + + result = builder.create_extract_scalar(src.handle, new_indice) + return wrap_tensor(result, src.type.scalar, None) + + assert len(src.shape) > 0 + new_indice = [semantic.to_tensor(i, _builder) if isinstance(i, constexpr) else i for i in indice] + return get_element_impl(src, new_indice, _builder) + + +@builtin +def flip(ptr, dim=-1, _builder=None, _generator=None): + + def flip_impl(ptr: tensor, dim: int, builder: ir.builder, generator=None): + """ + Flips a tensor `ptr` along the dimension `dim`. + + :param ptr: the first input tensor + :type ptr: tensor + :param dim: the dimension to flip along + :type dim: int + :param generator: the code generator (required for reduce operations) + :type generator: generator object + """ + + def _get_flip_dim(dim, shape): + dim = _unwrap_if_constexpr(dim) + shape = _unwrap_if_constexpr(shape) + if dim is None: + dim = len(shape) - 1 + if dim < 0: # flip doesn't work if dim < 0 because the xor-swap for loop will start/end at the wrong index + dim += len(shape) + return constexpr(dim) + + def _log2(i: core.constexpr): + log2 = 0 + n = core.constexpr(i).value + while n > 1: + n >>= 1 + log2 += 1 + return core.constexpr(log2) + + def flip_simd(ptr: tensor, dim: int, builder: ir.builder): + """ + Triton flip operation for simd + + Args: + ptr: tensor, input tensor + dim: int, dimension to flip (can be negative, normalized here) + builder: ir.builder, underlying IR builder + Returns: + flipped: tensor, same type and shape as input + """ + + shape = getattr(ptr, "shape", None) + if shape is None or shape == (): + shape = getattr(getattr(ptr, "type", None), "shape", None) + + rank = None + if shape is not None: + try: + rank = len(shape) + except Exception: + rank = len(list(shape)) + + if rank is not None: + if rank < 1: + raise ValueError("ascend.flip requires tensor rank >= 1") + norm_dim = dim if dim >= 0 else dim + rank + if not (0 <= norm_dim < rank): + raise ValueError(f"ascend.flip got invalid dim={dim} for shape {tuple(shape)}") + dim = norm_dim + else: + if dim < 0: + raise ValueError("ascend.flip with unknown rank requires non-negative dim") + + flipped_vals = builder.create_flip(ptr.handle, dim) + flipped = tensor(flipped_vals, type=ptr.type) + return flipped + + # If compile_mode is not simt, use the simd implementation + if not builder.is_simt_mode(): + return flip_simd(ptr, dim, builder) + core.static_assert(-len(ptr.shape) <= dim and dim < len(ptr.shape), _builder=builder) + _dim: core.constexpr = _get_flip_dim(dim, ptr.shape) + core.static_assert(standard._is_power_of_two(ptr.shape[_dim]), _builder=builder) + steps: core.constexpr = _log2(ptr.shape[_dim]) + # If steps is 0, return the original tensor + if steps == 0: + return ptr + # reshape the swap dimension to (2, 2, ..., 2) + idtype = core.get_int_dtype(bitwidth=ptr.dtype.primitive_bitwidth, signed=True) + y = core.reshape( + ptr.to(idtype, bitcast=True, _builder=builder), + ptr.shape.__getitem__(slice(None, _dim)) + [2] * steps + ptr.shape.__getitem__(slice(_dim + 1, None)), + _builder=builder) + for i in static_range(steps): + y = y.__xor__(standard.xor_sum(y, _dim + i, True, _builder=builder, _generator=generator), _builder=builder) + ptr = core.reshape(y, ptr.shape, _builder=builder).to(ptr.dtype, bitcast=True, _builder=builder) + return ptr + + try: + dim = int(dim.value) if hasattr(dim, "value") else int(dim) + except Exception as e: + raise TypeError(f"dim must be an integer (or tl.constexpr int), got {dim!r}") from e + + dim = len(ptr.shape) - 1 if dim == -1 else dim + return flip_impl(ptr, dim, _builder, _generator) + + +class static_range: + """ + Iterator for non-JIT Python functions that need to iterate over constexpr values. + This is used in functions like flip that are called during compilation. + """ + + def __init__(self, arg1, arg2=None, step=None): + if step is None: + self.step = core.constexpr(1) + else: + self.step = step + if arg2 is None: + self.start = core.constexpr(0) + self.end = arg1 + else: + self.start = arg1 + self.end = arg2 + + def __iter__(self): + # Extract actual values from constexpr objects for iteration + start_val = core._constexpr_to_value(self.start) + end_val = core._constexpr_to_value(self.end) + step_val = core._constexpr_to_value(self.step) + # Store as regular Python integers for iteration + self._current = start_val + self._end = end_val + self._step = step_val + return self + + def __next__(self): + if self._current >= self._end: + raise StopIteration + value = self._current + self._current += self._step + return value + + +@builtin +def sort(ptr, dim=-1, descending=False, _builder=None): + """ + sort the input tensor along 'dim' + + param: + ptr: tensor, input tensor + dim: int or tl.constexpr[int], dimension to sort + descending: bool or tl.constexpr[bool], the result is descending or not + _builder: ir.builder + return: + values: tensor, the sorted tensor + """ + + def sort_impl(ptr: tensor, dim: int, descending, builder: ir.builder): + allowed_types = { + tl.int8, tl.int16, tl.bfloat16, tl.float16, tl.float32, tl.int32, tl.int64, tl.float8e4nv, tl.float8e5 + } + base_ty = ptr.type.scalar if hasattr(ptr.type, "scalar") else ptr.type + if base_ty not in allowed_types: + raise TypeError( + f"ascend.sort only supports int8, int16, bfloat16, float16, float32, int32, int64, float8e4nv, float8e5" + f"but got {ptr.type}") + + shape = getattr(ptr, "shape", None) + if shape is None or shape == (): + shape = getattr(getattr(ptr, "type", None), "shape", None) + + rank = None + if shape is not None: + try: + rank = len(shape) + except Exception: + rank = len(list(shape)) + + if rank is not None: + if rank < 1: + raise ValueError("ascend.sort requires tensor rank >= 1") + last_dim = rank - 1 + norm_dim = dim if dim >= 0 else dim + rank + if norm_dim != last_dim: + raise ValueError(f"ascend.sort only supports sorting along the last dimension " + f"(dim={last_dim} or -1) for shape {tuple(shape)}, but got dim={dim}") + dim = last_dim + else: + if dim != -1: + raise ValueError("ascend.sort only supports the last dimension; when rank is unknown " + "you must pass dim=-1") + + if hasattr(descending, "value"): + descending = bool(descending.value) + else: + descending = bool(descending) + + sorted_vals = builder.create_sort(ptr.handle, dim, descending) + + values = tensor(sorted_vals, type=ptr.type) + + return values + + try: + dim = int(dim.value) if hasattr(dim, "value") else int(dim) + except Exception as e: + raise TypeError(f"dim must be an integer (or tl.constexpr int), got {dim!r}. Error: {str(e)}") from e + + if hasattr(descending, "value"): + descending = bool(descending.value) + else: + descending = bool(descending) + + ret = sort_impl(ptr, dim, descending, _builder) + base_ty = ptr.type.scalar if hasattr(ptr.type, "scalar") else ptr.type + if base_ty.is_int8() or base_ty.is_int16(): + compile_hint_impl(ret, "overflow_mode", constexpr("saturate"), _builder) + return ret + + +def ascend_cast_impl(input: tensor, dst_ty: dtype, builder: ir.builder, fp_downcast_rounding: Optional[str] = None, + overflow_mode: Optional[str] = None) -> tensor: + src_ty = input.type + if isinstance(dst_ty, tl.constexpr): + dst_ty = dst_ty.value + if isinstance(fp_downcast_rounding, tl.constexpr): + fp_downcast_rounding = fp_downcast_rounding.value + if src_ty.is_block(): + dst_ty = tl.block_type(dst_ty.scalar, input.type.get_block_shapes()) + if src_ty == dst_ty: + return input + + src_sca_ty = src_ty.scalar + dst_sca_ty = dst_ty.scalar + if src_sca_ty == dst_sca_ty: + return input + + # For fp downcasting default rounding mode should be RTNE, for all other conversions it should + # not be set + fp_downcast_rounding = _str_to_rounding_mode(fp_downcast_rounding) + use_custom_rounding = False + if dst_sca_ty.is_floating() and src_sca_ty.is_floating( + ) and dst_sca_ty.primitive_bitwidth < src_sca_ty.primitive_bitwidth: + if fp_downcast_rounding is None: fp_downcast_rounding = ir.ROUNDING_MODE.RTNE + elif fp_downcast_rounding != ir.ROUNDING_MODE.RTNE: use_custom_rounding = True + else: + if fp_downcast_rounding is not None: + raise ValueError("fp_downcast_rounding should be set only for truncating fp conversions. " + "Source scalar type is " + str(src_sca_ty) + " and destination type is " + str(dst_sca_ty)) + if not is_compile_on_910_95: + if (src_sca_ty.is_fp8() or dst_sca_ty.is_fp8()) or (src_sca_ty.is_fp64() or dst_sca_ty.is_fp64()): + raise ValueError("[fp8, fp64] is unsupported on Ascend for now." + "Source scalar type is " + str(src_sca_ty) + " and destination type is " + str(dst_sca_ty)) + if (src_sca_ty.is_fp8e4b15() or dst_sca_ty.is_fp8e4b15()): + assert builder.codegen_fns.get( + "convert_custom_types") is not None, "target doesn't provide conversion for this type." + return builder.codegen_fns["convert_custom_types"](input, dst_ty, fp_downcast_rounding, _builder=builder) + # Casting with customized floating types involved: fp8 <=> bf16, fp16, fp32, fp64 + # and non-default rounding modes for downcasting + if (src_sca_ty.is_fp8() and dst_sca_ty.is_floating()) or \ + (src_sca_ty.is_floating() and dst_sca_ty.is_fp8()) or \ + use_custom_rounding: + return tensor(builder.create_fp_to_fp(input.handle, dst_ty.to_ir(builder), fp_downcast_rounding), dst_ty) + + # bf16 <=> (not fp32) + if (src_sca_ty.is_fp16() and not dst_sca_ty.is_fp32()) or \ + (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()): + return ascend_cast_impl(ascend_cast_impl(input, tl.float32, builder), dst_sca_ty, builder) + + # Standard floating types' casting: truncation + # fp64 => fp32, fp16, bf16 + # fp32 => fp16, bf16 + truncate_fp = src_sca_ty.is_floating() and \ + dst_sca_ty.is_floating() and \ + src_sca_ty.primitive_bitwidth > dst_sca_ty.primitive_bitwidth + if truncate_fp: + return tensor(builder.create_fp_trunc(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Standard floating types' casting: extension + # fp32 => fp64 + # fp16 => fp32, fp64 + # bf16 => fp32, fp64 + ext_fp = src_sca_ty.is_floating() and \ + dst_sca_ty.is_floating() and \ + src_sca_ty.primitive_bitwidth < dst_sca_ty.primitive_bitwidth + if ext_fp: + return tensor(builder.create_fp_ext(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Casting between integer types + if src_sca_ty.is_int() and dst_sca_ty.is_int() and \ + (src_sca_ty.int_bitwidth != dst_sca_ty.int_bitwidth or src_sca_ty.int_signedness != dst_sca_ty.int_signedness): + sign_extend = src_sca_ty.is_int_signed() and not src_sca_ty.is_bool() + if dst_sca_ty.is_bool(): + ty = input.dtype.to_ir(builder) + _0 = tensor(builder.get_null_value(ty), input.dtype) + return not_equal(input, _0, builder) + elif overflow_mode == "saturate" and \ + (src_sca_ty.is_int_unsigned() or dst_sca_ty.is_int_unsigned()) and \ + src_sca_ty.int_bitwidth >= dst_sca_ty.int_bitwidth: + return ascend_cast_impl(ascend_cast_impl(input, tl.float32, builder), dst_sca_ty, builder) + return tensor(builder.create_int_cast(input.handle, dst_ty.to_ir(builder), sign_extend), dst_ty) + + # Casting standard floating types to integer types + if src_sca_ty.is_standard_floating() and dst_sca_ty.is_int(): + if dst_sca_ty.is_bool(): + ty = input.dtype.to_ir(builder) + _0 = tensor(builder.get_null_value(ty), input.dtype) + return not_equal(input, _0, builder) + elif dst_sca_ty.is_int_signed(): + return tensor(builder.create_fp_to_si(input.handle, dst_ty.to_ir(builder)), dst_ty) + else: + return tensor(builder.create_fp_to_ui(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Casting integer types to standard floating types + if src_sca_ty.is_int() and dst_sca_ty.is_standard_floating(): + if src_sca_ty.is_bool() or not src_sca_ty.is_int_signed(): + return tensor(builder.create_ui_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty) + else: + return tensor(builder.create_si_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Casting pointer types to integer types + if src_sca_ty.is_ptr() and dst_sca_ty.is_int(): + bitwidth = dst_sca_ty.int_bitwidth + if bitwidth == 64: + return tensor(builder.create_ptr_to_int(input.handle, dst_ty.to_ir(builder)), dst_ty) + if bitwidth == 1: + return not_equal(ascend_cast_impl(input, tl.int64, builder), tensor(builder.get_int64(0), tl.int64), + builder) + + # Casting integer types to pointer types + if src_sca_ty.is_int() and dst_sca_ty.is_ptr(): + return tensor(builder.create_int_to_ptr(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Casting pointer types to pointer types + if src_sca_ty.is_ptr() and dst_sca_ty.is_ptr(): + return tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), dst_ty) + + assert False, f'cannot cast {input} to {dst_ty}' + + +@_tensor_member_fn +@builtin +def cast(input, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast: bool = False, + overflow_mode: Optional[str] = None, _builder=None): + """ + Casts a tensor to the given :code:`dtype`. + + :param dtype: The target data type. + :type dtype: dtype + :param fp_downcast_rounding: The rounding mode for downcasting + floating-point values. This parameter is only used when self is a + floating-point tensor and dtype is a floating-point type with a + smaller bitwidth. Supported values are :code:`"rtne"` (round to + nearest, ties to even) and :code:`"rtz"` (round towards zero). + :type fp_downcast_rounding: str, optional + :param bitcast: If true, the tensor is bitcasted to the given + :code:`dtype`, instead of being numerically casted. + :type bitcast: bool, optional + :param overflow_mode: When overflow_mode is not set or is "trunc", + truncation (cut-off) will be used to handle overflow. When + overflow_mode is "sautrate", the maximum value of the data type + will be used to handle overflow. + :type overflow_mode: string, optional + """ + overflow_modes = ["trunc", "saturate"] + input = semantic.to_tensor(input, _builder) + if isinstance(bitcast, constexpr): + bitcast = bitcast.value + if bitcast: + return semantic.bitcast(input, dtype, _builder) + ret = ascend_cast_impl(input, dtype, _builder, fp_downcast_rounding, overflow_mode) + if overflow_mode is not None: + if overflow_mode in overflow_modes: + compile_hint_impl(ret, "overflow_mode", overflow_mode, _builder) + else: + raise ValueError(f"Unknown overflow_mode:{overflow_mode} is found.") + return ret diff --git a/third_party/ascend/language/ascend/libdevice.py b/third_party/ascend/language/cann/libdevice.py similarity index 58% rename from third_party/ascend/language/ascend/libdevice.py rename to third_party/ascend/language/cann/libdevice.py index 50d1f145f..eaba0a831 100644 --- a/third_party/ascend/language/ascend/libdevice.py +++ b/third_party/ascend/language/cann/libdevice.py @@ -1,202 +1,28 @@ -from functools import wraps -from typing import List -import numbers -from triton.language import core -from triton.language.core import ( - _constexpr_to_value, - constexpr, - tensor, -) -from triton.language.math import _add_math_1arg_docstr, _add_math_2arg_docstr, _add_math_3arg_docstr -from triton.language import semantic -from triton._C.libtriton import ir -from triton.language import math, semantic -from math import pi as math_pi - -T = core.TypeVar('T') - - -def _check_dtype(dtypes: List[str]) -> T: - """ - We're following libdevice's convention to check accepted data types for math functions. - It is not a good practice to support all data types as accelerators/GPUs don't support - many float16 and bfloat16 math operations. - We should let the users know that they are using and invoke explicit cast to convert - the data type to the supported one. - """ - - def wrapper(fn): - - @wraps(fn) - def check(*args, **kwargs): - # concatenate args and kwargs - all_args = list(args) + list(kwargs.values()) - for arg in [a for a in all_args if isinstance(a, core.tensor)]: - arg_type = arg.type.scalar.name - if hasattr(arg, 'was_bool_to_int8') and arg.was_bool_to_int8: - # In Triton, int1 maps to the boolean type - arg_type = 'int1' - if arg_type not in dtypes: - raise ValueError(f"Expected dtype {dtypes} but got {arg_type}") - return fn(*args, **kwargs) - - return check - - return wrapper - - -@core.extern -@_check_dtype(dtypes=["int32", "uint32"]) -@_add_math_2arg_docstr("most significant N bits of the 2N-bit product") -def umulhi(x, y, _builder=None): - x = semantic.to_tensor(x, _builder) - y = semantic.to_tensor(y, _builder) - x, y = core.binary_op_type_legalization(x, y, _builder) - return core.tensor(_builder.create_umulhi(x.handle, y.handle), x.type) - - -@core.extern -@_check_dtype(dtypes=["bf16", "fp16", "fp32", "fp8e4nv", "fp8e5"]) -@_add_math_1arg_docstr("exponential") -@core._tensor_member_fn -def exp(x, _builder=None): - x = semantic.to_tensor(x, _builder) - return core.tensor(_builder.create_exp(x.handle), x.type) - - -@core.extern -@_check_dtype(dtypes=["bf16", "fp16", "fp32", "fp8e4nv", "fp8e5"]) -@_add_math_1arg_docstr("exponential (base 2)") -@core._tensor_member_fn -def exp2(x, _builder=None): - x = semantic.to_tensor(x, _builder) - return core.tensor(_builder.create_exp2(x.handle), x.type) - - -@core.extern -@_check_dtype(dtypes=["bf16", "fp16", "fp32", "fp8e4nv", "fp8e5"]) -@_add_math_1arg_docstr("natural logarithm") -@core._tensor_member_fn -def log(x, _builder=None): - x = semantic.to_tensor(x, _builder) - return core.tensor(_builder.create_log(x.handle), x.type) - - -@core.extern -@_check_dtype(dtypes=["bf16", "fp16", "fp32", "fp8e4nv", "fp8e5"]) -@_add_math_1arg_docstr("logarithm (base 2)") -@core._tensor_member_fn -def log2(x, _builder=None): - x = semantic.to_tensor(x, _builder) - return core.tensor(_builder.create_log2(x.handle), x.type) - - -@core.extern -@_check_dtype(dtypes=["bf16", "fp16", "fp32", "fp8e4nv", "fp8e5"]) -@_add_math_1arg_docstr("cosine") -@core._tensor_member_fn -def cos(x, _builder=None): - x = semantic.to_tensor(x, _builder) - return core.tensor(_builder.create_cos(x.handle), x.type) - - -@core.extern -@_check_dtype(dtypes=["bf16", "fp16", "fp32", "fp8e4nv", "fp8e5"]) -@_add_math_1arg_docstr("sine") -@core._tensor_member_fn -def sin(x, _builder=None): - x = semantic.to_tensor(x, _builder) - return core.tensor(_builder.create_sin(x.handle), x.type) - - -@core.extern -@_check_dtype(dtypes=["bf16", "fp16", "fp32", "fp8e4nv", "fp8e5"]) -@_add_math_1arg_docstr("fast square root") -@core._tensor_member_fn -def sqrt(x, _builder=None): - x = semantic.to_tensor(x, _builder) - return core.tensor(_builder.create_sqrt(x.handle), x.type) - - -@core.extern -@_check_dtype(dtypes=["bf16", "fp16", "fp32", "fp8e4nv", "fp8e5"]) -@_add_math_1arg_docstr("precise square root (rounding to nearest wrt the IEEE standard)") -@core._tensor_member_fn -def sqrt_rn(x, _builder=None): - x = semantic.to_tensor(x, _builder) - return core.tensor(_builder.create_precise_sqrt(x.handle), x.type) - - -@core.extern -@_check_dtype(dtypes=["bf16", "fp16", "fp32", "fp8e4nv", "fp8e5"]) -@_add_math_1arg_docstr("inverse square root") -@core._tensor_member_fn -def rsqrt(x, _builder=None): - x = semantic.to_tensor(x, _builder) - return core.tensor(_builder.create_rsqrt(x.handle), x.type) - - -@core.extern -@_check_dtype(dtypes=["bf16", "fp16", "fp32", "fp8e4nv", "fp8e5"]) -@_add_math_2arg_docstr("precise division (rounding to nearest wrt the IEEE standard)") -def div_rn(x, y, _builder=None): - x = semantic.to_tensor(x, _builder) - y = semantic.to_tensor(y, _builder) - x, y = core.binary_op_type_legalization(x, y, _builder) - return core.tensor(_builder.create_precise_divf(x.handle, y.handle), x.type) - +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. -@core.extern -@_check_dtype(dtypes=["bf16", "fp16", "fp32", "fp8e4nv", "fp8e5"]) -@_add_math_1arg_docstr("error function") -@core._tensor_member_fn -def erf(x, _builder=None): - x = semantic.to_tensor(x, _builder) - return core.tensor(_builder.create_erf(x.handle), x.type) - - -@core.extern -@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) -@_add_math_1arg_docstr("error function") -@core._tensor_member_fn -def tanh(x, _builder=None): - x = semantic.to_tensor(x, _builder) - return core.tensor(_builder.create_tanh(x.handle), x.type) - - -@core.extern -@_check_dtype(dtypes=["bf16", "fp16", "fp32", "fp8e4nv", "fp8e5"]) -@_add_math_1arg_docstr("floor") -@core._tensor_member_fn -def floor(x, _builder=None): - x = semantic.to_tensor(x, _builder) - return core.tensor(_builder.create_floor(x.handle), x.type) - - -@core.extern -@_check_dtype(dtypes=["bf16", "fp16", "fp32"]) -@_add_math_1arg_docstr("ceil") -@core._tensor_member_fn -def ceil(x, _builder=None): - x = semantic.to_tensor(x, _builder) - if x.type.scalar.is_int(): - return x - elif x.type.scalar.is_floating(): - return core.tensor(_builder.create_ceil(x.handle), x.type) - raise ValueError("ceil does not support boolean type") - - -@core.extern -@_check_dtype(dtypes=["bf16", "fp16", "fp32", "fp8e4nv", "fp8e5"]) -@_add_math_3arg_docstr("fused multiply-add") -def fma(x, y, z, _builder=None): - x = semantic.to_tensor(x, _builder) - y = semantic.to_tensor(y, _builder) - z = semantic.to_tensor(z, _builder) - x, y = core.binary_op_type_legalization(x, y, _builder) - z, x = core.binary_op_type_legalization(z, x, _builder) - z, y = core.binary_op_type_legalization(z, y, _builder) - return core.tensor(_builder.create_fma(x.handle, y.handle, z.handle), x.type) +from math import pi as math_pi +from triton.language import core, math, semantic +from triton._C.libtriton import ir +from triton.runtime.jit import jit +from triton.backends.ascend.utils import get_ascend_arch_from_env @core.extern @@ -276,8 +102,8 @@ def ilogb(arg0, _builder=None): def ldexp(arg0, arg1, _builder=None): return core.extern_elementwise( "", "", [arg0, arg1], { - (core.dtype("fp32"), core.dtype("fp32")): ("__hmf_ldexpf", core.dtype("fp32")), - (core.dtype("fp16"), core.dtype("fp16")): ("__hmf_ldexpDh", core.dtype("fp16")), + (core.dtype("fp32"), core.dtype("int32")): ("__hmf_ldexpf", core.dtype("fp32")), + (core.dtype("fp16"), core.dtype("int32")): ("__hmf_ldexpDh", core.dtype("fp16")), }, is_pure=True, _builder=_builder) @@ -305,21 +131,6 @@ def isnan(arg0, _builder=None): }, is_pure=True, _builder=_builder) -@core.extern -def flip(arg0, arg1=None, _builder=None): - if _builder is None: - from triton.language import core as tl - _builder = tl.get_builder() - from triton.language.core import flip as ascend_flip - return ascend_flip(arg0, arg1, _builder=_builder) - - -@core.extern -def atan2(arg0, arg1, _builder=None): - core.static_print("The func atan2 is supported in math.atan2 lowlevel. So use math.atan2 instead.") - return math.atan2(arg1, arg0) - - @core.extern def div_rz(arg0, arg1, _builder=None): core.static_print("tl.div_rz is unsupported for now. Use libdevice.div_rz instead.") @@ -332,6 +143,12 @@ def fmod(arg0, arg1, _builder=None): core.static_assert(False) +@core.extern +def trunc(arg0, _builder=None): + core.static_print("tl.trunc is unsupported for now. Use libdevice.trunc instead.") + core.static_assert(False) + + @core.extern def round(arg0, _builder=None): return core.extern_elementwise("", "", [arg0], { @@ -339,40 +156,6 @@ def round(arg0, _builder=None): }, is_pure=True, _builder=_builder) -@core.extern -@_add_math_2arg_docstr("cdiv") -@core._tensor_member_fn -def cdiv(x, div, _builder=None): - if isinstance(x, core.constexpr): - x = x.value - if isinstance(div, core.constexpr): - div = div.value - from math import ceil as py_ceil - if isinstance(x, numbers.Number) and isinstance(div, numbers.Number): - if isinstance(x, bool) or isinstance(div, bool): - raise TypeError("cdiv does not support boolean type") - if isinstance(x, int) and isinstance(div, int): - res = x // div - rem = x % div - return res + (1 if rem != 0 else 0) - else: - return py_ceil(x / div) - - x = semantic.to_tensor(x, _builder) - div = semantic.to_tensor(div, _builder) - x_scalar_type = x.type.scalar - div_scalar_type = div.type.scalar - if x_scalar_type.is_bool() or div_scalar_type.is_bool(): - raise ValueError("cdiv does not support boolean type") - elif x_scalar_type.is_int() and div_scalar_type.is_int(): - # integer cdiv: (x + div - 1) // div as before - return semantic.floordiv(semantic.add(x, semantic.sub(div, 1, True, _builder), True, _builder), div, _builder) - else: - div_res = semantic.truediv(x, div, _builder) - cdiv_res = core.tensor(_builder.create_ceil(div_res.handle), div_res.type) - return semantic.cast(cdiv_res, x_scalar_type, _builder) - - @core.builtin @math._check_dtype(dtypes=["bf16", "fp16", "fp32"]) @math._add_math_1arg_docstr("acos") @@ -993,338 +776,45 @@ def copysign(arg0: core.tensor, arg1: core.tensor, _builder: ir.builder): return semantic.where(is_negative, neg_magnitude, magnitude, _builder) -@core.builtin -def index_put(ptr: tensor, index: tensor, value: tensor, dim: int, index_boundary: int, end_offset: tuple, - start_offset: tuple, dst_stride: tuple, _builder=None): - """ - Index put values from a tensor into a destination tensor. - - Index put operation for different tensor ranks: - 1. 2D index scatter (0 <= dim < 1): - 1.1 dim = 0 - out[index[i]][start_offset[1]:end_offset[1]] = value[i][0:end_offset[1]-start_offset[1]] - 2. 3D index scatter (0 <= dim < 2): - 2.1 dim = 0 - out[index[i]][start_offset[1]:end_offset[1]][start_offset[2]:end_offset[2]] - = value[i][0:end_offset[1]-start_offset[1]][0:end_offset[2]-start_offset[2]] - 2.2 dim = 1 - out[start_offset[0]:end_offset[0]][index[j]][start_offset[2]:end_offset[2]] - = value[0:end_offset[0]-start_offset[0]][j][0:end_offset[2]-start_offset[2]] - - - :param ptr: pointer type, the destination tensor pointer (in GM) - :param index: tensor, a index to scatter (in UB) - :param value: tensor, a value to store (in UB) - :param dim: int32, the dimension to scatter along - :param index_boundary: int64, the upper boundary for index values - :param end_offset: tuple of int, the offsets of each dimension for the end of the scatter region - :param start_offset: tuple of int, the offsets of each dimension for the start of the scatter region - :param dst_stride: tuple of int, the stride of each dimension of destination tensor - - Constraints - *********** - - `ptr` and `value` must have the same rank. - - `ptr.dtype` only supports `float16`, `bfloat16`, `float32` currently. - - `index` must be an integer tensor. If `index.rank` != 1, it will be reshaped to 1D. - - `index.numel` must equal `value.shape[dim]`. - - `value` support 2~5D tensors. - - `dim` must be valid (0 <= dim < rank(value) - 1). - - Example - ******* - .. code-block:: python - - import torch - import triton - import triton.language as tl - from triton.language.extra.ascend.libdevice import index_put - - @triton.jit - def simple_index_put_kernel(value_ptr, index_ptr, dst_ptr): - # index tile shape: [2] - index_local = tl.arange(0, 2) - x1_local = tl.arange(0, 2)[None, :] # shape=(1,2) - - index_tile = tl.load(index_ptr + index_local) - value_tile = tl.load(value_ptr + index_local[:, None]*2 + x1_local) - - index_put( - ptr=dst_ptr, - index=index_tile, - value=value_tile, - dim=0, - index_boundary=4, - end_offset=(2, 2), - start_offset=(0, 0), - dst_stride=(2, 1) - ) - - dst = torch.zeros((4,2), device='npu', dtype=torch.float32) - value = torch.tensor([[1.,2.], [3.,4.]], device='npu') - index = torch.tensor([2, 0], device='npu') - - simple_index_put_kernel[(1,)](value, index, dst) - print("IndexPut result:", dst) # ref:[[3.,4.], [0.,0.], [1.,2.], [0.,0.]] - """ - dim = _constexpr_to_value(dim) - index_boundary = _constexpr_to_value(index_boundary) - return semantic.index_put(ptr, index, value, dim, index_boundary, end_offset, start_offset, dst_stride, _builder) - - -@core.builtin -def gather_out_to_ub(src: tensor, index: tensor, index_boundary: int, dim: int, src_stride: tuple, end_offset: tuple, - start_offset: tuple, other=None, _builder=None): - """ - Gather from a source tensor in Global Memory (GM) to Unified Buffer (UB) - along a specified dimension with out-of-bound handling. - - Gather operation for different tensor ranks: - 1. 1D index gather: - out[i] = src[start_offset[0] + index[i]] - 2. 2D index gather (0 <= dim < 2): - 2.1 dim = 0 - out[i][j] = src[start_offset[0] + index[i][j]][start_offset[1] + j] - 2.2 dim = 1 - out[i][j] = src[start_offset[0] + i][start_offset[1] + index[i][j]] - 3. 3D index gather (0 <= dim < 3): - 3.1 dim = 0 - out[i][j][k] = src[start_offset[0] + index[i][j][k]][start_offset[1] + j][start_offset[2] + k] - 3.2 dim = 1 - out[i][j][k] = src[start_offset[0] + i][start_offset[1] + index[i][j][k]][start_offset[2] + k] - 3.3 dim = 2 - out[i][j][k] = src[start_offset[0] + i][start_offset[1] + j][start_offset[2] + index[i][j][k]] - - :param src: pointer type, the source tensor pointer (in GM) - :param index: tensor, a tensor to gather (in UB) - :param index_boundary: int64, the upper boundary for index values - :param dim: int32, the dimension to gather along - :param src_stride: tuple of int64, the stride of each dimension of src tensor - :param end_offset: tuple of int32, the end offsets of each dimension for index tensor - :param start_offset: tuple of int32, the start offsets of each dimension for index tensor - :param other(Optional): scalar value, the default value when index is out of boundary (in UB) - :return: tensor, with the same shape as `index.shape` (in UB) - - Constraints - *********** - - `src` and `index` must have the same rank. - - `src.dtype` only supports `float16`, `bfloat16`, `float32` currently. - - `index` must be an integer tensor, with rank between 1 and 5. - - `dim` must be valid (0 <= dim < rank(index)). - - `other` must be a scalar value. - - For every dimension `i` not equal to `dim`, `index.size[i]` <= `src.size[i]`. - - The output shape is the same as `index.shape`. If `index` is None, \ - the output tensor will be an empty tensor with the same shape as `index`. - - Example - ******* - .. code-block:: python - - import torch - import triton - import triton.language as tl - from triton.language.extra.ascend.libdevice import gather_out_to_ub - - @triton.jit - def simple_gather_kernel(src_ptr, index_ptr, out_ptr): - # index tile shape: [2,2] - y0_local = tl.arange(0, 2)[:, None] # [0,1] rows - x1_local = tl.arange(0, 2)[None, :] # [0,1] cols - mask = (y0_local < 2) & (x1_local < 2) - - # Load index tile to UB - index = tl.load(index_ptr + y0_local*2 + x1_local, mask) - - # Call gather_out_to_ub: gather values from src along dim=0 - gathered = gather_out_to_ub( - src=src_ptr, - index=index, - index_boundary=4, - dim=0, - src_stride=(2, 1), - end_offset=(2, 2), - start_offset=(0, 0) - ) - - tl.store(out_ptr + y0_local*2 + x1_local, gathered, mask) - - src = torch.tensor([[1.,2.], [3.,4.], [5.,6.], [7.,8.]], device='npu') - index = torch.tensor([[0,1], [2,3]], device='npu') - out = torch.empty((2,2), device='npu', dtype=torch.float32) - - simple_gather_kernel[(1,)](src, index, out) - print("Gather result:", out) # ref: [[1.,4.], [5.,8.]] - """ - dim = _constexpr_to_value(dim) - index_boundary = _constexpr_to_value(index_boundary) - return semantic.gather_out_to_ub(src, index, index_boundary, dim, src_stride, end_offset, start_offset, other, +if get_ascend_arch_from_env() == "Ascend910_9589": + # if we have hardware support + @core.extern + def rint(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("__hmf_rint", core.dtype("fp32")), + (core.dtype("fp16"), ): ("__hmf_rint", core.dtype("fp16")), + (core.dtype("bf16"), ): ("__hmf_rint", core.dtype("bf16")), + }, is_pure=True, _builder=_builder) +else: + + @core.builtin + @math._check_dtype(dtypes=["fp16", "fp32", "bf16"]) + @math._add_math_1arg_docstr("rint") + def rint(arg0: core.tensor, _builder: ir.builder): + arg0 = semantic.to_tensor(arg0, _builder) + + floor_x = math.floor(arg0, _builder=_builder) + fractional = semantic.sub(arg0, floor_x, True, _builder) + + half = semantic.full(arg0.shape, 0.5, arg0.type.scalar, _builder) + eps = semantic.full(arg0.shape, 1e-8, arg0.type.scalar, _builder) + is_half = semantic.less_than(math.abs(semantic.sub(fractional, half, True, _builder), _builder=_builder), eps, _builder) + floor_int = floor_x.to(core.int32, _builder=_builder) if hasattr(floor_x, "to") else semantic.cast( + floor_x, core.int32, _builder) + two_i32 = semantic.full(arg0.shape, 2, core.int32, _builder) + is_even = semantic.equal(semantic.mod(floor_int, two_i32, _builder), + semantic.full(arg0.shape, 0, core.int32, _builder), _builder) -@core.builtin -def scatter_ub_to_out(ptr: tensor, value: tensor, index: tensor, index_boundary: int, dim: int, dst_stride: tuple, - end_offset: tuple, start_offset: tuple, _builder=None): - """ - Scatter a tile from Unified Buffer (UB) into a destination tensor in Global Memory (GM) - along a specified dimension, with index-boundary checking. - - Scatter operation for different tensor ranks: - 1. 1D index scatter: - out[start_offset[0] + index[i]] = value[i] - 2. 2D index scatter (0 <= dim < 2): - 2.1 dim = 0 - out[start_offset[0] + index[i][j]][start_offset[1] + j] = value[i][j] - 2.2 dim = 1 - out[start_offset[0] + i][start_offset[1] + index[i][j]] = value[i][j] - 3. 3D index scatter (0 <= dim < 3): - 3.1 dim = 0 - out[start_offset[0] + index[i][j][k]][start_offset[1] + j][start_offset[2] + k] = value[i][j][k] - 3.2 dim = 1 - out[start_offset[0] + i][start_offset[1] + index[i][j][k]][start_offset[2] + k] = value[i][j][k] - 3.3 dim = 2 - out[start_offset[0] + i][start_offset[1] + j][start_offset[2] + index[i][j][k]] = value[i][j][k] - - :param ptr: pointer type, the destination tensor pointer (in GM) - :param value: tensor, a tile value to store (in UB) - :param index: tensor, a index to scatter (in UB) - :param index_boundary: int64, the upper boundary for index values - :param dim: int32, the dimension to scatter along - :param dst_stride: tuple of int64, the stride of each dimension of destination tensor - :param end_offset: tuple of int32, the end offsets of each dimension for index tensor - :param start_offset: tuple of int32, the start offsets of each dimension for index tensor - - Constraints - *********** - - `ptr`, `index` and `value` must have the same rank. - - `ptr.dtype` only supports `float16`, `bfloat16`, `float32` currently. - - `index` must be an integer tensor, with rank between 1 and 5. - - `dim` must be valid (0 <= dim < rank(index)). - - For every dimension `i` not equal to `dim`, `index.size[i]` <= `ptr.size[i]`. - - The output shape is the same as `index.shape`. If `index` is None, \ - the output tensor will be an empty tensor with the same shape as `index`. - - Example - ******* - .. code-block:: python - - import torch - import triton - import triton.language as tl - from triton.language.extra.ascend.libdevice import scatter_ub_to_out - - @triton.jit - def simple_scatter_kernel(value_ptr, index_ptr, dst_ptr): - # index tile shape: [2,2] - y0_local = tl.arange(0, 2)[:, None] # [0,1] rows - x1_local = tl.arange(0, 2)[None, :] # [0,1] cols - mask = (y0_local < 2) & (x1_local < 2) - - value = tl.load(value_ptr + y0_local*2 + x1_local, mask) - index = tl.load(index_ptr + y0_local*2 + x1_local, mask) - - scatter_ub_to_out( - ptr=dst_ptr, - value=value, - index=index, - index_boundary=4, - dim=0, - dst_stride=(2, 1), - end_offset=(2, 2), - start_offset=(0, 0) - ) - - dst = torch.zeros((4,2), device='npu', dtype=torch.float32) - value = torch.tensor([[1.,2.], [3.,4.]], device='npu') - index = torch.tensor([[1,2], [3,0]], device='npu') - - simple_scatter_kernel[(1,)](value, index, dst) - print("Scatter result:", dst) # ref:[[0.,4.], [1.,0.], [0.,2.], [3.,0.]] - """ - dim = _constexpr_to_value(dim) - index_boundary = _constexpr_to_value(index_boundary) - return semantic.scatter_ub_to_out(ptr, value, index, index_boundary, dim, dst_stride, end_offset, start_offset, - _builder) - - -@core.builtin -def index_select_simd(src, dim, index, src_shape, src_offset, read_shape, _builder=None) -> tensor: - """ - Parallel index_select operation from Global Memory to Unified Buffer (SIMD version). - - Selects data from multiple indices along a specified dimension and loads - them as tiles from GM directly to UB with zero-copy semantics. - - :param src: Source tensor pointer (in GM) - :type src: tensor (pointer type) - :param dim: The dimension along which to select indices - :type dim: int or constexpr - :param index: 1D tensor of indices to select (in UB) - :type index: tensor - :param src_shape: Complete shape of the source tensor (can be int or tensor) - :type src_shape: List[Union[int, tensor]] - :param src_offset: Starting offset for reading (can be int or tensor) - :type src_offset: List[Union[int, tensor]] - :param read_shape: Size to read (tile shape, can be int or tensor) - :type read_shape: List[Union[int, tensor]] - - **Constraints:** - - - ``read_shape[dim]`` must be ``-1`` - - ``src_offset[dim]`` can be ``-1`` (will be ignored) - - Boundary handling: ``src_offset + read_shape > src_shape`` automatically - truncates to ``src_shape`` boundary - - Does not check if ``index`` contains out-of-bounds values - - **Example:** - - .. code-block:: python - - @triton.jit - def kernel(src_ptr, output_ptr, indices_ptr, M, N, D, ...): - # Load indices (e.g., [5, 10, 15, 20]) - indices = tl.load(indices_ptr + tl.arange(0, 4)) - - # Example 1: Static shapes (constants) - # Index select from dimension 1 - # src: [8, 100, 256], index_select at dim=1 - # Read: [4, ?, 128] starting from [4, ?, 128] - result = libdevice.index_select_simd( - src_ptr, - dim=1, - index=indices, - src_shape=[8, 100, 256], - src_offset=[4, -1, 128], - read_shape=[4, -1, 128] - ) - # result shape: [4, 4, 128] - - # Example 2: Dynamic shapes (variables) - result2 = libdevice.index_select_simd( - src_ptr, - dim=1, - index=indices, - src_shape=[M, N, D], - src_offset=[4, -1, 128], - read_shape=[4, -1, 128] - ) - - tl.store(output_ptr + ..., result) - - :return: Result tensor in UB with shape where ``dim`` is replaced - by the length of ``index`` - :rtype: tensor - """ - dim = _constexpr_to_value(dim) + zero = semantic.full(arg0.shape, 0.0, arg0.type.scalar, _builder) + is_pos = semantic.greater_equal(arg0, zero, _builder) - # Process shape parameters: convert constexpr to values, keep tensors as-is - def process_param(val): - """Convert constexpr to value, keep tensor or int as-is""" - if isinstance(val, tensor): - return val - else: - return _constexpr_to_value(val) + round_pos = math.floor(semantic.add(arg0, half, True, _builder), _builder=_builder) + round_neg = math.ceil(semantic.sub(arg0, half, True, _builder), _builder=_builder) + normal_round = semantic.where(is_pos, round_pos, round_neg, _builder) - newsrc_shape = [semantic.to_tensor(o, _builder) if isinstance(o, constexpr) else o for o in src_shape] - newsrc_offset = [semantic.to_tensor(o, _builder) if isinstance(o, constexpr) else o for o in src_offset] - assert len(index.shape) == 1, "index must be a 1D tensor" + half_round = semantic.where(is_even, floor_x, semantic.add(floor_x, 1.0, True, _builder), _builder) - return semantic.index_select_simd(src, dim, index, newsrc_shape, newsrc_offset, read_shape, _builder) + return semantic.where(is_half, half_round, normal_round, _builder) diff --git a/third_party/ascend/language/kernels/__init__.py b/third_party/ascend/language/kernels/__init__.py new file mode 100644 index 000000000..bd828caca --- /dev/null +++ b/third_party/ascend/language/kernels/__init__.py @@ -0,0 +1,28 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +""" +Optimized Triton kernels for Ascend NPU. + +This module provides high-level, optimized kernels written in Triton for common operations. +These are different from the low-level builtin extensions in `cann.extension`. +""" +__all__ = ["gather_2d_simd"] + +from .gather import gather_2d_simd diff --git a/third_party/ascend/language/kernels/gather.py b/third_party/ascend/language/kernels/gather.py new file mode 100644 index 000000000..56d1ce016 --- /dev/null +++ b/third_party/ascend/language/kernels/gather.py @@ -0,0 +1,88 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +""" +Gather kernel optimized for Ascend NPU. +""" +__all__ = ["gather_2d_simd"] + +import triton +import triton.language as tl +from triton.language.core import constexpr + + +@triton.jit +def gather_2d_simd(src_ptr, idx_ptr, out_ptr, M: constexpr, N: constexpr, K: constexpr, XBLOCK: constexpr, + XBLOCK_SUB: constexpr): + """ + 2D gather kernel for axis=1 (tail axis) with SIMD-style vectorization. + + This kernel is optimized for Ascend NPU architecture with focus on: + - Vectorized memory access using XBLOCK_SUB + - Efficient global memory (GM) access patterns + - Suitable for cases where N and K are not extremely large + + Args: + src_ptr: [M, N] source tensor in GM (Global Memory) + idx_ptr: [M, K] indices tensor in GM + out_ptr: [M, K] output tensor in GM + M: batch dimension size + N: source dimension size (gather from this dimension) + K: output dimension size (number of indices per batch) + XBLOCK: outer block size for M dimension (for program distribution) + XBLOCK_SUB: inner block size for M dimension (for SIMD vectorization) + + Example: + import torch + import triton + from third_party.ascend.language.kernels import gather_2d_simd + + M, N, K = 128, 256, 64 + src = torch.randn(M, N, device='npu') + indices = torch.randint(0, N, (M, K), dtype=torch.int32, device='npu') + output = torch.empty((M, K), dtype=src.dtype, device='npu') + + grid = (triton.cdiv(M, 32),) + gather_2d_simd[grid](src, indices, output, M, N, K, + XBLOCK=32, XBLOCK_SUB=4) + """ + pid = tl.program_id(0) + m_start = pid * XBLOCK + m_end = min(m_start + XBLOCK, M) + m_base = tl.arange(0, XBLOCK_SUB) + + # Process multiple rows at once using XBLOCK_SUB for vectorization + for m_tile_start in range(m_start, m_end, XBLOCK_SUB): + # M dimension offsets: [XBLOCK_SUB] + m_offs = m_tile_start + m_base + m_mask = m_offs < M + + # Load indices: [XBLOCK_SUB, K] + k_offs = tl.arange(0, K) + idx_tile = tl.load(idx_ptr + m_offs[:, None] * K + k_offs[None, :]) + + # Load source data: [XBLOCK_SUB, N] + n_offs = tl.arange(0, N) + src_tile = tl.load(src_ptr + m_offs[:, None] * N + n_offs[None, :]) + + # Gather operation along axis=1 + gathered_values = tl.gather(src_tile, idx_tile, axis=1) + + # Store results + tl.store(out_ptr + m_offs[:, None] * K + k_offs[None, :], gathered_values, mask=m_mask[:, None]) diff --git a/third_party/ascend/python/src/ir.cc b/third_party/ascend/python/src/ir.cc index b2e743fec..7a5e99620 100644 --- a/third_party/ascend/python/src/ir.cc +++ b/third_party/ascend/python/src/ir.cc @@ -733,12 +733,15 @@ void init_triton_ir(py::module &&m) { .def("get_fp8e4nv", [](TritonOpBuilder &self, double v) -> Value { return self.create( - FloatAttr::get(self.getBuilder().getFloat8E4M3FNType(), v)); + // FLAGTREE + // FloatAttr::get(self.getBuilder().getFloat8E4M3FNType(), v)); + FloatAttr::get(self.getBuilder().getType(), v)); }) .def("get_fp8e5", [](TritonOpBuilder &self, double v) -> Value { return self.create( - FloatAttr::get(self.getBuilder().getFloat8E5M2Type(), v)); + // FloatAttr::get(self.getBuilder().getFloat8E5M2Type(), v)); + FloatAttr::get(self.getBuilder().getType(), v)); }) .def("get_null_value", [](TritonOpBuilder &self, Type type) -> Value { diff --git a/third_party/ascend/python/src/ir.h b/third_party/ascend/python/src/ir.h new file mode 100644 index 000000000..47e242c1f --- /dev/null +++ b/third_party/ascend/python/src/ir.h @@ -0,0 +1,105 @@ +// FIXME: +#pragma once +#include "mlir/IR/Builders.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include +#include +#include +#include + +namespace py = pybind11; + +using namespace mlir; +using namespace triton; + +// A custom op builder that keeps track of the last location +class TritonOpBuilder { +public: + TritonOpBuilder(mlir::MLIRContext *context, + const std::string &compile_mode = "simd") { + builder = std::make_unique(context); + lastLoc = std::make_unique(builder->getUnknownLoc()); + this->compile_mode = compile_mode; + } + + OpBuilder &getBuilder() { return *builder; } + + bool isLineInfoEnabled() { return lineInfoEnabled; } + + bool isSimtMode() const { return compile_mode == "simt"; } + + void setLastLoc(Location loc) { + if (lineInfoEnabled) + lastLoc = std::make_unique(loc); + } + + void setLastLoc(const std::string &fileName, int line, int column) { + auto context = builder->getContext(); + setLastLoc(FileLineColLoc::get(context, fileName, line, column)); + } + + Location getLastLoc() { + assert(lastLoc); + return *lastLoc; + } + + void setInsertionPointToStart(Block &block) { + if (!block.empty()) + setLastLoc(block.begin()->getLoc()); + else + setLastLoc(builder->getUnknownLoc()); + builder->setInsertionPointToStart(&block); + } + + void setInsertionPointToEnd(Block &block) { + if (!block.empty()) + setLastLoc(block.back().getLoc()); + else + setLastLoc(builder->getUnknownLoc()); + builder->setInsertionPointToEnd(&block); + } + + void setInsertionPointAfter(Operation &op) { + setLastLoc(op.getLoc()); + builder->setInsertionPointAfter(&op); + } + + void restoreInsertionPoint(OpBuilder::InsertPoint pt) { + if (pt.isSet() && pt.getPoint() != pt.getBlock()->end()) + setLastLoc(pt.getPoint()->getLoc()); + else + setLastLoc(builder->getUnknownLoc()); + builder->restoreInsertionPoint(pt); + } + + template OpTy create(Args &&...args) { + auto loc = getLastLoc(); + return builder->create(loc, std::forward(args)...); + } + + // Overload to create or fold a single result operation. + template + std::enable_if_t(), Value> + createOrFold(Args &&...args) { + auto loc = getLastLoc(); + return builder->createOrFold(loc, std::forward(args)...); + } + + // Overload to create or fold a zero result operation. + template + std::enable_if_t(), OpTy> + createOrFold(Args &&...args) { + auto loc = getLastLoc(); + return builder->createOrFold(loc, std::forward(args)...); + } + +private: + std::unique_ptr builder; + std::unique_ptr lastLoc; + bool lineInfoEnabled = !triton::tools::getBoolEnv("TRITON_DISABLE_LINE_INFO"); + std::string compile_mode; +}; + +namespace ir { +extern py::class_ *getBuilderClass(); +} // namespace ir diff --git a/third_party/ascend/test/Conversion/TritonOp/argmax_uint.mlir b/third_party/ascend/test/Conversion/TritonOp/argmax_uint.mlir index 8a9d1358c..bfb3257b1 100644 --- a/third_party/ascend/test/Conversion/TritonOp/argmax_uint.mlir +++ b/third_party/ascend/test/Conversion/TritonOp/argmax_uint.mlir @@ -1,89 +1,89 @@ -// RUN: triton-adapter-opt --triton-linearize --discrete-mask-access-conversion --triton-to-annotation '--triton-to-unstructure=compile-on-a5=False force_simt_template=False' --triton-to-hivm --triton-to-hfusion --triton-to-llvm --bubble-up-operation '--triton-to-linalg=global-kernel=false named-ops=True enable-nd2nz-on-vector=False compile-on-a5=False' --split-input-file %s | FileCheck %s - - -module { - tt.func public @triton_argmax_1d(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/test_u8_argmax.py":21:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/test_u8_argmax.py":21:0), %arg2: i32 {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/test_u8_argmax.py":21:0)) attributes {noinline = false} { - %0 = tt.get_program_id x : i32 loc(#loc1) - %1 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> loc(#loc2) - %2 = tt.splat %0 : i32 -> tensor<16xi32> loc(#loc3) - %3 = arith.addi %2, %1 : tensor<16xi32> loc(#loc3) - %4 = tt.splat %arg0 : !tt.ptr -> tensor<16x!tt.ptr> loc(#loc4) - %5 = tt.addptr %4, %3 : tensor<16x!tt.ptr>, tensor<16xi32> loc(#loc4) - %6 = tt.load %5 : tensor<16x!tt.ptr> loc(#loc5) - %7:2 = "tt.reduce"(%6, %1) <{axis = 0 : i32}> ({ - ^bb0(%arg3: i8 loc(callsite(#loc21 at #loc8)), %arg4: i32 loc(callsite(#loc21 at #loc8)), %arg5: i8 loc(callsite(#loc21 at #loc8)), %arg6: i32 loc(callsite(#loc21 at #loc8))): - %8 = arith.cmpi eq, %arg3, %arg5 : i8 loc(#loc45) - %9 = arith.cmpi slt, %arg4, %arg6 : i32 loc(#loc46) - %10 = arith.andi %8, %9 : i1 loc(#loc47) - %11 = arith.cmpi ugt, %arg3, %arg5 : i8 loc(#loc48) - %12 = arith.ori %11, %10 : i1 loc(#loc49) - %13 = arith.select %12, %arg3, %arg5 : i8 loc(#loc50) - %14 = arith.select %12, %arg4, %arg6 : i32 loc(#loc51) - tt.reduce.return %13, %14 : i8, i32 loc(#loc29) - }) : (tensor<16xi8>, tensor<16xi32>) -> (i8, i32) loc(#loc29) - tt.store %arg1, %7#1 : !tt.ptr loc(#loc18) - tt.return loc(#loc19) - } loc(#loc) -} loc(#loc) - - -// CHECK: %[[VAL_2:[A-Za-z0-9_]+]] = arith.cmpi ugt, %[[VAL_0]], %[[VAL_1]] : i8 -// ----- - - -module { - tt.func public @triton_argmax_1d(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/test_u16_argmax.py":20:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/test_u16_argmax.py":20:0), %arg2: i32 {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/test_u16_argmax.py":20:0)) attributes {noinline = false} { - %0 = tt.get_program_id x : i32 loc(#loc1) - %1 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> loc(#loc2) - %2 = tt.splat %0 : i32 -> tensor<16xi32> loc(#loc3) - %3 = arith.addi %2, %1 : tensor<16xi32> loc(#loc3) - %4 = tt.splat %arg0 : !tt.ptr -> tensor<16x!tt.ptr> loc(#loc4) - %5 = tt.addptr %4, %3 : tensor<16x!tt.ptr>, tensor<16xi32> loc(#loc4) - %6 = tt.load %5 : tensor<16x!tt.ptr> loc(#loc5) - %7:2 = "tt.reduce"(%6, %1) <{axis = 0 : i32}> ({ - ^bb0(%arg3: i16 loc(callsite(#loc21 at #loc8)), %arg4: i32 loc(callsite(#loc21 at #loc8)), %arg5: i16 loc(callsite(#loc21 at #loc8)), %arg6: i32 loc(callsite(#loc21 at #loc8))): - %8 = arith.cmpi eq, %arg3, %arg5 : i16 loc(#loc45) - %9 = arith.cmpi slt, %arg4, %arg6 : i32 loc(#loc46) - %10 = arith.andi %8, %9 : i1 loc(#loc47) - %11 = arith.cmpi ugt, %arg3, %arg5 : i16 loc(#loc48) - %12 = arith.ori %11, %10 : i1 loc(#loc49) - %13 = arith.select %12, %arg3, %arg5 : i16 loc(#loc50) - %14 = arith.select %12, %arg4, %arg6 : i32 loc(#loc51) - tt.reduce.return %13, %14 : i16, i32 loc(#loc29) - }) : (tensor<16xi16>, tensor<16xi32>) -> (i16, i32) loc(#loc29) - tt.store %arg1, %7#1 : !tt.ptr loc(#loc18) - tt.return loc(#loc19) - } loc(#loc) -} loc(#loc) - - -// CHECK: %[[VAL_2:[A-Za-z0-9_]+]] = arith.cmpi ugt, %[[VAL_0]], %[[VAL_1]] : i16 -// ----- - -module { - tt.func public @triton_argmax_1d(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/test_u32_argmax.py":21:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/test_u32_argmax.py":21:0), %arg2: i32 {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/test_u32_argmax.py":21:0)) attributes {noinline = false} { - %0 = tt.get_program_id x : i32 loc(#loc1) - %1 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> loc(#loc2) - %2 = tt.splat %0 : i32 -> tensor<16xi32> loc(#loc3) - %3 = arith.addi %2, %1 : tensor<16xi32> loc(#loc3) - %4 = tt.splat %arg0 : !tt.ptr -> tensor<16x!tt.ptr> loc(#loc4) - %5 = tt.addptr %4, %3 : tensor<16x!tt.ptr>, tensor<16xi32> loc(#loc4) - %6 = tt.load %5 : tensor<16x!tt.ptr> loc(#loc5) - %7:2 = "tt.reduce"(%6, %1) <{axis = 0 : i32}> ({ - ^bb0(%arg3: i32 loc(callsite(#loc21 at #loc8)), %arg4: i32 loc(callsite(#loc21 at #loc8)), %arg5: i32 loc(callsite(#loc21 at #loc8)), %arg6: i32 loc(callsite(#loc21 at #loc8))): - %8 = arith.cmpi eq, %arg3, %arg5 : i32 loc(#loc45) - %9 = arith.cmpi slt, %arg4, %arg6 : i32 loc(#loc46) - %10 = arith.andi %8, %9 : i1 loc(#loc47) - %11 = arith.cmpi ugt, %arg3, %arg5 : i32 loc(#loc48) - %12 = arith.ori %11, %10 : i1 loc(#loc49) - %13 = arith.select %12, %arg3, %arg5 : i32 loc(#loc50) - %14 = arith.select %12, %arg4, %arg6 : i32 loc(#loc51) - tt.reduce.return %13, %14 : i32, i32 loc(#loc29) - }) : (tensor<16xi32>, tensor<16xi32>) -> (i32, i32) loc(#loc29) - tt.store %arg1, %7#1 : !tt.ptr loc(#loc18) - tt.return loc(#loc19) - } loc(#loc) -} loc(#loc) - -// CHECK: %[[VAL_2:[A-Za-z0-9_]+]] = arith.cmpi ugt, %[[VAL_0]], %[[VAL_1]] : i32 +// RUN: triton-adapter-opt --triton-linearize --discrete-mask-access-conversion --triton-to-annotation '--triton-to-unstructure=compile-on-a5=False force_simt_template=False' --triton-to-hivm --triton-to-hfusion --triton-to-llvm --bubble-up-operation '--triton-to-linalg=global-kernel=false named-ops=True enable-nd2nz-on-vector=False compile-on-a5=False' --split-input-file %s | FileCheck %s + + +module { + tt.func public @triton_argmax_1d(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/test_u8_argmax.py":21:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/test_u8_argmax.py":21:0), %arg2: i32 {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/test_u8_argmax.py":21:0)) attributes {noinline = false} { + %0 = tt.get_program_id x : i32 loc(#loc1) + %1 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> loc(#loc2) + %2 = tt.splat %0 : i32 -> tensor<16xi32> loc(#loc3) + %3 = arith.addi %2, %1 : tensor<16xi32> loc(#loc3) + %4 = tt.splat %arg0 : !tt.ptr -> tensor<16x!tt.ptr> loc(#loc4) + %5 = tt.addptr %4, %3 : tensor<16x!tt.ptr>, tensor<16xi32> loc(#loc4) + %6 = tt.load %5 : tensor<16x!tt.ptr> loc(#loc5) + %7:2 = "tt.reduce"(%6, %1) <{axis = 0 : i32}> ({ + ^bb0(%arg3: i8 loc(callsite(#loc21 at #loc8)), %arg4: i32 loc(callsite(#loc21 at #loc8)), %arg5: i8 loc(callsite(#loc21 at #loc8)), %arg6: i32 loc(callsite(#loc21 at #loc8))): + %8 = arith.cmpi eq, %arg3, %arg5 : i8 loc(#loc45) + %9 = arith.cmpi slt, %arg4, %arg6 : i32 loc(#loc46) + %10 = arith.andi %8, %9 : i1 loc(#loc47) + %11 = arith.cmpi ugt, %arg3, %arg5 : i8 loc(#loc48) + %12 = arith.ori %11, %10 : i1 loc(#loc49) + %13 = arith.select %12, %arg3, %arg5 : i8 loc(#loc50) + %14 = arith.select %12, %arg4, %arg6 : i32 loc(#loc51) + tt.reduce.return %13, %14 : i8, i32 loc(#loc29) + }) : (tensor<16xi8>, tensor<16xi32>) -> (i8, i32) loc(#loc29) + tt.store %arg1, %7#1 : !tt.ptr loc(#loc18) + tt.return loc(#loc19) + } loc(#loc) +} loc(#loc) + + +// CHECK: %[[VAL_2:[A-Za-z0-9_]+]] = arith.cmpi ugt, %[[VAL_0]], %[[VAL_1]] : i8 +// ----- + + +module { + tt.func public @triton_argmax_1d(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/test_u16_argmax.py":20:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/test_u16_argmax.py":20:0), %arg2: i32 {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/test_u16_argmax.py":20:0)) attributes {noinline = false} { + %0 = tt.get_program_id x : i32 loc(#loc1) + %1 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> loc(#loc2) + %2 = tt.splat %0 : i32 -> tensor<16xi32> loc(#loc3) + %3 = arith.addi %2, %1 : tensor<16xi32> loc(#loc3) + %4 = tt.splat %arg0 : !tt.ptr -> tensor<16x!tt.ptr> loc(#loc4) + %5 = tt.addptr %4, %3 : tensor<16x!tt.ptr>, tensor<16xi32> loc(#loc4) + %6 = tt.load %5 : tensor<16x!tt.ptr> loc(#loc5) + %7:2 = "tt.reduce"(%6, %1) <{axis = 0 : i32}> ({ + ^bb0(%arg3: i16 loc(callsite(#loc21 at #loc8)), %arg4: i32 loc(callsite(#loc21 at #loc8)), %arg5: i16 loc(callsite(#loc21 at #loc8)), %arg6: i32 loc(callsite(#loc21 at #loc8))): + %8 = arith.cmpi eq, %arg3, %arg5 : i16 loc(#loc45) + %9 = arith.cmpi slt, %arg4, %arg6 : i32 loc(#loc46) + %10 = arith.andi %8, %9 : i1 loc(#loc47) + %11 = arith.cmpi ugt, %arg3, %arg5 : i16 loc(#loc48) + %12 = arith.ori %11, %10 : i1 loc(#loc49) + %13 = arith.select %12, %arg3, %arg5 : i16 loc(#loc50) + %14 = arith.select %12, %arg4, %arg6 : i32 loc(#loc51) + tt.reduce.return %13, %14 : i16, i32 loc(#loc29) + }) : (tensor<16xi16>, tensor<16xi32>) -> (i16, i32) loc(#loc29) + tt.store %arg1, %7#1 : !tt.ptr loc(#loc18) + tt.return loc(#loc19) + } loc(#loc) +} loc(#loc) + + +// CHECK: %[[VAL_2:[A-Za-z0-9_]+]] = arith.cmpi ugt, %[[VAL_0]], %[[VAL_1]] : i16 +// ----- + +module { + tt.func public @triton_argmax_1d(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/test_u32_argmax.py":21:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/test_u32_argmax.py":21:0), %arg2: i32 {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/test_u32_argmax.py":21:0)) attributes {noinline = false} { + %0 = tt.get_program_id x : i32 loc(#loc1) + %1 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> loc(#loc2) + %2 = tt.splat %0 : i32 -> tensor<16xi32> loc(#loc3) + %3 = arith.addi %2, %1 : tensor<16xi32> loc(#loc3) + %4 = tt.splat %arg0 : !tt.ptr -> tensor<16x!tt.ptr> loc(#loc4) + %5 = tt.addptr %4, %3 : tensor<16x!tt.ptr>, tensor<16xi32> loc(#loc4) + %6 = tt.load %5 : tensor<16x!tt.ptr> loc(#loc5) + %7:2 = "tt.reduce"(%6, %1) <{axis = 0 : i32}> ({ + ^bb0(%arg3: i32 loc(callsite(#loc21 at #loc8)), %arg4: i32 loc(callsite(#loc21 at #loc8)), %arg5: i32 loc(callsite(#loc21 at #loc8)), %arg6: i32 loc(callsite(#loc21 at #loc8))): + %8 = arith.cmpi eq, %arg3, %arg5 : i32 loc(#loc45) + %9 = arith.cmpi slt, %arg4, %arg6 : i32 loc(#loc46) + %10 = arith.andi %8, %9 : i1 loc(#loc47) + %11 = arith.cmpi ugt, %arg3, %arg5 : i32 loc(#loc48) + %12 = arith.ori %11, %10 : i1 loc(#loc49) + %13 = arith.select %12, %arg3, %arg5 : i32 loc(#loc50) + %14 = arith.select %12, %arg4, %arg6 : i32 loc(#loc51) + tt.reduce.return %13, %14 : i32, i32 loc(#loc29) + }) : (tensor<16xi32>, tensor<16xi32>) -> (i32, i32) loc(#loc29) + tt.store %arg1, %7#1 : !tt.ptr loc(#loc18) + tt.return loc(#loc19) + } loc(#loc) +} loc(#loc) + +// CHECK: %[[VAL_2:[A-Za-z0-9_]+]] = arith.cmpi ugt, %[[VAL_0]], %[[VAL_1]] : i32 // ----- diff --git a/third_party/ascend/test/Conversion/TritonOp/argmin_uint.mlir b/third_party/ascend/test/Conversion/TritonOp/argmin_uint.mlir index 5b24421f2..921ce9f68 100644 --- a/third_party/ascend/test/Conversion/TritonOp/argmin_uint.mlir +++ b/third_party/ascend/test/Conversion/TritonOp/argmin_uint.mlir @@ -1,89 +1,89 @@ -// RUN: triton-adapter-opt --triton-linearize --discrete-mask-access-conversion --triton-to-annotation '--triton-to-unstructure=compile-on-a5=False force_simt_template=False' --triton-to-hivm --triton-to-hfusion --triton-to-llvm --bubble-up-operation '--triton-to-linalg=global-kernel=false named-ops=True enable-nd2nz-on-vector=False compile-on-a5=False' --split-input-file %s | FileCheck %s - -module { - tt.func public @triton_argmin_1d8(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/test_u8_argmin.py":22:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/test_u8_argmin.py":22:0), %arg2: i32 {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/test_u8_argmin.py":22:0)) attributes {noinline = false} { - %0 = tt.get_program_id x : i32 loc(#loc1) - %1 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> loc(#loc2) - %2 = tt.splat %0 : i32 -> tensor<16xi32> loc(#loc3) - %3 = arith.addi %2, %1 : tensor<16xi32> loc(#loc3) - %4 = tt.splat %arg0 : !tt.ptr -> tensor<16x!tt.ptr> loc(#loc4) - %5 = tt.addptr %4, %3 : tensor<16x!tt.ptr>, tensor<16xi32> loc(#loc4) - %6 = tt.load %5 : tensor<16x!tt.ptr> loc(#loc5) - %7:2 = "tt.reduce"(%6, %1) <{axis = 0 : i32}> ({ - ^bb0(%arg3: i8 loc(callsite(#loc21 at #loc8)), %arg4: i32 loc(callsite(#loc21 at #loc8)), %arg5: i8 loc(callsite(#loc21 at #loc8)), %arg6: i32 loc(callsite(#loc21 at #loc8))): - %8 = arith.cmpi eq, %arg3, %arg5 : i8 loc(#loc45) - %9 = arith.cmpi slt, %arg4, %arg6 : i32 loc(#loc46) - %10 = arith.andi %8, %9 : i1 loc(#loc47) - %11 = arith.cmpi ult, %arg3, %arg5 : i8 loc(#loc48) - %12 = arith.ori %11, %10 : i1 loc(#loc49) - %13 = arith.select %12, %arg3, %arg5 : i8 loc(#loc50) - %14 = arith.select %12, %arg4, %arg6 : i32 loc(#loc51) - tt.reduce.return %13, %14 : i8, i32 loc(#loc29) - }) : (tensor<16xi8>, tensor<16xi32>) -> (i8, i32) loc(#loc29) - tt.store %arg1, %7#1 : !tt.ptr loc(#loc18) - tt.return loc(#loc19) - } loc(#loc) -} loc(#loc) - - -// CHECK: %[[VAL_2:[A-Za-z0-9_]+]] = arith.cmpi ult, %[[VAL_0]], %[[VAL_1]] : i8 -// --------- - -module { - tt.func public @triton_argmin_1d(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/test_u8_argmin.py":45:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/test_u8_argmin.py":45:0), %arg2: i32 {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/test_u8_argmin.py":45:0)) attributes {noinline = false} { - %0 = tt.get_program_id x : i32 loc(#loc1) - %1 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> loc(#loc2) - %2 = tt.splat %0 : i32 -> tensor<16xi32> loc(#loc3) - %3 = arith.addi %2, %1 : tensor<16xi32> loc(#loc3) - %4 = tt.splat %arg0 : !tt.ptr -> tensor<16x!tt.ptr> loc(#loc4) - %5 = tt.addptr %4, %3 : tensor<16x!tt.ptr>, tensor<16xi32> loc(#loc4) - %6 = tt.load %5 : tensor<16x!tt.ptr> loc(#loc5) - %7:2 = "tt.reduce"(%6, %1) <{axis = 0 : i32}> ({ - ^bb0(%arg3: i16 loc(callsite(#loc21 at #loc8)), %arg4: i32 loc(callsite(#loc21 at #loc8)), %arg5: i16 loc(callsite(#loc21 at #loc8)), %arg6: i32 loc(callsite(#loc21 at #loc8))): - %8 = arith.cmpi eq, %arg3, %arg5 : i16 loc(#loc45) - %9 = arith.cmpi slt, %arg4, %arg6 : i32 loc(#loc46) - %10 = arith.andi %8, %9 : i1 loc(#loc47) - %11 = arith.cmpi ult, %arg3, %arg5 : i16 loc(#loc48) - %12 = arith.ori %11, %10 : i1 loc(#loc49) - %13 = arith.select %12, %arg3, %arg5 : i16 loc(#loc50) - %14 = arith.select %12, %arg4, %arg6 : i32 loc(#loc51) - tt.reduce.return %13, %14 : i16, i32 loc(#loc29) - }) : (tensor<16xi16>, tensor<16xi32>) -> (i16, i32) loc(#loc29) - tt.store %arg1, %7#1 : !tt.ptr loc(#loc18) - tt.return loc(#loc19) - } loc(#loc) -} loc(#loc) - - -// CHECK: %[[VAL_2:[A-Za-z0-9_]+]] = arith.cmpi ult, %[[VAL_0]], %[[VAL_1]] : i16 -// --------- - - -module { - tt.func public @triton_argmin_1d(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/test_u8_argmin.py":45:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/test_u8_argmin.py":45:0), %arg2: i32 {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/test_u8_argmin.py":45:0)) attributes {noinline = false} { - %0 = tt.get_program_id x : i32 loc(#loc1) - %1 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> loc(#loc2) - %2 = tt.splat %0 : i32 -> tensor<16xi32> loc(#loc3) - %3 = arith.addi %2, %1 : tensor<16xi32> loc(#loc3) - %4 = tt.splat %arg0 : !tt.ptr -> tensor<16x!tt.ptr> loc(#loc4) - %5 = tt.addptr %4, %3 : tensor<16x!tt.ptr>, tensor<16xi32> loc(#loc4) - %6 = tt.load %5 : tensor<16x!tt.ptr> loc(#loc5) - %7:2 = "tt.reduce"(%6, %1) <{axis = 0 : i32}> ({ - ^bb0(%arg3: i32 loc(callsite(#loc21 at #loc8)), %arg4: i32 loc(callsite(#loc21 at #loc8)), %arg5: i32 loc(callsite(#loc21 at #loc8)), %arg6: i32 loc(callsite(#loc21 at #loc8))): - %8 = arith.cmpi eq, %arg3, %arg5 : i32 loc(#loc45) - %9 = arith.cmpi slt, %arg4, %arg6 : i32 loc(#loc46) - %10 = arith.andi %8, %9 : i1 loc(#loc47) - %11 = arith.cmpi ult, %arg3, %arg5 : i32 loc(#loc48) - %12 = arith.ori %11, %10 : i1 loc(#loc49) - %13 = arith.select %12, %arg3, %arg5 : i32 loc(#loc50) - %14 = arith.select %12, %arg4, %arg6 : i32 loc(#loc51) - tt.reduce.return %13, %14 : i32, i32 loc(#loc29) - }) : (tensor<16xi32>, tensor<16xi32>) -> (i32, i32) loc(#loc29) - tt.store %arg1, %7#1 : !tt.ptr loc(#loc18) - tt.return loc(#loc19) - } loc(#loc) -} loc(#loc) - - -// CHECK: %[[VAL_2:[A-Za-z0-9_]+]] = arith.cmpi ult, %[[VAL_0]], %[[VAL_1]] : i32 +// RUN: triton-adapter-opt --triton-linearize --discrete-mask-access-conversion --triton-to-annotation '--triton-to-unstructure=compile-on-a5=False force_simt_template=False' --triton-to-hivm --triton-to-hfusion --triton-to-llvm --bubble-up-operation '--triton-to-linalg=global-kernel=false named-ops=True enable-nd2nz-on-vector=False compile-on-a5=False' --split-input-file %s | FileCheck %s + +module { + tt.func public @triton_argmin_1d8(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/test_u8_argmin.py":22:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/test_u8_argmin.py":22:0), %arg2: i32 {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/test_u8_argmin.py":22:0)) attributes {noinline = false} { + %0 = tt.get_program_id x : i32 loc(#loc1) + %1 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> loc(#loc2) + %2 = tt.splat %0 : i32 -> tensor<16xi32> loc(#loc3) + %3 = arith.addi %2, %1 : tensor<16xi32> loc(#loc3) + %4 = tt.splat %arg0 : !tt.ptr -> tensor<16x!tt.ptr> loc(#loc4) + %5 = tt.addptr %4, %3 : tensor<16x!tt.ptr>, tensor<16xi32> loc(#loc4) + %6 = tt.load %5 : tensor<16x!tt.ptr> loc(#loc5) + %7:2 = "tt.reduce"(%6, %1) <{axis = 0 : i32}> ({ + ^bb0(%arg3: i8 loc(callsite(#loc21 at #loc8)), %arg4: i32 loc(callsite(#loc21 at #loc8)), %arg5: i8 loc(callsite(#loc21 at #loc8)), %arg6: i32 loc(callsite(#loc21 at #loc8))): + %8 = arith.cmpi eq, %arg3, %arg5 : i8 loc(#loc45) + %9 = arith.cmpi slt, %arg4, %arg6 : i32 loc(#loc46) + %10 = arith.andi %8, %9 : i1 loc(#loc47) + %11 = arith.cmpi ult, %arg3, %arg5 : i8 loc(#loc48) + %12 = arith.ori %11, %10 : i1 loc(#loc49) + %13 = arith.select %12, %arg3, %arg5 : i8 loc(#loc50) + %14 = arith.select %12, %arg4, %arg6 : i32 loc(#loc51) + tt.reduce.return %13, %14 : i8, i32 loc(#loc29) + }) : (tensor<16xi8>, tensor<16xi32>) -> (i8, i32) loc(#loc29) + tt.store %arg1, %7#1 : !tt.ptr loc(#loc18) + tt.return loc(#loc19) + } loc(#loc) +} loc(#loc) + + +// CHECK: %[[VAL_2:[A-Za-z0-9_]+]] = arith.cmpi ult, %[[VAL_0]], %[[VAL_1]] : i8 +// --------- + +module { + tt.func public @triton_argmin_1d(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/test_u8_argmin.py":45:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/test_u8_argmin.py":45:0), %arg2: i32 {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/test_u8_argmin.py":45:0)) attributes {noinline = false} { + %0 = tt.get_program_id x : i32 loc(#loc1) + %1 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> loc(#loc2) + %2 = tt.splat %0 : i32 -> tensor<16xi32> loc(#loc3) + %3 = arith.addi %2, %1 : tensor<16xi32> loc(#loc3) + %4 = tt.splat %arg0 : !tt.ptr -> tensor<16x!tt.ptr> loc(#loc4) + %5 = tt.addptr %4, %3 : tensor<16x!tt.ptr>, tensor<16xi32> loc(#loc4) + %6 = tt.load %5 : tensor<16x!tt.ptr> loc(#loc5) + %7:2 = "tt.reduce"(%6, %1) <{axis = 0 : i32}> ({ + ^bb0(%arg3: i16 loc(callsite(#loc21 at #loc8)), %arg4: i32 loc(callsite(#loc21 at #loc8)), %arg5: i16 loc(callsite(#loc21 at #loc8)), %arg6: i32 loc(callsite(#loc21 at #loc8))): + %8 = arith.cmpi eq, %arg3, %arg5 : i16 loc(#loc45) + %9 = arith.cmpi slt, %arg4, %arg6 : i32 loc(#loc46) + %10 = arith.andi %8, %9 : i1 loc(#loc47) + %11 = arith.cmpi ult, %arg3, %arg5 : i16 loc(#loc48) + %12 = arith.ori %11, %10 : i1 loc(#loc49) + %13 = arith.select %12, %arg3, %arg5 : i16 loc(#loc50) + %14 = arith.select %12, %arg4, %arg6 : i32 loc(#loc51) + tt.reduce.return %13, %14 : i16, i32 loc(#loc29) + }) : (tensor<16xi16>, tensor<16xi32>) -> (i16, i32) loc(#loc29) + tt.store %arg1, %7#1 : !tt.ptr loc(#loc18) + tt.return loc(#loc19) + } loc(#loc) +} loc(#loc) + + +// CHECK: %[[VAL_2:[A-Za-z0-9_]+]] = arith.cmpi ult, %[[VAL_0]], %[[VAL_1]] : i16 +// --------- + + +module { + tt.func public @triton_argmin_1d(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/test_u8_argmin.py":45:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/test_u8_argmin.py":45:0), %arg2: i32 {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/test_u8_argmin.py":45:0)) attributes {noinline = false} { + %0 = tt.get_program_id x : i32 loc(#loc1) + %1 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> loc(#loc2) + %2 = tt.splat %0 : i32 -> tensor<16xi32> loc(#loc3) + %3 = arith.addi %2, %1 : tensor<16xi32> loc(#loc3) + %4 = tt.splat %arg0 : !tt.ptr -> tensor<16x!tt.ptr> loc(#loc4) + %5 = tt.addptr %4, %3 : tensor<16x!tt.ptr>, tensor<16xi32> loc(#loc4) + %6 = tt.load %5 : tensor<16x!tt.ptr> loc(#loc5) + %7:2 = "tt.reduce"(%6, %1) <{axis = 0 : i32}> ({ + ^bb0(%arg3: i32 loc(callsite(#loc21 at #loc8)), %arg4: i32 loc(callsite(#loc21 at #loc8)), %arg5: i32 loc(callsite(#loc21 at #loc8)), %arg6: i32 loc(callsite(#loc21 at #loc8))): + %8 = arith.cmpi eq, %arg3, %arg5 : i32 loc(#loc45) + %9 = arith.cmpi slt, %arg4, %arg6 : i32 loc(#loc46) + %10 = arith.andi %8, %9 : i1 loc(#loc47) + %11 = arith.cmpi ult, %arg3, %arg5 : i32 loc(#loc48) + %12 = arith.ori %11, %10 : i1 loc(#loc49) + %13 = arith.select %12, %arg3, %arg5 : i32 loc(#loc50) + %14 = arith.select %12, %arg4, %arg6 : i32 loc(#loc51) + tt.reduce.return %13, %14 : i32, i32 loc(#loc29) + }) : (tensor<16xi32>, tensor<16xi32>) -> (i32, i32) loc(#loc29) + tt.store %arg1, %7#1 : !tt.ptr loc(#loc18) + tt.return loc(#loc19) + } loc(#loc) +} loc(#loc) + + +// CHECK: %[[VAL_2:[A-Za-z0-9_]+]] = arith.cmpi ult, %[[VAL_0]], %[[VAL_1]] : i32 // --------- diff --git a/third_party/ascend/test/Conversion/TritonOp/associative_scan.mlir b/third_party/ascend/test/Conversion/TritonOp/associative_scan.mlir index 14834cfcb..01327a19a 100644 --- a/third_party/ascend/test/Conversion/TritonOp/associative_scan.mlir +++ b/third_party/ascend/test/Conversion/TritonOp/associative_scan.mlir @@ -1,180 +1,180 @@ -// RUN: triton-adapter-opt --triton-linearize --discrete-mask-access-conversion --triton-to-annotation --triton-to-unstructure --triton-to-hivm --triton-to-hfusion --triton-to-llvm --bubble-up-operation '--triton-to-linalg=global-kernel=false named-ops=True enable-nd2nz-on-vector=False' --split-input-file %s | FileCheck %s - -// === i8 u8 version === -module { - tt.func public @fn_npu_u8( - %arg0: !tt.ptr {tt.divisibility = 16 : i32}, - %arg1: !tt.ptr {tt.divisibility = 16 : i32} - ) { - %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> - %1 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<8x!tt.ptr>, tensor<8xi32> - %3 = tt.load %2 : tensor<8x!tt.ptr> - %4 = "tt.scan"(%3) <{axis = 0 : i32, reverse = false}> ({ - ^bb0(%arg2: i8, %arg3: i8): - %7 = arith.maxui %arg2, %arg3 : i8 - tt.scan.return %7 : i8 - }) : (tensor<8xi8>) -> tensor<8xi8> - %5 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr> - %6 = tt.addptr %5, %0 : tensor<8x!tt.ptr>, tensor<8xi32> - tt.store %6, %4 : tensor<8x!tt.ptr> - tt.return - } -} - -// ----- -// CHECK: %[[INPUT_BUF:.*]] = memref.alloc() : memref<8xi8> -// CHECK: memref.copy {{.*}}, %[[INPUT_BUF]] : memref<8xi8{{.*}}> to memref<8xi8> - -// CHECK: %[[OUTPUT_BUF:.*]] = memref.alloc() : memref<8xi8> - -// Initialize first element -// CHECK: %[[FIRST_VAL:.*]] = memref.load %[[INPUT_BUF]][%c0] : memref<8xi8> -// CHECK: memref.store %[[FIRST_VAL]], %[[OUTPUT_BUF]][%c0] : memref<8xi8> - -// Main scan loop -// CHECK: scf.for %{{.*}} = %c1 to %c8 step %c1 { -// CHECK-NEXT: %[[PREV_IDX:.*]] = arith.subi %{{.*}}, %c1 : index -// CHECK-NEXT: %[[CURR_INPUT:.*]] = memref.load %[[INPUT_BUF]][%{{.*}}] : memref<8xi8> -// CHECK-NEXT: %[[PREV_OUTPUT:.*]] = memref.load %[[OUTPUT_BUF]][%[[PREV_IDX]]] : memref<8xi8> -// CHECK-NEXT: %[[COMBINED:.*]] = arith.maxui %[[PREV_OUTPUT]], %[[CURR_INPUT]] : i8 -// CHECK-NEXT: memref.store %[[COMBINED]], %[[OUTPUT_BUF]][%{{.*}}] : memref<8xi8> -// CHECK-NEXT: } - -// Final materialization -// CHECK: bufferization.materialize_in_destination - - -// === i16 u16 version === -module { - tt.func public @fn_npu_u16( - %arg0: !tt.ptr {tt.divisibility = 16 : i32}, - %arg1: !tt.ptr {tt.divisibility = 16 : i32} - ) { - %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> - %1 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<8x!tt.ptr>, tensor<8xi32> - %3 = tt.load %2 : tensor<8x!tt.ptr> - %4 = "tt.scan"(%3) <{axis = 0 : i32, reverse = false}> ({ - ^bb0(%arg2: i16, %arg3: i16): - %7 = arith.maxui %arg2, %arg3 : i16 - tt.scan.return %7 : i16 - }) : (tensor<8xi16>) -> tensor<8xi16> - %5 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr> - %6 = tt.addptr %5, %0 : tensor<8x!tt.ptr>, tensor<8xi32> - tt.store %6, %4 : tensor<8x!tt.ptr> - tt.return - } -} - -// ----- -// CHECK: %[[INPUT_BUF:.*]] = memref.alloc() : memref<8xi16> -// CHECK: memref.copy {{.*}}, %[[INPUT_BUF]] : memref<8xi16{{.*}}> to memref<8xi16> - -// CHECK: %[[OUTPUT_BUF:.*]] = memref.alloc() : memref<8xi16> - -// Initialize first element -// CHECK: %[[FIRST_VAL:.*]] = memref.load %[[INPUT_BUF]][%c0] : memref<8xi16> -// CHECK: memref.store %[[FIRST_VAL]], %[[OUTPUT_BUF]][%c0] : memref<8xi16> - -// Main scan loop -// CHECK: scf.for %{{.*}} = %c1 to %c8 step %c1 { -// CHECK-NEXT: %[[PREV_IDX:.*]] = arith.subi %{{.*}}, %c1 : index -// CHECK-NEXT: %[[CURR_INPUT:.*]] = memref.load %[[INPUT_BUF]][%{{.*}}] : memref<8xi16> -// CHECK-NEXT: %[[PREV_OUTPUT:.*]] = memref.load %[[OUTPUT_BUF]][%[[PREV_IDX]]] : memref<8xi16> -// CHECK-NEXT: %[[COMBINED:.*]] = arith.maxui %[[PREV_OUTPUT]], %[[CURR_INPUT]] : i16 -// CHECK-NEXT: memref.store %[[COMBINED]], %[[OUTPUT_BUF]][%{{.*}}] : memref<8xi16> -// CHECK-NEXT: } - -// Final materialization -// CHECK: bufferization.materialize_in_destination - - -// === i32 u32 version === -module { - tt.func public @fn_npu_u32( - %arg0: !tt.ptr {tt.divisibility = 16 : i32}, - %arg1: !tt.ptr {tt.divisibility = 16 : i32} - ) { - %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> - %1 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<8x!tt.ptr>, tensor<8xi32> - %3 = tt.load %2 : tensor<8x!tt.ptr> - %4 = "tt.scan"(%3) <{axis = 0 : i32, reverse = false}> ({ - ^bb0(%arg2: i32, %arg3: i32): - %7 = arith.maxui %arg2, %arg3 : i32 - tt.scan.return %7 : i32 - }) : (tensor<8xi32>) -> tensor<8xi32> - %5 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr> - %6 = tt.addptr %5, %0 : tensor<8x!tt.ptr>, tensor<8xi32> - tt.store %6, %4 : tensor<8x!tt.ptr> - tt.return - } -} - -// ----- -// CHECK: %[[INPUT_BUF:.*]] = memref.alloc() : memref<8xi32> -// CHECK: memref.copy {{.*}}, %[[INPUT_BUF]] : memref<8xi32{{.*}}> to memref<8xi32> - -// CHECK: %[[OUTPUT_BUF:.*]] = memref.alloc() : memref<8xi32> - -// Initialize first element -// CHECK: %[[FIRST_VAL:.*]] = memref.load %[[INPUT_BUF]][%c0] : memref<8xi32> -// CHECK: memref.store %[[FIRST_VAL]], %[[OUTPUT_BUF]][%c0] : memref<8xi32> - -// Main scan loop -// CHECK: scf.for %{{.*}} = %c1 to %c8 step %c1 { -// CHECK-NEXT: %[[PREV_IDX:.*]] = arith.subi %{{.*}}, %c1 : index -// CHECK-NEXT: %[[CURR_INPUT:.*]] = memref.load %[[INPUT_BUF]][%{{.*}}] : memref<8xi32> -// CHECK-NEXT: %[[PREV_OUTPUT:.*]] = memref.load %[[OUTPUT_BUF]][%[[PREV_IDX]]] : memref<8xi32> -// CHECK-NEXT: %[[COMBINED:.*]] = arith.maxui %[[PREV_OUTPUT]], %[[CURR_INPUT]] : i32 -// CHECK-NEXT: memref.store %[[COMBINED]], %[[OUTPUT_BUF]][%{{.*}}] : memref<8xi32> -// CHECK-NEXT: } - -// Final materialization -// CHECK: bufferization.materialize_in_destination - - -// === i64 u64 version === -module { - tt.func public @fn_npu_u64( - %arg0: !tt.ptr {tt.divisibility = 16 : i32}, - %arg1: !tt.ptr {tt.divisibility = 16 : i32} - ) { - %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> - %1 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<8x!tt.ptr>, tensor<8xi32> - %3 = tt.load %2 : tensor<8x!tt.ptr> - %4 = "tt.scan"(%3) <{axis = 0 : i32, reverse = false}> ({ - ^bb0(%arg2: i64, %arg3: i64): - %7 = arith.maxui %arg2, %arg3 : i64 - tt.scan.return %7 : i64 - }) : (tensor<8xi64>) -> tensor<8xi64> - %5 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr> - %6 = tt.addptr %5, %0 : tensor<8x!tt.ptr>, tensor<8xi32> - tt.store %6, %4 : tensor<8x!tt.ptr> - tt.return - } -} - -// ----- -// CHECK: %[[INPUT_BUF:.*]] = memref.alloc() : memref<8xi64> -// CHECK: memref.copy {{.*}}, %[[INPUT_BUF]] : memref<8xi64{{.*}}> to memref<8xi64> - -// CHECK: %[[OUTPUT_BUF:.*]] = memref.alloc() : memref<8xi64> - -// Initialize first element -// CHECK: %[[FIRST_VAL:.*]] = memref.load %[[INPUT_BUF]][%c0] : memref<8xi64> -// CHECK: memref.store %[[FIRST_VAL]], %[[OUTPUT_BUF]][%c0] : memref<8xi64> - -// Main scan loop -// CHECK: scf.for %{{.*}} = %c1 to %c8 step %c1 { -// CHECK-NEXT: %[[PREV_IDX:.*]] = arith.subi %{{.*}}, %c1 : index -// CHECK-NEXT: %[[CURR_INPUT:.*]] = memref.load %[[INPUT_BUF]][%{{.*}}] : memref<8xi64> -// CHECK-NEXT: %[[PREV_OUTPUT:.*]] = memref.load %[[OUTPUT_BUF]][%[[PREV_IDX]]] : memref<8xi64> -// CHECK-NEXT: %[[COMBINED:.*]] = arith.maxui %[[PREV_OUTPUT]], %[[CURR_INPUT]] : i64 -// CHECK-NEXT: memref.store %[[COMBINED]], %[[OUTPUT_BUF]][%{{.*}}] : memref<8xi64> -// CHECK-NEXT: } - -// Final materialization +// RUN: triton-adapter-opt --triton-linearize --discrete-mask-access-conversion --triton-to-annotation --triton-to-unstructure --triton-to-hivm --triton-to-hfusion --triton-to-llvm --bubble-up-operation '--triton-to-linalg=global-kernel=false named-ops=True enable-nd2nz-on-vector=False' --split-input-file %s | FileCheck %s + +// === i8 u8 version === +module { + tt.func public @fn_npu_u8( + %arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32} + ) { + %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %1 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<8x!tt.ptr>, tensor<8xi32> + %3 = tt.load %2 : tensor<8x!tt.ptr> + %4 = "tt.scan"(%3) <{axis = 0 : i32, reverse = false}> ({ + ^bb0(%arg2: i8, %arg3: i8): + %7 = arith.maxui %arg2, %arg3 : i8 + tt.scan.return %7 : i8 + }) : (tensor<8xi8>) -> tensor<8xi8> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr> + %6 = tt.addptr %5, %0 : tensor<8x!tt.ptr>, tensor<8xi32> + tt.store %6, %4 : tensor<8x!tt.ptr> + tt.return + } +} + +// ----- +// CHECK: %[[INPUT_BUF:.*]] = memref.alloc() : memref<8xi8> +// CHECK: memref.copy {{.*}}, %[[INPUT_BUF]] : memref<8xi8{{.*}}> to memref<8xi8> + +// CHECK: %[[OUTPUT_BUF:.*]] = memref.alloc() : memref<8xi8> + +// Initialize first element +// CHECK: %[[FIRST_VAL:.*]] = memref.load %[[INPUT_BUF]][%c0] : memref<8xi8> +// CHECK: memref.store %[[FIRST_VAL]], %[[OUTPUT_BUF]][%c0] : memref<8xi8> + +// Main scan loop +// CHECK: scf.for %{{.*}} = %c1 to %c8 step %c1 { +// CHECK-NEXT: %[[PREV_IDX:.*]] = arith.subi %{{.*}}, %c1 : index +// CHECK-NEXT: %[[CURR_INPUT:.*]] = memref.load %[[INPUT_BUF]][%{{.*}}] : memref<8xi8> +// CHECK-NEXT: %[[PREV_OUTPUT:.*]] = memref.load %[[OUTPUT_BUF]][%[[PREV_IDX]]] : memref<8xi8> +// CHECK-NEXT: %[[COMBINED:.*]] = arith.maxui %[[PREV_OUTPUT]], %[[CURR_INPUT]] : i8 +// CHECK-NEXT: memref.store %[[COMBINED]], %[[OUTPUT_BUF]][%{{.*}}] : memref<8xi8> +// CHECK-NEXT: } + +// Final materialization +// CHECK: bufferization.materialize_in_destination + + +// === i16 u16 version === +module { + tt.func public @fn_npu_u16( + %arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32} + ) { + %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %1 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<8x!tt.ptr>, tensor<8xi32> + %3 = tt.load %2 : tensor<8x!tt.ptr> + %4 = "tt.scan"(%3) <{axis = 0 : i32, reverse = false}> ({ + ^bb0(%arg2: i16, %arg3: i16): + %7 = arith.maxui %arg2, %arg3 : i16 + tt.scan.return %7 : i16 + }) : (tensor<8xi16>) -> tensor<8xi16> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr> + %6 = tt.addptr %5, %0 : tensor<8x!tt.ptr>, tensor<8xi32> + tt.store %6, %4 : tensor<8x!tt.ptr> + tt.return + } +} + +// ----- +// CHECK: %[[INPUT_BUF:.*]] = memref.alloc() : memref<8xi16> +// CHECK: memref.copy {{.*}}, %[[INPUT_BUF]] : memref<8xi16{{.*}}> to memref<8xi16> + +// CHECK: %[[OUTPUT_BUF:.*]] = memref.alloc() : memref<8xi16> + +// Initialize first element +// CHECK: %[[FIRST_VAL:.*]] = memref.load %[[INPUT_BUF]][%c0] : memref<8xi16> +// CHECK: memref.store %[[FIRST_VAL]], %[[OUTPUT_BUF]][%c0] : memref<8xi16> + +// Main scan loop +// CHECK: scf.for %{{.*}} = %c1 to %c8 step %c1 { +// CHECK-NEXT: %[[PREV_IDX:.*]] = arith.subi %{{.*}}, %c1 : index +// CHECK-NEXT: %[[CURR_INPUT:.*]] = memref.load %[[INPUT_BUF]][%{{.*}}] : memref<8xi16> +// CHECK-NEXT: %[[PREV_OUTPUT:.*]] = memref.load %[[OUTPUT_BUF]][%[[PREV_IDX]]] : memref<8xi16> +// CHECK-NEXT: %[[COMBINED:.*]] = arith.maxui %[[PREV_OUTPUT]], %[[CURR_INPUT]] : i16 +// CHECK-NEXT: memref.store %[[COMBINED]], %[[OUTPUT_BUF]][%{{.*}}] : memref<8xi16> +// CHECK-NEXT: } + +// Final materialization +// CHECK: bufferization.materialize_in_destination + + +// === i32 u32 version === +module { + tt.func public @fn_npu_u32( + %arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32} + ) { + %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %1 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<8x!tt.ptr>, tensor<8xi32> + %3 = tt.load %2 : tensor<8x!tt.ptr> + %4 = "tt.scan"(%3) <{axis = 0 : i32, reverse = false}> ({ + ^bb0(%arg2: i32, %arg3: i32): + %7 = arith.maxui %arg2, %arg3 : i32 + tt.scan.return %7 : i32 + }) : (tensor<8xi32>) -> tensor<8xi32> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr> + %6 = tt.addptr %5, %0 : tensor<8x!tt.ptr>, tensor<8xi32> + tt.store %6, %4 : tensor<8x!tt.ptr> + tt.return + } +} + +// ----- +// CHECK: %[[INPUT_BUF:.*]] = memref.alloc() : memref<8xi32> +// CHECK: memref.copy {{.*}}, %[[INPUT_BUF]] : memref<8xi32{{.*}}> to memref<8xi32> + +// CHECK: %[[OUTPUT_BUF:.*]] = memref.alloc() : memref<8xi32> + +// Initialize first element +// CHECK: %[[FIRST_VAL:.*]] = memref.load %[[INPUT_BUF]][%c0] : memref<8xi32> +// CHECK: memref.store %[[FIRST_VAL]], %[[OUTPUT_BUF]][%c0] : memref<8xi32> + +// Main scan loop +// CHECK: scf.for %{{.*}} = %c1 to %c8 step %c1 { +// CHECK-NEXT: %[[PREV_IDX:.*]] = arith.subi %{{.*}}, %c1 : index +// CHECK-NEXT: %[[CURR_INPUT:.*]] = memref.load %[[INPUT_BUF]][%{{.*}}] : memref<8xi32> +// CHECK-NEXT: %[[PREV_OUTPUT:.*]] = memref.load %[[OUTPUT_BUF]][%[[PREV_IDX]]] : memref<8xi32> +// CHECK-NEXT: %[[COMBINED:.*]] = arith.maxui %[[PREV_OUTPUT]], %[[CURR_INPUT]] : i32 +// CHECK-NEXT: memref.store %[[COMBINED]], %[[OUTPUT_BUF]][%{{.*}}] : memref<8xi32> +// CHECK-NEXT: } + +// Final materialization +// CHECK: bufferization.materialize_in_destination + + +// === i64 u64 version === +module { + tt.func public @fn_npu_u64( + %arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32} + ) { + %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %1 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<8x!tt.ptr>, tensor<8xi32> + %3 = tt.load %2 : tensor<8x!tt.ptr> + %4 = "tt.scan"(%3) <{axis = 0 : i32, reverse = false}> ({ + ^bb0(%arg2: i64, %arg3: i64): + %7 = arith.maxui %arg2, %arg3 : i64 + tt.scan.return %7 : i64 + }) : (tensor<8xi64>) -> tensor<8xi64> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr> + %6 = tt.addptr %5, %0 : tensor<8x!tt.ptr>, tensor<8xi32> + tt.store %6, %4 : tensor<8x!tt.ptr> + tt.return + } +} + +// ----- +// CHECK: %[[INPUT_BUF:.*]] = memref.alloc() : memref<8xi64> +// CHECK: memref.copy {{.*}}, %[[INPUT_BUF]] : memref<8xi64{{.*}}> to memref<8xi64> + +// CHECK: %[[OUTPUT_BUF:.*]] = memref.alloc() : memref<8xi64> + +// Initialize first element +// CHECK: %[[FIRST_VAL:.*]] = memref.load %[[INPUT_BUF]][%c0] : memref<8xi64> +// CHECK: memref.store %[[FIRST_VAL]], %[[OUTPUT_BUF]][%c0] : memref<8xi64> + +// Main scan loop +// CHECK: scf.for %{{.*}} = %c1 to %c8 step %c1 { +// CHECK-NEXT: %[[PREV_IDX:.*]] = arith.subi %{{.*}}, %c1 : index +// CHECK-NEXT: %[[CURR_INPUT:.*]] = memref.load %[[INPUT_BUF]][%{{.*}}] : memref<8xi64> +// CHECK-NEXT: %[[PREV_OUTPUT:.*]] = memref.load %[[OUTPUT_BUF]][%[[PREV_IDX]]] : memref<8xi64> +// CHECK-NEXT: %[[COMBINED:.*]] = arith.maxui %[[PREV_OUTPUT]], %[[CURR_INPUT]] : i64 +// CHECK-NEXT: memref.store %[[COMBINED]], %[[OUTPUT_BUF]][%{{.*}}] : memref<8xi64> +// CHECK-NEXT: } + +// Final materialization // CHECK: bufferization.materialize_in_destination diff --git a/third_party/ascend/test/Conversion/TritonOp/cat.mlir b/third_party/ascend/test/Conversion/TritonOp/cat.mlir index 287a762fe..15206bd7e 100644 --- a/third_party/ascend/test/Conversion/TritonOp/cat.mlir +++ b/third_party/ascend/test/Conversion/TritonOp/cat.mlir @@ -1,204 +1,204 @@ -// RUN: triton-adapter-opt --triton-linearize --discrete-mask-access-conversion --triton-to-annotation --triton-to-unstructure --triton-to-hivm --triton-to-hfusion --triton-to-llvm --bubble-up-operation '--triton-to-linalg=global-kernel=false named-ops=True enable-nd2nz-on-vector=False' --split-input-file %s | FileCheck %s - -// === i8 u8 version === -tt.func public @fn_npu_i8(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, - %arg1: !tt.ptr {tt.divisibility = 16 : i32}, - %arg2: !tt.ptr {tt.divisibility = 16 : i32}) { - %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %1 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<32x!tt.ptr>, tensor<32xi32> - %3 = tt.load %2 : tensor<32x!tt.ptr> - %4 = tt.splat %arg2 : !tt.ptr -> tensor<32x!tt.ptr> - %5 = tt.addptr %4, %0 : tensor<32x!tt.ptr>, tensor<32xi32> - %6 = tt.load %5 : tensor<32x!tt.ptr> - %7 = tt.cat %3, %6 : tensor<32xi8> -> tensor<64xi8> - %8 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> - %9 = tt.splat %arg0 : !tt.ptr -> tensor<64x!tt.ptr> - %10 = tt.addptr %9, %8 : tensor<64x!tt.ptr>, tensor<64xi32> - tt.store %10, %7 : tensor<64x!tt.ptr> - tt.return -} - -// CHECK-LABEL: func.func @fn_npu_i8( -// CHECK-NOT: tt.cat -// CHECK: %[[CAST_IN1:.+]] = memref.reinterpret_cast %arg3 to offset: [0], sizes: [32] -// CHECK-SAME: memref to memref<32xi8 -// CHECK: %[[ALLOC1:.+]] = memref.alloc() : memref<32xi8> -// CHECK: memref.copy %[[CAST_IN1]], %[[ALLOC1]] -// CHECK: %[[TENSOR1:.+]] = bufferization.to_tensor %[[ALLOC1]] -// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<64xi8> -// CHECK: %[[SLICE0:.+]] = tensor.insert_slice %[[TENSOR1]] into %[[EMPTY]][0] [32] [1] : tensor<32xi8> into tensor<64xi8> -// CHECK: %[[SLICE1:.+]] = tensor.insert_slice {{.*}} into %[[SLICE0]][32] [32] [1] -// CHECK: %[[OUT_CAST:.+]] = memref.reinterpret_cast %arg2 to offset: [0], sizes: [64] -// CHECK-SAME: memref to memref<64xi8 -// CHECK: bufferization.materialize_in_destination %[[SLICE1]] in writable %[[OUT_CAST]] - - -// === i16 u16 version === -tt.func public @fn_npu_i16(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, - %arg1: !tt.ptr {tt.divisibility = 16 : i32}, - %arg2: !tt.ptr {tt.divisibility = 16 : i32}) { - %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %1 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<32x!tt.ptr>, tensor<32xi32> - %3 = tt.load %2 : tensor<32x!tt.ptr> - %4 = tt.splat %arg2 : !tt.ptr -> tensor<32x!tt.ptr> - %5 = tt.addptr %4, %0 : tensor<32x!tt.ptr>, tensor<32xi32> - %6 = tt.load %5 : tensor<32x!tt.ptr> - %7 = tt.cat %3, %6 : tensor<32xi16> -> tensor<64xi16> - %8 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> - %9 = tt.splat %arg0 : !tt.ptr -> tensor<64x!tt.ptr> - %10 = tt.addptr %9, %8 : tensor<64x!tt.ptr>, tensor<64xi32> - tt.store %10, %7 : tensor<64x!tt.ptr> - tt.return -} - -// CHECK-LABEL: func.func @fn_npu_i16( -// CHECK-NOT: tt.cat -// CHECK: %[[CAST_IN1:.+]] = memref.reinterpret_cast %arg3 to offset: [0], sizes: [32] -// CHECK-SAME: memref to memref<32xi16 -// CHECK: %[[ALLOC1:.+]] = memref.alloc() : memref<32xi16> -// CHECK: memref.copy %[[CAST_IN1]], %[[ALLOC1]] -// CHECK: %[[TENSOR1:.+]] = bufferization.to_tensor %[[ALLOC1]] -// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<64xi16> -// CHECK: %[[SLICE0:.+]] = tensor.insert_slice %[[TENSOR1]] into %[[EMPTY]][0] [32] [1] : tensor<32xi16> into tensor<64xi16> -// CHECK: %[[SLICE1:.+]] = tensor.insert_slice {{.*}} into %[[SLICE0]][32] [32] [1] -// CHECK: %[[OUT_CAST:.+]] = memref.reinterpret_cast %arg2 to offset: [0], sizes: [64] -// CHECK-SAME: memref to memref<64xi16 -// CHECK: bufferization.materialize_in_destination %[[SLICE1]] in writable %[[OUT_CAST]] - - -// === i32 u32 version === -tt.func public @fn_npu_i32(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, - %arg1: !tt.ptr {tt.divisibility = 16 : i32}, - %arg2: !tt.ptr {tt.divisibility = 16 : i32}) { - %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %1 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<32x!tt.ptr>, tensor<32xi32> - %3 = tt.load %2 : tensor<32x!tt.ptr> - %4 = tt.splat %arg2 : !tt.ptr -> tensor<32x!tt.ptr> - %5 = tt.addptr %4, %0 : tensor<32x!tt.ptr>, tensor<32xi32> - %6 = tt.load %5 : tensor<32x!tt.ptr> - %7 = tt.cat %3, %6 : tensor<32xi32> -> tensor<64xi32> - %8 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> - %9 = tt.splat %arg0 : !tt.ptr -> tensor<64x!tt.ptr> - %10 = tt.addptr %9, %8 : tensor<64x!tt.ptr>, tensor<64xi32> - tt.store %10, %7 : tensor<64x!tt.ptr> - tt.return -} - -// CHECK-LABEL: func.func @fn_npu_i32( -// CHECK-NOT: tt.cat -// CHECK: %[[CAST_IN1:.+]] = memref.reinterpret_cast %arg3 to offset: [0], sizes: [32] -// CHECK-SAME: memref to memref<32xi32 -// CHECK: %[[ALLOC1:.+]] = memref.alloc() : memref<32xi32> -// CHECK: memref.copy %[[CAST_IN1]], %[[ALLOC1]] -// CHECK: %[[TENSOR1:.+]] = bufferization.to_tensor %[[ALLOC1]] -// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<64xi32> -// CHECK: %[[SLICE0:.+]] = tensor.insert_slice %[[TENSOR1]] into %[[EMPTY]][0] [32] [1] : tensor<32xi32> into tensor<64xi32> -// CHECK: %[[SLICE1:.+]] = tensor.insert_slice {{.*}} into %[[SLICE0]][32] [32] [1] -// CHECK: %[[OUT_CAST:.+]] = memref.reinterpret_cast %arg2 to offset: [0], sizes: [64] -// CHECK-SAME: memref to memref<64xi32 -// CHECK: bufferization.materialize_in_destination %[[SLICE1]] in writable %[[OUT_CAST]] - - -// === i64 u64 version === -tt.func public @fn_npu_i64(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, - %arg1: !tt.ptr {tt.divisibility = 16 : i32}, - %arg2: !tt.ptr {tt.divisibility = 16 : i32}) { - %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %1 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<32x!tt.ptr>, tensor<32xi32> - %3 = tt.load %2 : tensor<32x!tt.ptr> - %4 = tt.splat %arg2 : !tt.ptr -> tensor<32x!tt.ptr> - %5 = tt.addptr %4, %0 : tensor<32x!tt.ptr>, tensor<32xi32> - %6 = tt.load %5 : tensor<32x!tt.ptr> - %7 = tt.cat %3, %6 : tensor<32xi64> -> tensor<64xi64> - %8 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> - %9 = tt.splat %arg0 : !tt.ptr -> tensor<64x!tt.ptr> - %10 = tt.addptr %9, %8 : tensor<64x!tt.ptr>, tensor<64xi32> - tt.store %10, %7 : tensor<64x!tt.ptr> - tt.return -} - -// CHECK-LABEL: func.func @fn_npu_i64( -// CHECK-NOT: tt.cat -// CHECK: %[[CAST_IN1:.+]] = memref.reinterpret_cast %arg3 to offset: [0], sizes: [32] -// CHECK-SAME: memref to memref<32xi64 -// CHECK: %[[ALLOC1:.+]] = memref.alloc() : memref<32xi64> -// CHECK: memref.copy %[[CAST_IN1]], %[[ALLOC1]] -// CHECK: %[[TENSOR1:.+]] = bufferization.to_tensor %[[ALLOC1]] -// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<64xi64> -// CHECK: %[[SLICE0:.+]] = tensor.insert_slice %[[TENSOR1]] into %[[EMPTY]][0] [32] [1] : tensor<32xi64> into tensor<64xi64> -// CHECK: %[[SLICE1:.+]] = tensor.insert_slice {{.*}} into %[[SLICE0]][32] [32] [1] -// CHECK: %[[OUT_CAST:.+]] = memref.reinterpret_cast %arg2 to offset: [0], sizes: [64] -// CHECK-SAME: memref to memref<64xi64 -// CHECK: bufferization.materialize_in_destination %[[SLICE1]] in writable %[[OUT_CAST]] - - -// === float8_e4m3fn version === -tt.func public @fn_npu_f8E4M3FN(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, - %arg1: !tt.ptr {tt.divisibility = 16 : i32}, - %arg2: !tt.ptr {tt.divisibility = 16 : i32}) { - %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %1 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<32x!tt.ptr>, tensor<32xi32> - %3 = tt.load %2 : tensor<32x!tt.ptr> - %4 = tt.splat %arg2 : !tt.ptr -> tensor<32x!tt.ptr> - %5 = tt.addptr %4, %0 : tensor<32x!tt.ptr>, tensor<32xi32> - %6 = tt.load %5 : tensor<32x!tt.ptr> - %7 = tt.cat %3, %6 : tensor<32xf8E4M3FN> -> tensor<64xf8E4M3FN> - %8 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> - %9 = tt.splat %arg0 : !tt.ptr -> tensor<64x!tt.ptr> - %10 = tt.addptr %9, %8 : tensor<64x!tt.ptr>, tensor<64xi32> - tt.store %10, %7 : tensor<64x!tt.ptr> - tt.return -} - -// CHECK-LABEL: func.func @fn_npu_f8E4M3FN( -// CHECK-NOT: tt.cat -// CHECK: %[[CAST_IN1:.+]] = memref.reinterpret_cast %arg3 to offset: [0], sizes: [32] -// CHECK-SAME: memref to memref<32xf8E4M3FN -// CHECK: %[[ALLOC1:.+]] = memref.alloc() : memref<32xf8E4M3FN> -// CHECK: memref.copy %[[CAST_IN1]], %[[ALLOC1]] -// CHECK: %[[TENSOR1:.+]] = bufferization.to_tensor %[[ALLOC1]] -// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<64xf8E4M3FN> -// CHECK: %[[SLICE0:.+]] = tensor.insert_slice %[[TENSOR1]] into %[[EMPTY]][0] [32] [1] : tensor<32xf8E4M3FN> into tensor<64xf8E4M3FN> -// CHECK: %[[SLICE1:.+]] = tensor.insert_slice {{.*}} into %[[SLICE0]][32] [32] [1] -// CHECK: %[[OUT_CAST:.+]] = memref.reinterpret_cast %arg2 to offset: [0], sizes: [64] -// CHECK-SAME: memref to memref<64xf8E4M3FN -// CHECK: bufferization.materialize_in_destination %[[SLICE1]] in writable %[[OUT_CAST]] - - -// === float8_e5m2 version === -tt.func public @fn_npu_f8E5M2(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, - %arg1: !tt.ptr {tt.divisibility = 16 : i32}, - %arg2: !tt.ptr {tt.divisibility = 16 : i32}) { - %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> - %1 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<32x!tt.ptr>, tensor<32xi32> - %3 = tt.load %2 : tensor<32x!tt.ptr> - %4 = tt.splat %arg2 : !tt.ptr -> tensor<32x!tt.ptr> - %5 = tt.addptr %4, %0 : tensor<32x!tt.ptr>, tensor<32xi32> - %6 = tt.load %5 : tensor<32x!tt.ptr> - %7 = tt.cat %3, %6 : tensor<32xf8E5M2> -> tensor<64xf8E5M2> - %8 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> - %9 = tt.splat %arg0 : !tt.ptr -> tensor<64x!tt.ptr> - %10 = tt.addptr %9, %8 : tensor<64x!tt.ptr>, tensor<64xi32> - tt.store %10, %7 : tensor<64x!tt.ptr> - tt.return -} - -// CHECK-LABEL: func.func @fn_npu_f8E5M2( -// CHECK-NOT: tt.cat -// CHECK: %[[CAST_IN1:.+]] = memref.reinterpret_cast %arg3 to offset: [0], sizes: [32] -// CHECK-SAME: memref to memref<32xf8E5M2 -// CHECK: %[[ALLOC1:.+]] = memref.alloc() : memref<32xf8E5M2> -// CHECK: memref.copy %[[CAST_IN1]], %[[ALLOC1]] -// CHECK: %[[TENSOR1:.+]] = bufferization.to_tensor %[[ALLOC1]] -// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<64xf8E5M2> -// CHECK: %[[SLICE0:.+]] = tensor.insert_slice %[[TENSOR1]] into %[[EMPTY]][0] [32] [1] : tensor<32xf8E5M2> into tensor<64xf8E5M2> -// CHECK: %[[SLICE1:.+]] = tensor.insert_slice {{.*}} into %[[SLICE0]][32] [32] [1] -// CHECK: %[[OUT_CAST:.+]] = memref.reinterpret_cast %arg2 to offset: [0], sizes: [64] -// CHECK-SAME: memref to memref<64xf8E5M2 +// RUN: triton-adapter-opt --triton-linearize --discrete-mask-access-conversion --triton-to-annotation --triton-to-unstructure --triton-to-hivm --triton-to-hfusion --triton-to-llvm --bubble-up-operation '--triton-to-linalg=global-kernel=false named-ops=True enable-nd2nz-on-vector=False' --split-input-file %s | FileCheck %s + +// === i8 u8 version === +tt.func public @fn_npu_i8(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32}, + %arg2: !tt.ptr {tt.divisibility = 16 : i32}) { + %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %1 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<32x!tt.ptr>, tensor<32xi32> + %3 = tt.load %2 : tensor<32x!tt.ptr> + %4 = tt.splat %arg2 : !tt.ptr -> tensor<32x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<32x!tt.ptr>, tensor<32xi32> + %6 = tt.load %5 : tensor<32x!tt.ptr> + %7 = tt.cat %3, %6 : tensor<32xi8> -> tensor<64xi8> + %8 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %9 = tt.splat %arg0 : !tt.ptr -> tensor<64x!tt.ptr> + %10 = tt.addptr %9, %8 : tensor<64x!tt.ptr>, tensor<64xi32> + tt.store %10, %7 : tensor<64x!tt.ptr> + tt.return +} + +// CHECK-LABEL: func.func @fn_npu_i8( +// CHECK-NOT: tt.cat +// CHECK: %[[CAST_IN1:.+]] = memref.reinterpret_cast %arg3 to offset: [0], sizes: [32] +// CHECK-SAME: memref to memref<32xi8 +// CHECK: %[[ALLOC1:.+]] = memref.alloc() : memref<32xi8> +// CHECK: memref.copy %[[CAST_IN1]], %[[ALLOC1]] +// CHECK: %[[TENSOR1:.+]] = bufferization.to_tensor %[[ALLOC1]] +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<64xi8> +// CHECK: %[[SLICE0:.+]] = tensor.insert_slice %[[TENSOR1]] into %[[EMPTY]][0] [32] [1] : tensor<32xi8> into tensor<64xi8> +// CHECK: %[[SLICE1:.+]] = tensor.insert_slice {{.*}} into %[[SLICE0]][32] [32] [1] +// CHECK: %[[OUT_CAST:.+]] = memref.reinterpret_cast %arg2 to offset: [0], sizes: [64] +// CHECK-SAME: memref to memref<64xi8 +// CHECK: bufferization.materialize_in_destination %[[SLICE1]] in writable %[[OUT_CAST]] + + +// === i16 u16 version === +tt.func public @fn_npu_i16(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32}, + %arg2: !tt.ptr {tt.divisibility = 16 : i32}) { + %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %1 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<32x!tt.ptr>, tensor<32xi32> + %3 = tt.load %2 : tensor<32x!tt.ptr> + %4 = tt.splat %arg2 : !tt.ptr -> tensor<32x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<32x!tt.ptr>, tensor<32xi32> + %6 = tt.load %5 : tensor<32x!tt.ptr> + %7 = tt.cat %3, %6 : tensor<32xi16> -> tensor<64xi16> + %8 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %9 = tt.splat %arg0 : !tt.ptr -> tensor<64x!tt.ptr> + %10 = tt.addptr %9, %8 : tensor<64x!tt.ptr>, tensor<64xi32> + tt.store %10, %7 : tensor<64x!tt.ptr> + tt.return +} + +// CHECK-LABEL: func.func @fn_npu_i16( +// CHECK-NOT: tt.cat +// CHECK: %[[CAST_IN1:.+]] = memref.reinterpret_cast %arg3 to offset: [0], sizes: [32] +// CHECK-SAME: memref to memref<32xi16 +// CHECK: %[[ALLOC1:.+]] = memref.alloc() : memref<32xi16> +// CHECK: memref.copy %[[CAST_IN1]], %[[ALLOC1]] +// CHECK: %[[TENSOR1:.+]] = bufferization.to_tensor %[[ALLOC1]] +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<64xi16> +// CHECK: %[[SLICE0:.+]] = tensor.insert_slice %[[TENSOR1]] into %[[EMPTY]][0] [32] [1] : tensor<32xi16> into tensor<64xi16> +// CHECK: %[[SLICE1:.+]] = tensor.insert_slice {{.*}} into %[[SLICE0]][32] [32] [1] +// CHECK: %[[OUT_CAST:.+]] = memref.reinterpret_cast %arg2 to offset: [0], sizes: [64] +// CHECK-SAME: memref to memref<64xi16 +// CHECK: bufferization.materialize_in_destination %[[SLICE1]] in writable %[[OUT_CAST]] + + +// === i32 u32 version === +tt.func public @fn_npu_i32(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32}, + %arg2: !tt.ptr {tt.divisibility = 16 : i32}) { + %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %1 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<32x!tt.ptr>, tensor<32xi32> + %3 = tt.load %2 : tensor<32x!tt.ptr> + %4 = tt.splat %arg2 : !tt.ptr -> tensor<32x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<32x!tt.ptr>, tensor<32xi32> + %6 = tt.load %5 : tensor<32x!tt.ptr> + %7 = tt.cat %3, %6 : tensor<32xi32> -> tensor<64xi32> + %8 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %9 = tt.splat %arg0 : !tt.ptr -> tensor<64x!tt.ptr> + %10 = tt.addptr %9, %8 : tensor<64x!tt.ptr>, tensor<64xi32> + tt.store %10, %7 : tensor<64x!tt.ptr> + tt.return +} + +// CHECK-LABEL: func.func @fn_npu_i32( +// CHECK-NOT: tt.cat +// CHECK: %[[CAST_IN1:.+]] = memref.reinterpret_cast %arg3 to offset: [0], sizes: [32] +// CHECK-SAME: memref to memref<32xi32 +// CHECK: %[[ALLOC1:.+]] = memref.alloc() : memref<32xi32> +// CHECK: memref.copy %[[CAST_IN1]], %[[ALLOC1]] +// CHECK: %[[TENSOR1:.+]] = bufferization.to_tensor %[[ALLOC1]] +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<64xi32> +// CHECK: %[[SLICE0:.+]] = tensor.insert_slice %[[TENSOR1]] into %[[EMPTY]][0] [32] [1] : tensor<32xi32> into tensor<64xi32> +// CHECK: %[[SLICE1:.+]] = tensor.insert_slice {{.*}} into %[[SLICE0]][32] [32] [1] +// CHECK: %[[OUT_CAST:.+]] = memref.reinterpret_cast %arg2 to offset: [0], sizes: [64] +// CHECK-SAME: memref to memref<64xi32 +// CHECK: bufferization.materialize_in_destination %[[SLICE1]] in writable %[[OUT_CAST]] + + +// === i64 u64 version === +tt.func public @fn_npu_i64(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32}, + %arg2: !tt.ptr {tt.divisibility = 16 : i32}) { + %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %1 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<32x!tt.ptr>, tensor<32xi32> + %3 = tt.load %2 : tensor<32x!tt.ptr> + %4 = tt.splat %arg2 : !tt.ptr -> tensor<32x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<32x!tt.ptr>, tensor<32xi32> + %6 = tt.load %5 : tensor<32x!tt.ptr> + %7 = tt.cat %3, %6 : tensor<32xi64> -> tensor<64xi64> + %8 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %9 = tt.splat %arg0 : !tt.ptr -> tensor<64x!tt.ptr> + %10 = tt.addptr %9, %8 : tensor<64x!tt.ptr>, tensor<64xi32> + tt.store %10, %7 : tensor<64x!tt.ptr> + tt.return +} + +// CHECK-LABEL: func.func @fn_npu_i64( +// CHECK-NOT: tt.cat +// CHECK: %[[CAST_IN1:.+]] = memref.reinterpret_cast %arg3 to offset: [0], sizes: [32] +// CHECK-SAME: memref to memref<32xi64 +// CHECK: %[[ALLOC1:.+]] = memref.alloc() : memref<32xi64> +// CHECK: memref.copy %[[CAST_IN1]], %[[ALLOC1]] +// CHECK: %[[TENSOR1:.+]] = bufferization.to_tensor %[[ALLOC1]] +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<64xi64> +// CHECK: %[[SLICE0:.+]] = tensor.insert_slice %[[TENSOR1]] into %[[EMPTY]][0] [32] [1] : tensor<32xi64> into tensor<64xi64> +// CHECK: %[[SLICE1:.+]] = tensor.insert_slice {{.*}} into %[[SLICE0]][32] [32] [1] +// CHECK: %[[OUT_CAST:.+]] = memref.reinterpret_cast %arg2 to offset: [0], sizes: [64] +// CHECK-SAME: memref to memref<64xi64 +// CHECK: bufferization.materialize_in_destination %[[SLICE1]] in writable %[[OUT_CAST]] + + +// === float8_e4m3fn version === +tt.func public @fn_npu_f8E4M3FN(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32}, + %arg2: !tt.ptr {tt.divisibility = 16 : i32}) { + %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %1 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<32x!tt.ptr>, tensor<32xi32> + %3 = tt.load %2 : tensor<32x!tt.ptr> + %4 = tt.splat %arg2 : !tt.ptr -> tensor<32x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<32x!tt.ptr>, tensor<32xi32> + %6 = tt.load %5 : tensor<32x!tt.ptr> + %7 = tt.cat %3, %6 : tensor<32xf8E4M3FN> -> tensor<64xf8E4M3FN> + %8 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %9 = tt.splat %arg0 : !tt.ptr -> tensor<64x!tt.ptr> + %10 = tt.addptr %9, %8 : tensor<64x!tt.ptr>, tensor<64xi32> + tt.store %10, %7 : tensor<64x!tt.ptr> + tt.return +} + +// CHECK-LABEL: func.func @fn_npu_f8E4M3FN( +// CHECK-NOT: tt.cat +// CHECK: %[[CAST_IN1:.+]] = memref.reinterpret_cast %arg3 to offset: [0], sizes: [32] +// CHECK-SAME: memref to memref<32xf8E4M3FN +// CHECK: %[[ALLOC1:.+]] = memref.alloc() : memref<32xf8E4M3FN> +// CHECK: memref.copy %[[CAST_IN1]], %[[ALLOC1]] +// CHECK: %[[TENSOR1:.+]] = bufferization.to_tensor %[[ALLOC1]] +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<64xf8E4M3FN> +// CHECK: %[[SLICE0:.+]] = tensor.insert_slice %[[TENSOR1]] into %[[EMPTY]][0] [32] [1] : tensor<32xf8E4M3FN> into tensor<64xf8E4M3FN> +// CHECK: %[[SLICE1:.+]] = tensor.insert_slice {{.*}} into %[[SLICE0]][32] [32] [1] +// CHECK: %[[OUT_CAST:.+]] = memref.reinterpret_cast %arg2 to offset: [0], sizes: [64] +// CHECK-SAME: memref to memref<64xf8E4M3FN +// CHECK: bufferization.materialize_in_destination %[[SLICE1]] in writable %[[OUT_CAST]] + + +// === float8_e5m2 version === +tt.func public @fn_npu_f8E5M2(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32}, + %arg2: !tt.ptr {tt.divisibility = 16 : i32}) { + %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %1 = tt.splat %arg1 : !tt.ptr -> tensor<32x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<32x!tt.ptr>, tensor<32xi32> + %3 = tt.load %2 : tensor<32x!tt.ptr> + %4 = tt.splat %arg2 : !tt.ptr -> tensor<32x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<32x!tt.ptr>, tensor<32xi32> + %6 = tt.load %5 : tensor<32x!tt.ptr> + %7 = tt.cat %3, %6 : tensor<32xf8E5M2> -> tensor<64xf8E5M2> + %8 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> + %9 = tt.splat %arg0 : !tt.ptr -> tensor<64x!tt.ptr> + %10 = tt.addptr %9, %8 : tensor<64x!tt.ptr>, tensor<64xi32> + tt.store %10, %7 : tensor<64x!tt.ptr> + tt.return +} + +// CHECK-LABEL: func.func @fn_npu_f8E5M2( +// CHECK-NOT: tt.cat +// CHECK: %[[CAST_IN1:.+]] = memref.reinterpret_cast %arg3 to offset: [0], sizes: [32] +// CHECK-SAME: memref to memref<32xf8E5M2 +// CHECK: %[[ALLOC1:.+]] = memref.alloc() : memref<32xf8E5M2> +// CHECK: memref.copy %[[CAST_IN1]], %[[ALLOC1]] +// CHECK: %[[TENSOR1:.+]] = bufferization.to_tensor %[[ALLOC1]] +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<64xf8E5M2> +// CHECK: %[[SLICE0:.+]] = tensor.insert_slice %[[TENSOR1]] into %[[EMPTY]][0] [32] [1] : tensor<32xf8E5M2> into tensor<64xf8E5M2> +// CHECK: %[[SLICE1:.+]] = tensor.insert_slice {{.*}} into %[[SLICE0]][32] [32] [1] +// CHECK: %[[OUT_CAST:.+]] = memref.reinterpret_cast %arg2 to offset: [0], sizes: [64] +// CHECK-SAME: memref to memref<64xf8E5M2 // CHECK: bufferization.materialize_in_destination %[[SLICE1]] in writable %[[OUT_CAST]] diff --git a/third_party/ascend/test/Conversion/TritonOp/compiler_hint.mlir b/third_party/ascend/test/Conversion/TritonOp/compiler_hint.mlir index d846dccfd..d8c246e1e 100644 --- a/third_party/ascend/test/Conversion/TritonOp/compiler_hint.mlir +++ b/third_party/ascend/test/Conversion/TritonOp/compiler_hint.mlir @@ -1,231 +1,231 @@ -// RUN: triton-adapter-opt --triton-linearize --discrete-mask-access-conversion --triton-to-annotation '--triton-to-unstructure=compile-on-910-95=False force-simt-template=False' --triton-to-hivm --triton-to-hfusion --triton-to-llvm --bubble-up-operation '--triton-to-linalg=global-kernel=false named-ops=True enable-nd2nz-on-vector=False compile-on-910-95=False' %s | FileCheck %s - -module { - tt.func public @compile_hint_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { - %true = arith.constant true - // CHECK: "llvm.intr.assume"(%true) : (i1) -> () - "llvm.intr.assume"(%true) : (i1) -> () - %0 = tt.make_range {end = 10 : i32, start = 0 : i32} : tensor<10xi32> - %1 = tt.splat %arg0 : !tt.ptr -> tensor<10x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<10x!tt.ptr>, tensor<10xi32> - // CHECK-NOT:tt.constancy - // CHECK-NOT:tt.contiguity - // CHECK-NOT:tt.divisibility - %3 = tt.load %2 {tt.constancy = dense<1> : tensor<1xi32>, tt.contiguity = dense<1> : tensor<1xi32>, tt.divisibility = dense<1> : tensor<1xi32>} : tensor<10x!tt.ptr> - // CHECK: gpu.barrier - gpu.barrier - %4 = tt.splat %arg1 : !tt.ptr -> tensor<10x!tt.ptr> - %5 = tt.addptr %4, %0 : tensor<10x!tt.ptr>, tensor<10xi32> - tt.store %5, %3 : tensor<10x!tt.ptr> - tt.return - } -} - -// ----- - -module { - tt.func public @compile_hint_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { - %true = arith.constant true - // CHECK: "llvm.intr.assume"(%true) : (i1) -> () - "llvm.intr.assume"(%true) : (i1) -> () - %0 = tt.make_range {end = 10 : i32, start = 0 : i32} : tensor<10xi32> - %1 = tt.splat %arg0 : !tt.ptr -> tensor<10x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<10x!tt.ptr>, tensor<10xi32> - // CHECK-NOT:tt.constancy - // CHECK-NOT:tt.contiguity - // CHECK-NOT:tt.divisibility - %3 = tt.load %2 {tt.constancy = dense<1> : tensor<1xi32>, tt.contiguity = dense<1> : tensor<1xi32>, tt.divisibility = dense<1> : tensor<1xi32>} : tensor<10x!tt.ptr> - // CHECK: gpu.barrier - gpu.barrier - %4 = tt.splat %arg1 : !tt.ptr -> tensor<10x!tt.ptr> - %5 = tt.addptr %4, %0 : tensor<10x!tt.ptr>, tensor<10xi32> - tt.store %5, %3 : tensor<10x!tt.ptr> - tt.return - } -} - -// ----- - -module { - tt.func public @compile_hint_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { - %true = arith.constant true - // CHECK: "llvm.intr.assume"(%true) : (i1) -> () - "llvm.intr.assume"(%true) : (i1) -> () - %0 = tt.make_range {end = 10 : i32, start = 0 : i32} : tensor<10xi32> - %1 = tt.splat %arg0 : !tt.ptr -> tensor<10x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<10x!tt.ptr>, tensor<10xi32> - // CHECK-NOT:tt.constancy - // CHECK-NOT:tt.contiguity - // CHECK-NOT:tt.divisibility - %3 = tt.load %2 {tt.constancy = dense<1> : tensor<1xi32>, tt.contiguity = dense<1> : tensor<1xi32>, tt.divisibility = dense<1> : tensor<1xi32>} : tensor<10x!tt.ptr> - // CHECK: gpu.barrier - gpu.barrier - %4 = tt.splat %arg1 : !tt.ptr -> tensor<10x!tt.ptr> - %5 = tt.addptr %4, %0 : tensor<10x!tt.ptr>, tensor<10xi32> - tt.store %5, %3 : tensor<10x!tt.ptr> - tt.return - } -} - -// ----- - -module { - tt.func public @compile_hint_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { - %true = arith.constant true - // CHECK: "llvm.intr.assume"(%true) : (i1) -> () - "llvm.intr.assume"(%true) : (i1) -> () - %0 = tt.make_range {end = 10 : i32, start = 0 : i32} : tensor<10xi32> - %1 = tt.splat %arg0 : !tt.ptr -> tensor<10x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<10x!tt.ptr>, tensor<10xi32> - // CHECK-NOT:tt.constancy - // CHECK-NOT:tt.contiguity - // CHECK-NOT:tt.divisibility - %3 = tt.load %2 {tt.constancy = dense<1> : tensor<1xi32>, tt.contiguity = dense<1> : tensor<1xi32>, tt.divisibility = dense<1> : tensor<1xi32>} : tensor<10x!tt.ptr> - // CHECK: gpu.barrier - gpu.barrier - %4 = tt.splat %arg1 : !tt.ptr -> tensor<10x!tt.ptr> - %5 = tt.addptr %4, %0 : tensor<10x!tt.ptr>, tensor<10xi32> - tt.store %5, %3 : tensor<10x!tt.ptr> - tt.return - } -} - -// ----- - -module { - tt.func public @compile_hint_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { - %true = arith.constant true - // CHECK: "llvm.intr.assume"(%true) : (i1) -> () - "llvm.intr.assume"(%true) : (i1) -> () - %0 = tt.make_range {end = 10 : i32, start = 0 : i32} : tensor<10xi32> - %1 = tt.splat %arg0 : !tt.ptr -> tensor<10x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<10x!tt.ptr>, tensor<10xi32> - // CHECK-NOT:tt.constancy - // CHECK-NOT:tt.contiguity - // CHECK-NOT:tt.divisibility - %3 = tt.load %2 {tt.constancy = dense<1> : tensor<1xi32>, tt.contiguity = dense<1> : tensor<1xi32>, tt.divisibility = dense<1> : tensor<1xi32>} : tensor<10x!tt.ptr> - // CHECK: gpu.barrier - gpu.barrier - %4 = tt.splat %arg1 : !tt.ptr -> tensor<10x!tt.ptr> - %5 = tt.addptr %4, %0 : tensor<10x!tt.ptr>, tensor<10xi32> - tt.store %5, %3 : tensor<10x!tt.ptr> - tt.return - } -} - -// ----- - -module { - tt.func public @compile_hint_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { - %true = arith.constant true - // CHECK: "llvm.intr.assume"(%true) : (i1) -> () - "llvm.intr.assume"(%true) : (i1) -> () - %0 = tt.make_range {end = 10 : i32, start = 0 : i32} : tensor<10xi32> - %1 = tt.splat %arg0 : !tt.ptr -> tensor<10x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<10x!tt.ptr>, tensor<10xi32> - // CHECK-NOT:tt.constancy - // CHECK-NOT:tt.contiguity - // CHECK-NOT:tt.divisibility - %3 = tt.load %2 {tt.constancy = dense<1> : tensor<1xi32>, tt.contiguity = dense<1> : tensor<1xi32>, tt.divisibility = dense<1> : tensor<1xi32>} : tensor<10x!tt.ptr> - // CHECK: gpu.barrier - gpu.barrier - %4 = tt.splat %arg1 : !tt.ptr -> tensor<10x!tt.ptr> - %5 = tt.addptr %4, %0 : tensor<10x!tt.ptr>, tensor<10xi32> - tt.store %5, %3 : tensor<10x!tt.ptr> - tt.return - } -} - -// ----- - -module { - tt.func public @compile_hint_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { - %true = arith.constant true - // CHECK: "llvm.intr.assume"(%true) : (i1) -> () - "llvm.intr.assume"(%true) : (i1) -> () - %0 = tt.make_range {end = 10 : i32, start = 0 : i32} : tensor<10xi32> - %1 = tt.splat %arg0 : !tt.ptr -> tensor<10x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<10x!tt.ptr>, tensor<10xi32> - // CHECK-NOT:tt.constancy - // CHECK-NOT:tt.contiguity - // CHECK-NOT:tt.divisibility - %3 = tt.load %2 {tt.constancy = dense<1> : tensor<1xi32>, tt.contiguity = dense<1> : tensor<1xi32>, tt.divisibility = dense<1> : tensor<1xi32>} : tensor<10x!tt.ptr> - // CHECK: gpu.barrier - gpu.barrier - %4 = tt.splat %arg1 : !tt.ptr -> tensor<10x!tt.ptr> - %5 = tt.addptr %4, %0 : tensor<10x!tt.ptr>, tensor<10xi32> - tt.store %5, %3 : tensor<10x!tt.ptr> - tt.return - } -} - -// ----- - -module { - tt.func public @compile_hint_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { - %true = arith.constant true - // CHECK: "llvm.intr.assume"(%true) : (i1) -> () - "llvm.intr.assume"(%true) : (i1) -> () - %0 = tt.make_range {end = 10 : i32, start = 0 : i32} : tensor<10xi32> - %1 = tt.splat %arg0 : !tt.ptr -> tensor<10x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<10x!tt.ptr>, tensor<10xi32> - // CHECK-NOT:tt.constancy - // CHECK-NOT:tt.contiguity - // CHECK-NOT:tt.divisibility - %3 = tt.load %2 {tt.constancy = dense<1> : tensor<1xi32>, tt.contiguity = dense<1> : tensor<1xi32>, tt.divisibility = dense<1> : tensor<1xi32>} : tensor<10x!tt.ptr> - // CHECK: gpu.barrier - gpu.barrier - %4 = tt.splat %arg1 : !tt.ptr -> tensor<10x!tt.ptr> - %5 = tt.addptr %4, %0 : tensor<10x!tt.ptr>, tensor<10xi32> - tt.store %5, %3 : tensor<10x!tt.ptr> - tt.return - } -} - -// ----- - -module { - tt.func public @compile_hint_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { - %true = arith.constant true - // CHECK: "llvm.intr.assume"(%true) : (i1) -> () - "llvm.intr.assume"(%true) : (i1) -> () - %0 = tt.make_range {end = 10 : i32, start = 0 : i32} : tensor<10xi32> - %1 = tt.splat %arg0 : !tt.ptr -> tensor<10x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<10x!tt.ptr>, tensor<10xi32> - // CHECK-NOT:tt.constancy - // CHECK-NOT:tt.contiguity - // CHECK-NOT:tt.divisibility - %3 = tt.load %2 {tt.constancy = dense<1> : tensor<1xi32>, tt.contiguity = dense<1> : tensor<1xi32>, tt.divisibility = dense<1> : tensor<1xi32>} : tensor<10x!tt.ptr> - // CHECK: gpu.barrier - gpu.barrier - %4 = tt.splat %arg1 : !tt.ptr -> tensor<10x!tt.ptr> - %5 = tt.addptr %4, %0 : tensor<10x!tt.ptr>, tensor<10xi32> - tt.store %5, %3 : tensor<10x!tt.ptr> - tt.return - } -} - -// ----- - -module { - tt.func public @compile_hint_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { - %true = arith.constant true - // CHECK: "llvm.intr.assume"(%true) : (i1) -> () - "llvm.intr.assume"(%true) : (i1) -> () - %0 = tt.make_range {end = 10 : i32, start = 0 : i32} : tensor<10xi32> - %1 = tt.splat %arg0 : !tt.ptr -> tensor<10x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<10x!tt.ptr>, tensor<10xi32> - %3 = tt.bitcast %2 : tensor<10x!tt.ptr> -> tensor<10x!tt.ptr> - // CHECK-NOT:tt.constancy - // CHECK-NOT:tt.contiguity - // CHECK-NOT:tt.divisibility - %4 = tt.load %3 {tt.constancy = dense<1> : tensor<1xi32>, tt.contiguity = dense<1> : tensor<1xi32>, tt.divisibility = dense<1> : tensor<1xi32>} : tensor<10x!tt.ptr> - // CHECK: gpu.barrier - gpu.barrier - %5 = tt.splat %arg1 : !tt.ptr -> tensor<10x!tt.ptr> - %6 = tt.addptr %5, %0 : tensor<10x!tt.ptr>, tensor<10xi32> - %7 = tt.bitcast %6 : tensor<10x!tt.ptr> -> tensor<10x!tt.ptr> - tt.store %7, %4 : tensor<10x!tt.ptr> - tt.return - } +// RUN: triton-adapter-opt --triton-linearize --discrete-mask-access-conversion --triton-to-annotation '--triton-to-unstructure=compile-on-910-95=False force-simt-template=False' --triton-to-hivm --triton-to-hfusion --triton-to-llvm --bubble-up-operation '--triton-to-linalg=global-kernel=false named-ops=True enable-nd2nz-on-vector=False compile-on-910-95=False' %s | FileCheck %s + +module { + tt.func public @compile_hint_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %true = arith.constant true + // CHECK: "llvm.intr.assume"(%true) : (i1) -> () + "llvm.intr.assume"(%true) : (i1) -> () + %0 = tt.make_range {end = 10 : i32, start = 0 : i32} : tensor<10xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<10x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<10x!tt.ptr>, tensor<10xi32> + // CHECK-NOT:tt.constancy + // CHECK-NOT:tt.contiguity + // CHECK-NOT:tt.divisibility + %3 = tt.load %2 {tt.constancy = dense<1> : tensor<1xi32>, tt.contiguity = dense<1> : tensor<1xi32>, tt.divisibility = dense<1> : tensor<1xi32>} : tensor<10x!tt.ptr> + // CHECK: gpu.barrier + gpu.barrier + %4 = tt.splat %arg1 : !tt.ptr -> tensor<10x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<10x!tt.ptr>, tensor<10xi32> + tt.store %5, %3 : tensor<10x!tt.ptr> + tt.return + } +} + +// ----- + +module { + tt.func public @compile_hint_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %true = arith.constant true + // CHECK: "llvm.intr.assume"(%true) : (i1) -> () + "llvm.intr.assume"(%true) : (i1) -> () + %0 = tt.make_range {end = 10 : i32, start = 0 : i32} : tensor<10xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<10x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<10x!tt.ptr>, tensor<10xi32> + // CHECK-NOT:tt.constancy + // CHECK-NOT:tt.contiguity + // CHECK-NOT:tt.divisibility + %3 = tt.load %2 {tt.constancy = dense<1> : tensor<1xi32>, tt.contiguity = dense<1> : tensor<1xi32>, tt.divisibility = dense<1> : tensor<1xi32>} : tensor<10x!tt.ptr> + // CHECK: gpu.barrier + gpu.barrier + %4 = tt.splat %arg1 : !tt.ptr -> tensor<10x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<10x!tt.ptr>, tensor<10xi32> + tt.store %5, %3 : tensor<10x!tt.ptr> + tt.return + } +} + +// ----- + +module { + tt.func public @compile_hint_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %true = arith.constant true + // CHECK: "llvm.intr.assume"(%true) : (i1) -> () + "llvm.intr.assume"(%true) : (i1) -> () + %0 = tt.make_range {end = 10 : i32, start = 0 : i32} : tensor<10xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<10x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<10x!tt.ptr>, tensor<10xi32> + // CHECK-NOT:tt.constancy + // CHECK-NOT:tt.contiguity + // CHECK-NOT:tt.divisibility + %3 = tt.load %2 {tt.constancy = dense<1> : tensor<1xi32>, tt.contiguity = dense<1> : tensor<1xi32>, tt.divisibility = dense<1> : tensor<1xi32>} : tensor<10x!tt.ptr> + // CHECK: gpu.barrier + gpu.barrier + %4 = tt.splat %arg1 : !tt.ptr -> tensor<10x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<10x!tt.ptr>, tensor<10xi32> + tt.store %5, %3 : tensor<10x!tt.ptr> + tt.return + } +} + +// ----- + +module { + tt.func public @compile_hint_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %true = arith.constant true + // CHECK: "llvm.intr.assume"(%true) : (i1) -> () + "llvm.intr.assume"(%true) : (i1) -> () + %0 = tt.make_range {end = 10 : i32, start = 0 : i32} : tensor<10xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<10x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<10x!tt.ptr>, tensor<10xi32> + // CHECK-NOT:tt.constancy + // CHECK-NOT:tt.contiguity + // CHECK-NOT:tt.divisibility + %3 = tt.load %2 {tt.constancy = dense<1> : tensor<1xi32>, tt.contiguity = dense<1> : tensor<1xi32>, tt.divisibility = dense<1> : tensor<1xi32>} : tensor<10x!tt.ptr> + // CHECK: gpu.barrier + gpu.barrier + %4 = tt.splat %arg1 : !tt.ptr -> tensor<10x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<10x!tt.ptr>, tensor<10xi32> + tt.store %5, %3 : tensor<10x!tt.ptr> + tt.return + } +} + +// ----- + +module { + tt.func public @compile_hint_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %true = arith.constant true + // CHECK: "llvm.intr.assume"(%true) : (i1) -> () + "llvm.intr.assume"(%true) : (i1) -> () + %0 = tt.make_range {end = 10 : i32, start = 0 : i32} : tensor<10xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<10x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<10x!tt.ptr>, tensor<10xi32> + // CHECK-NOT:tt.constancy + // CHECK-NOT:tt.contiguity + // CHECK-NOT:tt.divisibility + %3 = tt.load %2 {tt.constancy = dense<1> : tensor<1xi32>, tt.contiguity = dense<1> : tensor<1xi32>, tt.divisibility = dense<1> : tensor<1xi32>} : tensor<10x!tt.ptr> + // CHECK: gpu.barrier + gpu.barrier + %4 = tt.splat %arg1 : !tt.ptr -> tensor<10x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<10x!tt.ptr>, tensor<10xi32> + tt.store %5, %3 : tensor<10x!tt.ptr> + tt.return + } +} + +// ----- + +module { + tt.func public @compile_hint_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %true = arith.constant true + // CHECK: "llvm.intr.assume"(%true) : (i1) -> () + "llvm.intr.assume"(%true) : (i1) -> () + %0 = tt.make_range {end = 10 : i32, start = 0 : i32} : tensor<10xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<10x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<10x!tt.ptr>, tensor<10xi32> + // CHECK-NOT:tt.constancy + // CHECK-NOT:tt.contiguity + // CHECK-NOT:tt.divisibility + %3 = tt.load %2 {tt.constancy = dense<1> : tensor<1xi32>, tt.contiguity = dense<1> : tensor<1xi32>, tt.divisibility = dense<1> : tensor<1xi32>} : tensor<10x!tt.ptr> + // CHECK: gpu.barrier + gpu.barrier + %4 = tt.splat %arg1 : !tt.ptr -> tensor<10x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<10x!tt.ptr>, tensor<10xi32> + tt.store %5, %3 : tensor<10x!tt.ptr> + tt.return + } +} + +// ----- + +module { + tt.func public @compile_hint_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %true = arith.constant true + // CHECK: "llvm.intr.assume"(%true) : (i1) -> () + "llvm.intr.assume"(%true) : (i1) -> () + %0 = tt.make_range {end = 10 : i32, start = 0 : i32} : tensor<10xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<10x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<10x!tt.ptr>, tensor<10xi32> + // CHECK-NOT:tt.constancy + // CHECK-NOT:tt.contiguity + // CHECK-NOT:tt.divisibility + %3 = tt.load %2 {tt.constancy = dense<1> : tensor<1xi32>, tt.contiguity = dense<1> : tensor<1xi32>, tt.divisibility = dense<1> : tensor<1xi32>} : tensor<10x!tt.ptr> + // CHECK: gpu.barrier + gpu.barrier + %4 = tt.splat %arg1 : !tt.ptr -> tensor<10x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<10x!tt.ptr>, tensor<10xi32> + tt.store %5, %3 : tensor<10x!tt.ptr> + tt.return + } +} + +// ----- + +module { + tt.func public @compile_hint_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %true = arith.constant true + // CHECK: "llvm.intr.assume"(%true) : (i1) -> () + "llvm.intr.assume"(%true) : (i1) -> () + %0 = tt.make_range {end = 10 : i32, start = 0 : i32} : tensor<10xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<10x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<10x!tt.ptr>, tensor<10xi32> + // CHECK-NOT:tt.constancy + // CHECK-NOT:tt.contiguity + // CHECK-NOT:tt.divisibility + %3 = tt.load %2 {tt.constancy = dense<1> : tensor<1xi32>, tt.contiguity = dense<1> : tensor<1xi32>, tt.divisibility = dense<1> : tensor<1xi32>} : tensor<10x!tt.ptr> + // CHECK: gpu.barrier + gpu.barrier + %4 = tt.splat %arg1 : !tt.ptr -> tensor<10x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<10x!tt.ptr>, tensor<10xi32> + tt.store %5, %3 : tensor<10x!tt.ptr> + tt.return + } +} + +// ----- + +module { + tt.func public @compile_hint_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %true = arith.constant true + // CHECK: "llvm.intr.assume"(%true) : (i1) -> () + "llvm.intr.assume"(%true) : (i1) -> () + %0 = tt.make_range {end = 10 : i32, start = 0 : i32} : tensor<10xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<10x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<10x!tt.ptr>, tensor<10xi32> + // CHECK-NOT:tt.constancy + // CHECK-NOT:tt.contiguity + // CHECK-NOT:tt.divisibility + %3 = tt.load %2 {tt.constancy = dense<1> : tensor<1xi32>, tt.contiguity = dense<1> : tensor<1xi32>, tt.divisibility = dense<1> : tensor<1xi32>} : tensor<10x!tt.ptr> + // CHECK: gpu.barrier + gpu.barrier + %4 = tt.splat %arg1 : !tt.ptr -> tensor<10x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<10x!tt.ptr>, tensor<10xi32> + tt.store %5, %3 : tensor<10x!tt.ptr> + tt.return + } +} + +// ----- + +module { + tt.func public @compile_hint_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %true = arith.constant true + // CHECK: "llvm.intr.assume"(%true) : (i1) -> () + "llvm.intr.assume"(%true) : (i1) -> () + %0 = tt.make_range {end = 10 : i32, start = 0 : i32} : tensor<10xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<10x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<10x!tt.ptr>, tensor<10xi32> + %3 = tt.bitcast %2 : tensor<10x!tt.ptr> -> tensor<10x!tt.ptr> + // CHECK-NOT:tt.constancy + // CHECK-NOT:tt.contiguity + // CHECK-NOT:tt.divisibility + %4 = tt.load %3 {tt.constancy = dense<1> : tensor<1xi32>, tt.contiguity = dense<1> : tensor<1xi32>, tt.divisibility = dense<1> : tensor<1xi32>} : tensor<10x!tt.ptr> + // CHECK: gpu.barrier + gpu.barrier + %5 = tt.splat %arg1 : !tt.ptr -> tensor<10x!tt.ptr> + %6 = tt.addptr %5, %0 : tensor<10x!tt.ptr>, tensor<10xi32> + %7 = tt.bitcast %6 : tensor<10x!tt.ptr> -> tensor<10x!tt.ptr> + tt.store %7, %4 : tensor<10x!tt.ptr> + tt.return + } } diff --git a/third_party/ascend/test/Conversion/TritonOp/cumprod.mlir b/third_party/ascend/test/Conversion/TritonOp/cumprod.mlir index 5b9051009..b22133194 100644 --- a/third_party/ascend/test/Conversion/TritonOp/cumprod.mlir +++ b/third_party/ascend/test/Conversion/TritonOp/cumprod.mlir @@ -1,136 +1,136 @@ -// RUN: triton-adapter-opt --triton-linearize --discrete-mask-access-conversion --triton-to-annotation --triton-to-unstructure --triton-to-hivm --triton-to-hfusion --triton-to-llvm --bubble-up-operation '--triton-to-linalg=global-kernel=false named-ops=True enable-nd2nz-on-vector=False' --split-input-file %s | FileCheck %s - -// === i8 u8 version === -module { - tt.func public @fn_npu_u8( - %arg0: !tt.ptr {tt.divisibility = 16 : i32}, - %arg1: !tt.ptr {tt.divisibility = 16 : i32} - ) { - %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> - %1 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<8x!tt.ptr>, tensor<8xi32> - %3 = tt.load %2 : tensor<8x!tt.ptr> - %4 = "tt.scan"(%3) <{axis = 0 : i32, reverse = false}> ({ - ^bb0(%arg2: i8, %arg3: i8): - %7 = arith.muli %arg2, %arg3 : i8 - tt.scan.return %7 : i8 - }) : (tensor<8xi8>) -> tensor<8xi8> - %5 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr> - %6 = tt.addptr %5, %0 : tensor<8x!tt.ptr>, tensor<8xi32> - tt.store %6, %4 : tensor<8x!tt.ptr> - tt.return - } -} - -// ----- - -// CHECK: func.func private @triton_cumprod_0(tensor<8xi8>, i32, i1) -> tensor<8xi8> -// CHECK: %false = arith.constant false -// CHECK: %c0_i32 = arith.constant 0 : i32 -// CHECK: %[[INPUT_BUF:.*]] = memref.alloc() : memref<8xi8> -// CHECK: memref.copy {{.*}}, %[[INPUT_BUF]] : memref<8xi8{{.*}}> to memref<8xi8> -// CHECK: %[[TENSOR:.*]] = bufferization.to_tensor %[[INPUT_BUF]] restrict writable : memref<8xi8> -// CHECK: %{{.*}} = call @triton_cumprod_0(%[[TENSOR]], %c0_i32, %false) : (tensor<8xi8>, i32, i1) -> tensor<8xi8> -// CHECK: bufferization.materialize_in_destination - - -// === i16 u16 version === -module { - tt.func public @fn_npu_u16( - %arg0: !tt.ptr {tt.divisibility = 16 : i32}, - %arg1: !tt.ptr {tt.divisibility = 16 : i32} - ) { - %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> - %1 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<8x!tt.ptr>, tensor<8xi32> - %3 = tt.load %2 : tensor<8x!tt.ptr> - %4 = "tt.scan"(%3) <{axis = 0 : i32, reverse = false}> ({ - ^bb0(%arg2: i16, %arg3: i16): - %7 = arith.muli %arg2, %arg3 : i16 - tt.scan.return %7 : i16 - }) : (tensor<8xi16>) -> tensor<8xi16> - %5 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr> - %6 = tt.addptr %5, %0 : tensor<8x!tt.ptr>, tensor<8xi32> - tt.store %6, %4 : tensor<8x!tt.ptr> - tt.return - } -} - -// ----- - -// CHECK: func.func private @triton_cumprod_0(tensor<8xi16>, i32, i1) -> tensor<8xi16> -// CHECK: %false = arith.constant false -// CHECK: %c0_i32 = arith.constant 0 : i32 -// CHECK: %[[INPUT_BUF:.*]] = memref.alloc() : memref<8xi16> -// CHECK: memref.copy {{.*}}, %[[INPUT_BUF]] : memref<8xi16{{.*}}> to memref<8xi16> -// CHECK: %[[TENSOR:.*]] = bufferization.to_tensor %[[INPUT_BUF]] restrict writable : memref<8xi16> -// CHECK: %{{.*}} = call @triton_cumprod_0(%[[TENSOR]], %c0_i32, %false) : (tensor<8xi16>, i32, i1) -> tensor<8xi16> -// CHECK: bufferization.materialize_in_destination - - -// === i32 u32 version === -module { - tt.func public @fn_npu_u32( - %arg0: !tt.ptr {tt.divisibility = 16 : i32}, - %arg1: !tt.ptr {tt.divisibility = 16 : i32} - ) { - %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> - %1 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<8x!tt.ptr>, tensor<8xi32> - %3 = tt.load %2 : tensor<8x!tt.ptr> - %4 = "tt.scan"(%3) <{axis = 0 : i32, reverse = false}> ({ - ^bb0(%arg2: i32, %arg3: i32): - %7 = arith.muli %arg2, %arg3 : i32 - tt.scan.return %7 : i32 - }) : (tensor<8xi32>) -> tensor<8xi32> - %5 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr> - %6 = tt.addptr %5, %0 : tensor<8x!tt.ptr>, tensor<8xi32> - tt.store %6, %4 : tensor<8x!tt.ptr> - tt.return - } -} - -// ----- - -// CHECK: func.func private @triton_cumprod_0(tensor<8xi32>, i32, i1) -> tensor<8xi32> -// CHECK: %false = arith.constant false -// CHECK: %c0_i32 = arith.constant 0 : i32 -// CHECK: %[[INPUT_BUF:.*]] = memref.alloc() : memref<8xi32> -// CHECK: memref.copy {{.*}}, %[[INPUT_BUF]] : memref<8xi32{{.*}}> to memref<8xi32> -// CHECK: %[[TENSOR:.*]] = bufferization.to_tensor %[[INPUT_BUF]] restrict writable : memref<8xi32> -// CHECK: %{{.*}} = call @triton_cumprod_0(%[[TENSOR]], %c0_i32, %false) : (tensor<8xi32>, i32, i1) -> tensor<8xi32> -// CHECK: bufferization.materialize_in_destination - - -// === i64 u64 version === -module { - tt.func public @fn_npu_u64( - %arg0: !tt.ptr {tt.divisibility = 16 : i32}, - %arg1: !tt.ptr {tt.divisibility = 16 : i32} - ) { - %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> - %1 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<8x!tt.ptr>, tensor<8xi32> - %3 = tt.load %2 : tensor<8x!tt.ptr> - %4 = "tt.scan"(%3) <{axis = 0 : i32, reverse = false}> ({ - ^bb0(%arg2: i64, %arg3: i64): - %7 = arith.muli %arg2, %arg3 : i64 - tt.scan.return %7 : i64 - }) : (tensor<8xi64>) -> tensor<8xi64> - %5 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr> - %6 = tt.addptr %5, %0 : tensor<8x!tt.ptr>, tensor<8xi32> - tt.store %6, %4 : tensor<8x!tt.ptr> - tt.return - } -} - -// ----- - -// CHECK: func.func private @triton_cumprod_0(tensor<8xi64>, i32, i1) -> tensor<8xi64> -// CHECK: %false = arith.constant false -// CHECK: %c0_i32 = arith.constant 0 : i32 -// CHECK: %[[INPUT_BUF:.*]] = memref.alloc() : memref<8xi64> -// CHECK: memref.copy {{.*}}, %[[INPUT_BUF]] : memref<8xi64{{.*}}> to memref<8xi64> -// CHECK: %[[TENSOR:.*]] = bufferization.to_tensor %[[INPUT_BUF]] restrict writable : memref<8xi64> -// CHECK: %{{.*}} = call @triton_cumprod_0(%[[TENSOR]], %c0_i32, %false) : (tensor<8xi64>, i32, i1) -> tensor<8xi64> +// RUN: triton-adapter-opt --triton-linearize --discrete-mask-access-conversion --triton-to-annotation --triton-to-unstructure --triton-to-hivm --triton-to-hfusion --triton-to-llvm --bubble-up-operation '--triton-to-linalg=global-kernel=false named-ops=True enable-nd2nz-on-vector=False' --split-input-file %s | FileCheck %s + +// === i8 u8 version === +module { + tt.func public @fn_npu_u8( + %arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32} + ) { + %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %1 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<8x!tt.ptr>, tensor<8xi32> + %3 = tt.load %2 : tensor<8x!tt.ptr> + %4 = "tt.scan"(%3) <{axis = 0 : i32, reverse = false}> ({ + ^bb0(%arg2: i8, %arg3: i8): + %7 = arith.muli %arg2, %arg3 : i8 + tt.scan.return %7 : i8 + }) : (tensor<8xi8>) -> tensor<8xi8> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr> + %6 = tt.addptr %5, %0 : tensor<8x!tt.ptr>, tensor<8xi32> + tt.store %6, %4 : tensor<8x!tt.ptr> + tt.return + } +} + +// ----- + +// CHECK: func.func private @triton_cumprod_0(tensor<8xi8>, i32, i1) -> tensor<8xi8> +// CHECK: %false = arith.constant false +// CHECK: %c0_i32 = arith.constant 0 : i32 +// CHECK: %[[INPUT_BUF:.*]] = memref.alloc() : memref<8xi8> +// CHECK: memref.copy {{.*}}, %[[INPUT_BUF]] : memref<8xi8{{.*}}> to memref<8xi8> +// CHECK: %[[TENSOR:.*]] = bufferization.to_tensor %[[INPUT_BUF]] restrict writable : memref<8xi8> +// CHECK: %{{.*}} = call @triton_cumprod_0(%[[TENSOR]], %c0_i32, %false) : (tensor<8xi8>, i32, i1) -> tensor<8xi8> +// CHECK: bufferization.materialize_in_destination + + +// === i16 u16 version === +module { + tt.func public @fn_npu_u16( + %arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32} + ) { + %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %1 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<8x!tt.ptr>, tensor<8xi32> + %3 = tt.load %2 : tensor<8x!tt.ptr> + %4 = "tt.scan"(%3) <{axis = 0 : i32, reverse = false}> ({ + ^bb0(%arg2: i16, %arg3: i16): + %7 = arith.muli %arg2, %arg3 : i16 + tt.scan.return %7 : i16 + }) : (tensor<8xi16>) -> tensor<8xi16> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr> + %6 = tt.addptr %5, %0 : tensor<8x!tt.ptr>, tensor<8xi32> + tt.store %6, %4 : tensor<8x!tt.ptr> + tt.return + } +} + +// ----- + +// CHECK: func.func private @triton_cumprod_0(tensor<8xi16>, i32, i1) -> tensor<8xi16> +// CHECK: %false = arith.constant false +// CHECK: %c0_i32 = arith.constant 0 : i32 +// CHECK: %[[INPUT_BUF:.*]] = memref.alloc() : memref<8xi16> +// CHECK: memref.copy {{.*}}, %[[INPUT_BUF]] : memref<8xi16{{.*}}> to memref<8xi16> +// CHECK: %[[TENSOR:.*]] = bufferization.to_tensor %[[INPUT_BUF]] restrict writable : memref<8xi16> +// CHECK: %{{.*}} = call @triton_cumprod_0(%[[TENSOR]], %c0_i32, %false) : (tensor<8xi16>, i32, i1) -> tensor<8xi16> +// CHECK: bufferization.materialize_in_destination + + +// === i32 u32 version === +module { + tt.func public @fn_npu_u32( + %arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32} + ) { + %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %1 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<8x!tt.ptr>, tensor<8xi32> + %3 = tt.load %2 : tensor<8x!tt.ptr> + %4 = "tt.scan"(%3) <{axis = 0 : i32, reverse = false}> ({ + ^bb0(%arg2: i32, %arg3: i32): + %7 = arith.muli %arg2, %arg3 : i32 + tt.scan.return %7 : i32 + }) : (tensor<8xi32>) -> tensor<8xi32> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr> + %6 = tt.addptr %5, %0 : tensor<8x!tt.ptr>, tensor<8xi32> + tt.store %6, %4 : tensor<8x!tt.ptr> + tt.return + } +} + +// ----- + +// CHECK: func.func private @triton_cumprod_0(tensor<8xi32>, i32, i1) -> tensor<8xi32> +// CHECK: %false = arith.constant false +// CHECK: %c0_i32 = arith.constant 0 : i32 +// CHECK: %[[INPUT_BUF:.*]] = memref.alloc() : memref<8xi32> +// CHECK: memref.copy {{.*}}, %[[INPUT_BUF]] : memref<8xi32{{.*}}> to memref<8xi32> +// CHECK: %[[TENSOR:.*]] = bufferization.to_tensor %[[INPUT_BUF]] restrict writable : memref<8xi32> +// CHECK: %{{.*}} = call @triton_cumprod_0(%[[TENSOR]], %c0_i32, %false) : (tensor<8xi32>, i32, i1) -> tensor<8xi32> +// CHECK: bufferization.materialize_in_destination + + +// === i64 u64 version === +module { + tt.func public @fn_npu_u64( + %arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32} + ) { + %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %1 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<8x!tt.ptr>, tensor<8xi32> + %3 = tt.load %2 : tensor<8x!tt.ptr> + %4 = "tt.scan"(%3) <{axis = 0 : i32, reverse = false}> ({ + ^bb0(%arg2: i64, %arg3: i64): + %7 = arith.muli %arg2, %arg3 : i64 + tt.scan.return %7 : i64 + }) : (tensor<8xi64>) -> tensor<8xi64> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr> + %6 = tt.addptr %5, %0 : tensor<8x!tt.ptr>, tensor<8xi32> + tt.store %6, %4 : tensor<8x!tt.ptr> + tt.return + } +} + +// ----- + +// CHECK: func.func private @triton_cumprod_0(tensor<8xi64>, i32, i1) -> tensor<8xi64> +// CHECK: %false = arith.constant false +// CHECK: %c0_i32 = arith.constant 0 : i32 +// CHECK: %[[INPUT_BUF:.*]] = memref.alloc() : memref<8xi64> +// CHECK: memref.copy {{.*}}, %[[INPUT_BUF]] : memref<8xi64{{.*}}> to memref<8xi64> +// CHECK: %[[TENSOR:.*]] = bufferization.to_tensor %[[INPUT_BUF]] restrict writable : memref<8xi64> +// CHECK: %{{.*}} = call @triton_cumprod_0(%[[TENSOR]], %c0_i32, %false) : (tensor<8xi64>, i32, i1) -> tensor<8xi64> // CHECK: bufferization.materialize_in_destination diff --git a/third_party/ascend/test/Conversion/TritonOp/cumsum.mlir b/third_party/ascend/test/Conversion/TritonOp/cumsum.mlir index 376d0b7c1..ff4f3339a 100644 --- a/third_party/ascend/test/Conversion/TritonOp/cumsum.mlir +++ b/third_party/ascend/test/Conversion/TritonOp/cumsum.mlir @@ -1,204 +1,204 @@ -// RUN: triton-adapter-opt --triton-linearize --discrete-mask-access-conversion --triton-to-annotation --triton-to-unstructure --triton-to-hivm --triton-to-hfusion --triton-to-llvm --bubble-up-operation '--triton-to-linalg=global-kernel=false named-ops=True enable-nd2nz-on-vector=False' --split-input-file %s | FileCheck %s - -// === i8 u8 version === -module { - tt.func public @fn_npu_u8( - %arg0: !tt.ptr {tt.divisibility = 16 : i32}, - %arg1: !tt.ptr {tt.divisibility = 16 : i32} - ) { - %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> - %1 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<8x!tt.ptr>, tensor<8xi32> - %3 = tt.load %2 : tensor<8x!tt.ptr> - %4 = "tt.scan"(%3) <{axis = 0 : i32, reverse = false}> ({ - ^bb0(%arg2: i8, %arg3: i8): - %7 = arith.addi %arg2, %arg3 : i8 - tt.scan.return %7 : i8 - }) : (tensor<8xi8>) -> tensor<8xi8> - %5 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr> - %6 = tt.addptr %5, %0 : tensor<8x!tt.ptr>, tensor<8xi32> - tt.store %6, %4 : tensor<8x!tt.ptr> - tt.return - } -} - -// ----- - -// CHECK: func.func private @triton_cumsum_0(tensor<8xi8>, i32, i1) -> tensor<8xi8> -// CHECK: %false = arith.constant false -// CHECK: %c0_i32 = arith.constant 0 : i32 -// CHECK: %[[INPUT_BUF:.*]] = memref.alloc() : memref<8xi8> -// CHECK: memref.copy {{.*}}, %[[INPUT_BUF]] : memref<8xi8{{.*}}> to memref<8xi8> -// CHECK: %[[TENSOR:.*]] = bufferization.to_tensor %[[INPUT_BUF]] restrict writable : memref<8xi8> -// CHECK: %{{.*}} = call @triton_cumsum_0(%[[TENSOR]], %c0_i32, %false) : (tensor<8xi8>, i32, i1) -> tensor<8xi8> -// CHECK: bufferization.materialize_in_destination - - -// === i16 u16 version === -module { - tt.func public @fn_npu_u16( - %arg0: !tt.ptr {tt.divisibility = 16 : i32}, - %arg1: !tt.ptr {tt.divisibility = 16 : i32} - ) { - %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> - %1 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<8x!tt.ptr>, tensor<8xi32> - %3 = tt.load %2 : tensor<8x!tt.ptr> - %4 = "tt.scan"(%3) <{axis = 0 : i32, reverse = false}> ({ - ^bb0(%arg2: i16, %arg3: i16): - %7 = arith.addi %arg2, %arg3 : i16 - tt.scan.return %7 : i16 - }) : (tensor<8xi16>) -> tensor<8xi16> - %5 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr> - %6 = tt.addptr %5, %0 : tensor<8x!tt.ptr>, tensor<8xi32> - tt.store %6, %4 : tensor<8x!tt.ptr> - tt.return - } -} - -// ----- - -// CHECK: func.func private @triton_cumsum_0(tensor<8xi16>, i32, i1) -> tensor<8xi16> -// CHECK: %false = arith.constant false -// CHECK: %c0_i32 = arith.constant 0 : i32 -// CHECK: %[[INPUT_BUF:.*]] = memref.alloc() : memref<8xi16> -// CHECK: memref.copy {{.*}}, %[[INPUT_BUF]] : memref<8xi16{{.*}}> to memref<8xi16> -// CHECK: %[[TENSOR:.*]] = bufferization.to_tensor %[[INPUT_BUF]] restrict writable : memref<8xi16> -// CHECK: %{{.*}} = call @triton_cumsum_0(%[[TENSOR]], %c0_i32, %false) : (tensor<8xi16>, i32, i1) -> tensor<8xi16> -// CHECK: bufferization.materialize_in_destination - - -// === i32 u32 version === -module { - tt.func public @fn_npu_u32( - %arg0: !tt.ptr {tt.divisibility = 16 : i32}, - %arg1: !tt.ptr {tt.divisibility = 16 : i32} - ) { - %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> - %1 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<8x!tt.ptr>, tensor<8xi32> - %3 = tt.load %2 : tensor<8x!tt.ptr> - %4 = "tt.scan"(%3) <{axis = 0 : i32, reverse = false}> ({ - ^bb0(%arg2: i32, %arg3: i32): - %7 = arith.addi %arg2, %arg3 : i32 - tt.scan.return %7 : i32 - }) : (tensor<8xi32>) -> tensor<8xi32> - %5 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr> - %6 = tt.addptr %5, %0 : tensor<8x!tt.ptr>, tensor<8xi32> - tt.store %6, %4 : tensor<8x!tt.ptr> - tt.return - } -} - -// ----- - -// CHECK: func.func private @triton_cumsum_0(tensor<8xi32>, i32, i1) -> tensor<8xi32> -// CHECK: %false = arith.constant false -// CHECK: %c0_i32 = arith.constant 0 : i32 -// CHECK: %[[INPUT_BUF:.*]] = memref.alloc() : memref<8xi32> -// CHECK: memref.copy {{.*}}, %[[INPUT_BUF]] : memref<8xi32{{.*}}> to memref<8xi32> -// CHECK: %[[TENSOR:.*]] = bufferization.to_tensor %[[INPUT_BUF]] restrict writable : memref<8xi32> -// CHECK: %{{.*}} = call @triton_cumsum_0(%[[TENSOR]], %c0_i32, %false) : (tensor<8xi32>, i32, i1) -> tensor<8xi32> -// CHECK: bufferization.materialize_in_destination - - -// === i64 u64 version === -module { - tt.func public @fn_npu_u64( - %arg0: !tt.ptr {tt.divisibility = 16 : i32}, - %arg1: !tt.ptr {tt.divisibility = 16 : i32} - ) { - %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> - %1 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<8x!tt.ptr>, tensor<8xi32> - %3 = tt.load %2 : tensor<8x!tt.ptr> - %4 = "tt.scan"(%3) <{axis = 0 : i32, reverse = false}> ({ - ^bb0(%arg2: i64, %arg3: i64): - %7 = arith.addi %arg2, %arg3 : i64 - tt.scan.return %7 : i64 - }) : (tensor<8xi64>) -> tensor<8xi64> - %5 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr> - %6 = tt.addptr %5, %0 : tensor<8x!tt.ptr>, tensor<8xi32> - tt.store %6, %4 : tensor<8x!tt.ptr> - tt.return - } -} - -// ----- - -// CHECK: func.func private @triton_cumsum_0(tensor<8xi64>, i32, i1) -> tensor<8xi64> -// CHECK: %false = arith.constant false -// CHECK: %c0_i32 = arith.constant 0 : i32 -// CHECK: %[[INPUT_BUF:.*]] = memref.alloc() : memref<8xi64> -// CHECK: memref.copy {{.*}}, %[[INPUT_BUF]] : memref<8xi64{{.*}}> to memref<8xi64> -// CHECK: %[[TENSOR:.*]] = bufferization.to_tensor %[[INPUT_BUF]] restrict writable : memref<8xi64> -// CHECK: %{{.*}} = call @triton_cumsum_0(%[[TENSOR]], %c0_i32, %false) : (tensor<8xi64>, i32, i1) -> tensor<8xi64> -// CHECK: bufferization.materialize_in_destination - - -// === f8E4M3FN version === -module { - tt.func public @fn_npu_f8E4M3FN( - %arg0: !tt.ptr {tt.divisibility = 16 : i32}, - %arg1: !tt.ptr {tt.divisibility = 16 : i32} - ) { - %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> - %1 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<8x!tt.ptr>, tensor<8xi32> - %3 = tt.load %2 : tensor<8x!tt.ptr> - %4 = "tt.scan"(%3) <{axis = 0 : i32, reverse = false}> ({ - ^bb0(%arg2: f8E4M3FN, %arg3: f8E4M3FN): - %7 = arith.addf %arg2, %arg3 : f8E4M3FN - tt.scan.return %7 : f8E4M3FN - }) : (tensor<8xf8E4M3FN>) -> tensor<8xf8E4M3FN> - %5 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr> - %6 = tt.addptr %5, %0 : tensor<8x!tt.ptr>, tensor<8xi32> - tt.store %6, %4 : tensor<8x!tt.ptr> - tt.return - } -} - -// ----- - -// CHECK: func.func private @triton_cumsum_0(tensor<8xf8E4M3FN>, i32, i1) -> tensor<8xf8E4M3FN> -// CHECK: %false = arith.constant false -// CHECK: %c0_i32 = arith.constant 0 : i32 -// CHECK: %[[INPUT_BUF:.+]] = memref.alloc() : memref<8xf8E4M3FN> -// CHECK: memref.copy {{.*}}, %[[INPUT_BUF]] : memref<8xf8E4M3FN{{.*}}> to memref<8xf8E4M3FN> -// CHECK: %[[TENSOR:.+]] = bufferization.to_tensor %[[INPUT_BUF]] restrict writable : memref<8xf8E4M3FN> -// CHECK: %{{.*}} = call @triton_cumsum_0(%[[TENSOR]], %c0_i32, %false) : (tensor<8xf8E4M3FN>, i32, i1) -> tensor<8xf8E4M3FN> -// CHECK: bufferization.materialize_in_destination - - -// === f8E5M2 version === -module { - tt.func public @fn_npu_f8E5M2( - %arg0: !tt.ptr {tt.divisibility = 16 : i32}, - %arg1: !tt.ptr {tt.divisibility = 16 : i32} - ) { - %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> - %1 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<8x!tt.ptr>, tensor<8xi32> - %3 = tt.load %2 : tensor<8x!tt.ptr> - %4 = "tt.scan"(%3) <{axis = 0 : i32, reverse = false}> ({ - ^bb0(%arg2: f8E5M2, %arg3: f8E5M2): - %7 = arith.addf %arg2, %arg3 : f8E5M2 - tt.scan.return %7 : f8E5M2 - }) : (tensor<8xf8E5M2>) -> tensor<8xf8E5M2> - %5 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr> - %6 = tt.addptr %5, %0 : tensor<8x!tt.ptr>, tensor<8xi32> - tt.store %6, %4 : tensor<8x!tt.ptr> - tt.return - } -} - -// ----- - -// CHECK: func.func private @triton_cumsum_0(tensor<8xf8E5M2>, i32, i1) -> tensor<8xf8E5M2> -// CHECK: %false = arith.constant false -// CHECK: %c0_i32 = arith.constant 0 : i32 -// CHECK: %[[INPUT_BUF:.+]] = memref.alloc() : memref<8xf8E5M2> -// CHECK: memref.copy {{.*}}, %[[INPUT_BUF]] : memref<8xf8E5M2{{.*}}> to memref<8xf8E5M2> -// CHECK: %[[TENSOR:.+]] = bufferization.to_tensor %[[INPUT_BUF]] restrict writable : memref<8xf8E5M2> -// CHECK: %{{.*}} = call @triton_cumsum_0(%[[TENSOR]], %c0_i32, %false) : (tensor<8xf8E5M2>, i32, i1) -> tensor<8xf8E5M2> +// RUN: triton-adapter-opt --triton-linearize --discrete-mask-access-conversion --triton-to-annotation --triton-to-unstructure --triton-to-hivm --triton-to-hfusion --triton-to-llvm --bubble-up-operation '--triton-to-linalg=global-kernel=false named-ops=True enable-nd2nz-on-vector=False' --split-input-file %s | FileCheck %s + +// === i8 u8 version === +module { + tt.func public @fn_npu_u8( + %arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32} + ) { + %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %1 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<8x!tt.ptr>, tensor<8xi32> + %3 = tt.load %2 : tensor<8x!tt.ptr> + %4 = "tt.scan"(%3) <{axis = 0 : i32, reverse = false}> ({ + ^bb0(%arg2: i8, %arg3: i8): + %7 = arith.addi %arg2, %arg3 : i8 + tt.scan.return %7 : i8 + }) : (tensor<8xi8>) -> tensor<8xi8> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr> + %6 = tt.addptr %5, %0 : tensor<8x!tt.ptr>, tensor<8xi32> + tt.store %6, %4 : tensor<8x!tt.ptr> + tt.return + } +} + +// ----- + +// CHECK: func.func private @triton_cumsum_0(tensor<8xi8>, i32, i1) -> tensor<8xi8> +// CHECK: %false = arith.constant false +// CHECK: %c0_i32 = arith.constant 0 : i32 +// CHECK: %[[INPUT_BUF:.*]] = memref.alloc() : memref<8xi8> +// CHECK: memref.copy {{.*}}, %[[INPUT_BUF]] : memref<8xi8{{.*}}> to memref<8xi8> +// CHECK: %[[TENSOR:.*]] = bufferization.to_tensor %[[INPUT_BUF]] restrict writable : memref<8xi8> +// CHECK: %{{.*}} = call @triton_cumsum_0(%[[TENSOR]], %c0_i32, %false) : (tensor<8xi8>, i32, i1) -> tensor<8xi8> +// CHECK: bufferization.materialize_in_destination + + +// === i16 u16 version === +module { + tt.func public @fn_npu_u16( + %arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32} + ) { + %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %1 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<8x!tt.ptr>, tensor<8xi32> + %3 = tt.load %2 : tensor<8x!tt.ptr> + %4 = "tt.scan"(%3) <{axis = 0 : i32, reverse = false}> ({ + ^bb0(%arg2: i16, %arg3: i16): + %7 = arith.addi %arg2, %arg3 : i16 + tt.scan.return %7 : i16 + }) : (tensor<8xi16>) -> tensor<8xi16> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr> + %6 = tt.addptr %5, %0 : tensor<8x!tt.ptr>, tensor<8xi32> + tt.store %6, %4 : tensor<8x!tt.ptr> + tt.return + } +} + +// ----- + +// CHECK: func.func private @triton_cumsum_0(tensor<8xi16>, i32, i1) -> tensor<8xi16> +// CHECK: %false = arith.constant false +// CHECK: %c0_i32 = arith.constant 0 : i32 +// CHECK: %[[INPUT_BUF:.*]] = memref.alloc() : memref<8xi16> +// CHECK: memref.copy {{.*}}, %[[INPUT_BUF]] : memref<8xi16{{.*}}> to memref<8xi16> +// CHECK: %[[TENSOR:.*]] = bufferization.to_tensor %[[INPUT_BUF]] restrict writable : memref<8xi16> +// CHECK: %{{.*}} = call @triton_cumsum_0(%[[TENSOR]], %c0_i32, %false) : (tensor<8xi16>, i32, i1) -> tensor<8xi16> +// CHECK: bufferization.materialize_in_destination + + +// === i32 u32 version === +module { + tt.func public @fn_npu_u32( + %arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32} + ) { + %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %1 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<8x!tt.ptr>, tensor<8xi32> + %3 = tt.load %2 : tensor<8x!tt.ptr> + %4 = "tt.scan"(%3) <{axis = 0 : i32, reverse = false}> ({ + ^bb0(%arg2: i32, %arg3: i32): + %7 = arith.addi %arg2, %arg3 : i32 + tt.scan.return %7 : i32 + }) : (tensor<8xi32>) -> tensor<8xi32> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr> + %6 = tt.addptr %5, %0 : tensor<8x!tt.ptr>, tensor<8xi32> + tt.store %6, %4 : tensor<8x!tt.ptr> + tt.return + } +} + +// ----- + +// CHECK: func.func private @triton_cumsum_0(tensor<8xi32>, i32, i1) -> tensor<8xi32> +// CHECK: %false = arith.constant false +// CHECK: %c0_i32 = arith.constant 0 : i32 +// CHECK: %[[INPUT_BUF:.*]] = memref.alloc() : memref<8xi32> +// CHECK: memref.copy {{.*}}, %[[INPUT_BUF]] : memref<8xi32{{.*}}> to memref<8xi32> +// CHECK: %[[TENSOR:.*]] = bufferization.to_tensor %[[INPUT_BUF]] restrict writable : memref<8xi32> +// CHECK: %{{.*}} = call @triton_cumsum_0(%[[TENSOR]], %c0_i32, %false) : (tensor<8xi32>, i32, i1) -> tensor<8xi32> +// CHECK: bufferization.materialize_in_destination + + +// === i64 u64 version === +module { + tt.func public @fn_npu_u64( + %arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32} + ) { + %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %1 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<8x!tt.ptr>, tensor<8xi32> + %3 = tt.load %2 : tensor<8x!tt.ptr> + %4 = "tt.scan"(%3) <{axis = 0 : i32, reverse = false}> ({ + ^bb0(%arg2: i64, %arg3: i64): + %7 = arith.addi %arg2, %arg3 : i64 + tt.scan.return %7 : i64 + }) : (tensor<8xi64>) -> tensor<8xi64> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr> + %6 = tt.addptr %5, %0 : tensor<8x!tt.ptr>, tensor<8xi32> + tt.store %6, %4 : tensor<8x!tt.ptr> + tt.return + } +} + +// ----- + +// CHECK: func.func private @triton_cumsum_0(tensor<8xi64>, i32, i1) -> tensor<8xi64> +// CHECK: %false = arith.constant false +// CHECK: %c0_i32 = arith.constant 0 : i32 +// CHECK: %[[INPUT_BUF:.*]] = memref.alloc() : memref<8xi64> +// CHECK: memref.copy {{.*}}, %[[INPUT_BUF]] : memref<8xi64{{.*}}> to memref<8xi64> +// CHECK: %[[TENSOR:.*]] = bufferization.to_tensor %[[INPUT_BUF]] restrict writable : memref<8xi64> +// CHECK: %{{.*}} = call @triton_cumsum_0(%[[TENSOR]], %c0_i32, %false) : (tensor<8xi64>, i32, i1) -> tensor<8xi64> +// CHECK: bufferization.materialize_in_destination + + +// === f8E4M3FN version === +module { + tt.func public @fn_npu_f8E4M3FN( + %arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32} + ) { + %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %1 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<8x!tt.ptr>, tensor<8xi32> + %3 = tt.load %2 : tensor<8x!tt.ptr> + %4 = "tt.scan"(%3) <{axis = 0 : i32, reverse = false}> ({ + ^bb0(%arg2: f8E4M3FN, %arg3: f8E4M3FN): + %7 = arith.addf %arg2, %arg3 : f8E4M3FN + tt.scan.return %7 : f8E4M3FN + }) : (tensor<8xf8E4M3FN>) -> tensor<8xf8E4M3FN> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr> + %6 = tt.addptr %5, %0 : tensor<8x!tt.ptr>, tensor<8xi32> + tt.store %6, %4 : tensor<8x!tt.ptr> + tt.return + } +} + +// ----- + +// CHECK: func.func private @triton_cumsum_0(tensor<8xf8E4M3FN>, i32, i1) -> tensor<8xf8E4M3FN> +// CHECK: %false = arith.constant false +// CHECK: %c0_i32 = arith.constant 0 : i32 +// CHECK: %[[INPUT_BUF:.+]] = memref.alloc() : memref<8xf8E4M3FN> +// CHECK: memref.copy {{.*}}, %[[INPUT_BUF]] : memref<8xf8E4M3FN{{.*}}> to memref<8xf8E4M3FN> +// CHECK: %[[TENSOR:.+]] = bufferization.to_tensor %[[INPUT_BUF]] restrict writable : memref<8xf8E4M3FN> +// CHECK: %{{.*}} = call @triton_cumsum_0(%[[TENSOR]], %c0_i32, %false) : (tensor<8xf8E4M3FN>, i32, i1) -> tensor<8xf8E4M3FN> +// CHECK: bufferization.materialize_in_destination + + +// === f8E5M2 version === +module { + tt.func public @fn_npu_f8E5M2( + %arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32} + ) { + %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %1 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<8x!tt.ptr>, tensor<8xi32> + %3 = tt.load %2 : tensor<8x!tt.ptr> + %4 = "tt.scan"(%3) <{axis = 0 : i32, reverse = false}> ({ + ^bb0(%arg2: f8E5M2, %arg3: f8E5M2): + %7 = arith.addf %arg2, %arg3 : f8E5M2 + tt.scan.return %7 : f8E5M2 + }) : (tensor<8xf8E5M2>) -> tensor<8xf8E5M2> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr> + %6 = tt.addptr %5, %0 : tensor<8x!tt.ptr>, tensor<8xi32> + tt.store %6, %4 : tensor<8x!tt.ptr> + tt.return + } +} + +// ----- + +// CHECK: func.func private @triton_cumsum_0(tensor<8xf8E5M2>, i32, i1) -> tensor<8xf8E5M2> +// CHECK: %false = arith.constant false +// CHECK: %c0_i32 = arith.constant 0 : i32 +// CHECK: %[[INPUT_BUF:.+]] = memref.alloc() : memref<8xf8E5M2> +// CHECK: memref.copy {{.*}}, %[[INPUT_BUF]] : memref<8xf8E5M2{{.*}}> to memref<8xf8E5M2> +// CHECK: %[[TENSOR:.+]] = bufferization.to_tensor %[[INPUT_BUF]] restrict writable : memref<8xf8E5M2> +// CHECK: %{{.*}} = call @triton_cumsum_0(%[[TENSOR]], %c0_i32, %false) : (tensor<8xf8E5M2>, i32, i1) -> tensor<8xf8E5M2> // CHECK: bufferization.materialize_in_destination diff --git a/third_party/ascend/test/Conversion/TritonOp/device_print_and_assert.mlir b/third_party/ascend/test/Conversion/TritonOp/device_print_and_assert.mlir index c722d6f91..f6fbbd3b1 100644 --- a/third_party/ascend/test/Conversion/TritonOp/device_print_and_assert.mlir +++ b/third_party/ascend/test/Conversion/TritonOp/device_print_and_assert.mlir @@ -1,166 +1,166 @@ -// RUN: triton-adapter-opt --triton-linearize --discrete-mask-access-conversion --triton-to-annotation '--triton-to-unstructure=compile-on-910-95=False force-simt-template=False' --triton-to-hivm --triton-to-hfusion --triton-to-llvm --bubble-up-operation '--triton-to-linalg=global-kernel=false named-ops=True enable-nd2nz-on-vector=False compile-on-910-95=False' %s | FileCheck %s - -module { - // CHECK: func.func private @triton_print_0(tensor<10xi8>) attributes {hex = false, prefix = " Type: uint8: "} - // CHECK: func.func @device_print_kernel - tt.func public @device_print_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { - %0 = tt.make_range {end = 10 : i32, start = 0 : i32} : tensor<10xi32> - %1 = tt.splat %arg0 : !tt.ptr -> tensor<10x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<10x!tt.ptr>, tensor<10xi32> - %3 = tt.load %2 : tensor<10x!tt.ptr> - // CHECK: call @triton_print_0 - tt.print " Type: uint8: " {hex = false, isSigned = array} : %3 : tensor<10xi8> - tt.return - } -} - -// ----- - -module { - // CHECK: func.func private @triton_print_0(tensor<10xi16>) attributes {hex = false, prefix = " Type: uint16: "} - // CHECK: func.func @device_print_kernel - tt.func public @device_print_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { - %0 = tt.make_range {end = 10 : i32, start = 0 : i32} : tensor<10xi32> - %1 = tt.splat %arg0 : !tt.ptr -> tensor<10x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<10x!tt.ptr>, tensor<10xi32> - %3 = tt.load %2 : tensor<10x!tt.ptr> - // CHECK: call @triton_print_0 - tt.print " Type: uint16: " {hex = false, isSigned = array} : %3 : tensor<10xi16> - tt.return - } -} - -// ----- - -module { - // CHECK: func.func private @triton_print_0(tensor<10xi32>) attributes {hex = false, prefix = " Type: uint32: "} - // CHECK: func.func @device_print_kernel - tt.func public @device_print_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { - %0 = tt.make_range {end = 10 : i32, start = 0 : i32} : tensor<10xi32> - %1 = tt.splat %arg0 : !tt.ptr -> tensor<10x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<10x!tt.ptr>, tensor<10xi32> - %3 = tt.load %2 : tensor<10x!tt.ptr> - // CHECK: call @triton_print_0 - tt.print " Type: uint32: " {hex = false, isSigned = array} : %3 : tensor<10xi32> - tt.return - } -} - -// ----- - -module { - // CHECK: func.func private @triton_print_0(tensor<10xi64>) attributes {hex = false, prefix = " Type: uint64: "} - // CHECK: func.func @device_print_kernel - tt.func public @device_print_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { - %0 = tt.make_range {end = 10 : i32, start = 0 : i32} : tensor<10xi32> - %1 = tt.splat %arg0 : !tt.ptr -> tensor<10x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<10x!tt.ptr>, tensor<10xi32> - %3 = tt.load %2 : tensor<10x!tt.ptr> - // CHECK: call @triton_print_0 - tt.print " Type: uint64: " {hex = false, isSigned = array} : %3 : tensor<10xi64> - tt.return - } -} - -// ----- - -module { - // CHECK: func.func private @triton_print_0(tensor<10xf32>) attributes {hex = false, prefix = " Type: float32: "} - // CHECK: func.func @device_print_kernel - tt.func public @device_print_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { - %0 = tt.make_range {end = 10 : i32, start = 0 : i32} : tensor<10xi32> - %1 = tt.splat %arg0 : !tt.ptr -> tensor<10x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<10x!tt.ptr>, tensor<10xi32> - %3 = tt.load %2 : tensor<10x!tt.ptr> - // CHECK: call @triton_print_0 - tt.print " Type: float32: " {hex = false, isSigned = array} : %3 : tensor<10xf32> - tt.return - } -} - -// ----- - -module { - // CHECK: func.func private @triton_print_0(tensor<10xf16>) attributes {hex = false, prefix = " Type: float16: "} - // CHECK: func.func @device_print_kernel - tt.func public @device_print_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { - %0 = tt.make_range {end = 10 : i32, start = 0 : i32} : tensor<10xi32> - %1 = tt.splat %arg0 : !tt.ptr -> tensor<10x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<10x!tt.ptr>, tensor<10xi32> - %3 = tt.load %2 : tensor<10x!tt.ptr> - // CHECK: call @triton_print_0 - tt.print " Type: float16: " {hex = false, isSigned = array} : %3 : tensor<10xf16> - tt.return - } -} - -// ----- - -module { - // CHECK: func.func private @triton_print_0(tensor<10xbf16>) attributes {hex = false, prefix = " Type: bfloat16: "} - // CHECK: func.func @device_print_kernel - tt.func public @device_print_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { - %0 = tt.make_range {end = 10 : i32, start = 0 : i32} : tensor<10xi32> - %1 = tt.splat %arg0 : !tt.ptr -> tensor<10x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<10x!tt.ptr>, tensor<10xi32> - %3 = tt.load %2 : tensor<10x!tt.ptr> - // CHECK: call @triton_print_0 - tt.print " Type: bfloat16: " {hex = false, isSigned = array} : %3 : tensor<10xbf16> - tt.return - } -} - -// ----- - -module { - // CHECK: func.func private @triton_assert_0(tensor<10xi1>) attributes {msg = "device_assert fail!"} - // CHECK: func.func private @triton_print_0(tensor<10xi1>) attributes {hex = false, prefix = " Type: bool (int1): "} - // CHECK: func.func @device_print_kernel - tt.func public @device_print_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { - %cst = arith.constant dense : tensor<10xi1> - %cst_0 = arith.constant dense<0> : tensor<10xi8> - %0 = tt.make_range {end = 10 : i32, start = 0 : i32} : tensor<10xi32> - %1 = tt.splat %arg0 : !tt.ptr -> tensor<10x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<10x!tt.ptr>, tensor<10xi32> - %3 = tt.bitcast %2 : tensor<10x!tt.ptr> -> tensor<10x!tt.ptr> - %4 = tt.load %3 : tensor<10x!tt.ptr> - %5 = arith.cmpi ne, %4, %cst_0 : tensor<10xi8> - // CHECK: call @triton_assert_0 - // CHECK: call @triton_print_0 - tt.assert %cst, "device_assert fail!" : tensor<10xi1> - tt.print " Type: bool (int1): " {hex = false, isSigned = array} : %5 : tensor<10xi1> - tt.return - } -} - -// ----- - -module { - // CHECK: func.func private @triton_print_0(tensor<16xf8E5M2>) - // CHECK: func.func @device_print_kernel - tt.func public @device_print_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { - %0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> - %1 = tt.splat %arg0 : !tt.ptr -> tensor<16x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<16x!tt.ptr>, tensor<16xi32> - %3 = tt.load %2 : tensor<16x!tt.ptr> - // CHECK: call @triton_print_0 - tt.print " val: " {hex = false, isSigned = array} : %3 : tensor<16xf8E5M2> - tt.return - } -} - -// ----- - -module { - // CHECK: func.func private @triton_print_0(tensor<16xf8E4M3FN>) attributes {hex = false, prefix = " val: "} - // CHECK: func.func @device_print_kernel - tt.func public @device_print_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { - %0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> - %1 = tt.splat %arg0 : !tt.ptr -> tensor<16x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<16x!tt.ptr>, tensor<16xi32> - %3 = tt.load %2 : tensor<16x!tt.ptr> - // CHECK: call @triton_print_0 - tt.print " val: " {hex = false, isSigned = array} : %3 : tensor<16xf8E4M3FN> - tt.return - } +// RUN: triton-adapter-opt --triton-linearize --discrete-mask-access-conversion --triton-to-annotation '--triton-to-unstructure=compile-on-910-95=False force-simt-template=False' --triton-to-hivm --triton-to-hfusion --triton-to-llvm --bubble-up-operation '--triton-to-linalg=global-kernel=false named-ops=True enable-nd2nz-on-vector=False compile-on-910-95=False' %s | FileCheck %s + +module { + // CHECK: func.func private @triton_print_0(tensor<10xi8>) attributes {hex = false, prefix = " Type: uint8: "} + // CHECK: func.func @device_print_kernel + tt.func public @device_print_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %0 = tt.make_range {end = 10 : i32, start = 0 : i32} : tensor<10xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<10x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<10x!tt.ptr>, tensor<10xi32> + %3 = tt.load %2 : tensor<10x!tt.ptr> + // CHECK: call @triton_print_0 + tt.print " Type: uint8: " {hex = false, isSigned = array} : %3 : tensor<10xi8> + tt.return + } +} + +// ----- + +module { + // CHECK: func.func private @triton_print_0(tensor<10xi16>) attributes {hex = false, prefix = " Type: uint16: "} + // CHECK: func.func @device_print_kernel + tt.func public @device_print_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %0 = tt.make_range {end = 10 : i32, start = 0 : i32} : tensor<10xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<10x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<10x!tt.ptr>, tensor<10xi32> + %3 = tt.load %2 : tensor<10x!tt.ptr> + // CHECK: call @triton_print_0 + tt.print " Type: uint16: " {hex = false, isSigned = array} : %3 : tensor<10xi16> + tt.return + } +} + +// ----- + +module { + // CHECK: func.func private @triton_print_0(tensor<10xi32>) attributes {hex = false, prefix = " Type: uint32: "} + // CHECK: func.func @device_print_kernel + tt.func public @device_print_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %0 = tt.make_range {end = 10 : i32, start = 0 : i32} : tensor<10xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<10x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<10x!tt.ptr>, tensor<10xi32> + %3 = tt.load %2 : tensor<10x!tt.ptr> + // CHECK: call @triton_print_0 + tt.print " Type: uint32: " {hex = false, isSigned = array} : %3 : tensor<10xi32> + tt.return + } +} + +// ----- + +module { + // CHECK: func.func private @triton_print_0(tensor<10xi64>) attributes {hex = false, prefix = " Type: uint64: "} + // CHECK: func.func @device_print_kernel + tt.func public @device_print_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %0 = tt.make_range {end = 10 : i32, start = 0 : i32} : tensor<10xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<10x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<10x!tt.ptr>, tensor<10xi32> + %3 = tt.load %2 : tensor<10x!tt.ptr> + // CHECK: call @triton_print_0 + tt.print " Type: uint64: " {hex = false, isSigned = array} : %3 : tensor<10xi64> + tt.return + } +} + +// ----- + +module { + // CHECK: func.func private @triton_print_0(tensor<10xf32>) attributes {hex = false, prefix = " Type: float32: "} + // CHECK: func.func @device_print_kernel + tt.func public @device_print_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %0 = tt.make_range {end = 10 : i32, start = 0 : i32} : tensor<10xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<10x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<10x!tt.ptr>, tensor<10xi32> + %3 = tt.load %2 : tensor<10x!tt.ptr> + // CHECK: call @triton_print_0 + tt.print " Type: float32: " {hex = false, isSigned = array} : %3 : tensor<10xf32> + tt.return + } +} + +// ----- + +module { + // CHECK: func.func private @triton_print_0(tensor<10xf16>) attributes {hex = false, prefix = " Type: float16: "} + // CHECK: func.func @device_print_kernel + tt.func public @device_print_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %0 = tt.make_range {end = 10 : i32, start = 0 : i32} : tensor<10xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<10x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<10x!tt.ptr>, tensor<10xi32> + %3 = tt.load %2 : tensor<10x!tt.ptr> + // CHECK: call @triton_print_0 + tt.print " Type: float16: " {hex = false, isSigned = array} : %3 : tensor<10xf16> + tt.return + } +} + +// ----- + +module { + // CHECK: func.func private @triton_print_0(tensor<10xbf16>) attributes {hex = false, prefix = " Type: bfloat16: "} + // CHECK: func.func @device_print_kernel + tt.func public @device_print_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %0 = tt.make_range {end = 10 : i32, start = 0 : i32} : tensor<10xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<10x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<10x!tt.ptr>, tensor<10xi32> + %3 = tt.load %2 : tensor<10x!tt.ptr> + // CHECK: call @triton_print_0 + tt.print " Type: bfloat16: " {hex = false, isSigned = array} : %3 : tensor<10xbf16> + tt.return + } +} + +// ----- + +module { + // CHECK: func.func private @triton_assert_0(tensor<10xi1>) attributes {msg = "device_assert fail!"} + // CHECK: func.func private @triton_print_0(tensor<10xi1>) attributes {hex = false, prefix = " Type: bool (int1): "} + // CHECK: func.func @device_print_kernel + tt.func public @device_print_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense : tensor<10xi1> + %cst_0 = arith.constant dense<0> : tensor<10xi8> + %0 = tt.make_range {end = 10 : i32, start = 0 : i32} : tensor<10xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<10x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<10x!tt.ptr>, tensor<10xi32> + %3 = tt.bitcast %2 : tensor<10x!tt.ptr> -> tensor<10x!tt.ptr> + %4 = tt.load %3 : tensor<10x!tt.ptr> + %5 = arith.cmpi ne, %4, %cst_0 : tensor<10xi8> + // CHECK: call @triton_assert_0 + // CHECK: call @triton_print_0 + tt.assert %cst, "device_assert fail!" : tensor<10xi1> + tt.print " Type: bool (int1): " {hex = false, isSigned = array} : %5 : tensor<10xi1> + tt.return + } +} + +// ----- + +module { + // CHECK: func.func private @triton_print_0(tensor<16xf8E5M2>) + // CHECK: func.func @device_print_kernel + tt.func public @device_print_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<16x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<16x!tt.ptr>, tensor<16xi32> + %3 = tt.load %2 : tensor<16x!tt.ptr> + // CHECK: call @triton_print_0 + tt.print " val: " {hex = false, isSigned = array} : %3 : tensor<16xf8E5M2> + tt.return + } +} + +// ----- + +module { + // CHECK: func.func private @triton_print_0(tensor<16xf8E4M3FN>) attributes {hex = false, prefix = " val: "} + // CHECK: func.func @device_print_kernel + tt.func public @device_print_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<16x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<16x!tt.ptr>, tensor<16xi32> + %3 = tt.load %2 : tensor<16x!tt.ptr> + // CHECK: call @triton_print_0 + tt.print " val: " {hex = false, isSigned = array} : %3 : tensor<16xf8E4M3FN> + tt.return + } } diff --git a/third_party/ascend/test/Conversion/TritonOp/full.mlir b/third_party/ascend/test/Conversion/TritonOp/full.mlir index 4c0825f86..8846f4f8e 100644 --- a/third_party/ascend/test/Conversion/TritonOp/full.mlir +++ b/third_party/ascend/test/Conversion/TritonOp/full.mlir @@ -1,189 +1,189 @@ -// RUN: triton-adapter-opt --triton-linearize --discrete-mask-access-conversion --triton-to-annotation --triton-to-unstructure --triton-to-hivm --triton-to-hfusion --triton-to-llvm --bubble-up-operation '--triton-to-linalg=global-kernel=false named-ops=True enable-nd2nz-on-vector=False' --split-input-file %s | FileCheck %s - -// === i8 u8 version === -module { - tt.func public @fn_npu_u8(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { - %cst = arith.constant dense<100> : tensor<8x8x4xi8> - %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> - %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> - %2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> - %3 = tt.expand_dims %2 {axis = 2 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> - %4 = tt.expand_dims %0 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> - %5 = tt.expand_dims %4 {axis = 2 : i32} : tensor<1x8xi32> -> tensor<1x8x1xi32> - %6 = tt.expand_dims %1 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> - %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<1x4xi32> -> tensor<1x1x4xi32> - %8 = tt.broadcast %3 : tensor<8x1x1xi32> -> tensor<8x8x1xi32> - %9 = tt.broadcast %5 : tensor<1x8x1xi32> -> tensor<8x8x1xi32> - %10 = arith.addi %8, %9 : tensor<8x8x1xi32> - %11 = tt.broadcast %10 : tensor<8x8x1xi32> -> tensor<8x8x4xi32> - %12 = tt.broadcast %7 : tensor<1x1x4xi32> -> tensor<8x8x4xi32> - %13 = arith.addi %11, %12 : tensor<8x8x4xi32> - %14 = tt.splat %arg0 : !tt.ptr -> tensor<8x8x4x!tt.ptr> - %15 = tt.addptr %14, %13 : tensor<8x8x4x!tt.ptr>, tensor<8x8x4xi32> - tt.store %15, %cst : tensor<8x8x4x!tt.ptr> - tt.return - } -} - -// CHECK: arith.constant 100 : i8 -// CHECK: tensor.empty() : tensor<8x8x4xi8> -// CHECK: linalg.fill ins(%{{.*}} : i8) outs(%{{.*}} : tensor<8x8x4xi8>) -> tensor<8x8x4xi8> -// CHECK: memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [8, 8, 4] -// CHECK: bufferization.materialize_in_destination %{{.*}} in writable %{{.*}} - -// === i16 u16 version === -module { - tt.func public @fn_npu_u16(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { - %cst = arith.constant dense<100> : tensor<8x8x4xi16> - %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> - %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> - %2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> - %3 = tt.expand_dims %2 {axis = 2 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> - %4 = tt.expand_dims %0 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> - %5 = tt.expand_dims %4 {axis = 2 : i32} : tensor<1x8xi32> -> tensor<1x8x1xi32> - %6 = tt.expand_dims %1 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> - %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<1x4xi32> -> tensor<1x1x4xi32> - %8 = tt.broadcast %3 : tensor<8x1x1xi32> -> tensor<8x8x1xi32> - %9 = tt.broadcast %5 : tensor<1x8x1xi32> -> tensor<8x8x1xi32> - %10 = arith.addi %8, %9 : tensor<8x8x1xi32> - %11 = tt.broadcast %10 : tensor<8x8x1xi32> -> tensor<8x8x4xi32> - %12 = tt.broadcast %7 : tensor<1x1x4xi32> -> tensor<8x8x4xi32> - %13 = arith.addi %11, %12 : tensor<8x8x4xi32> - %14 = tt.splat %arg0 : !tt.ptr -> tensor<8x8x4x!tt.ptr> - %15 = tt.addptr %14, %13 : tensor<8x8x4x!tt.ptr>, tensor<8x8x4xi32> - tt.store %15, %cst : tensor<8x8x4x!tt.ptr> - tt.return - } -} - -// CHECK: arith.constant 100 : i16 -// CHECK: tensor.empty() : tensor<8x8x4xi16> -// CHECK: linalg.fill ins(%{{.*}} : i16) outs(%{{.*}} : tensor<8x8x4xi16>) -> tensor<8x8x4xi16> -// CHECK: memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [8, 8, 4] -// CHECK: bufferization.materialize_in_destination %{{.*}} in writable %{{.*}} - -// === i32 u32 version === -module { - tt.func public @fn_npu_i32(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { - %cst = arith.constant dense<100> : tensor<8x8x4xi32> - %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> - %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> - %2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> - %3 = tt.expand_dims %2 {axis = 2 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> - %4 = tt.expand_dims %0 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> - %5 = tt.expand_dims %4 {axis = 2 : i32} : tensor<1x8xi32> -> tensor<1x8x1xi32> - %6 = tt.expand_dims %1 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> - %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<1x4xi32> -> tensor<1x1x4xi32> - %8 = tt.broadcast %3 : tensor<8x1x1xi32> -> tensor<8x8x1xi32> - %9 = tt.broadcast %5 : tensor<1x8x1xi32> -> tensor<8x8x1xi32> - %10 = arith.addi %8, %9 : tensor<8x8x1xi32> - %11 = tt.broadcast %10 : tensor<8x8x1xi32> -> tensor<8x8x4xi32> - %12 = tt.broadcast %7 : tensor<1x1x4xi32> -> tensor<8x8x4xi32> - %13 = arith.addi %11, %12 : tensor<8x8x4xi32> - %14 = tt.splat %arg0 : !tt.ptr -> tensor<8x8x4x!tt.ptr> - %15 = tt.addptr %14, %13 : tensor<8x8x4x!tt.ptr>, tensor<8x8x4xi32> - tt.store %15, %cst : tensor<8x8x4x!tt.ptr> - tt.return - } -} - -// CHECK: arith.constant 100 : i32 -// CHECK: tensor.empty() : tensor<8x8x4xi32> -// CHECK: linalg.fill ins(%{{.*}} : i32) outs(%{{.*}} : tensor<8x8x4xi32>) -> tensor<8x8x4xi32> -// CHECK: memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [8, 8, 4] -// CHECK: bufferization.materialize_in_destination %{{.*}} in writable %{{.*}} - -// === i64 u64 version === -module { - tt.func public @fn_npu_i64(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { - %cst = arith.constant dense<100> : tensor<8x8x4xi64> - %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> - %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> - %2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> - %3 = tt.expand_dims %2 {axis = 2 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> - %4 = tt.expand_dims %0 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> - %5 = tt.expand_dims %4 {axis = 2 : i32} : tensor<1x8xi32> -> tensor<1x8x1xi32> - %6 = tt.expand_dims %1 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> - %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<1x4xi32> -> tensor<1x1x4xi32> - %8 = tt.broadcast %3 : tensor<8x1x1xi32> -> tensor<8x8x1xi32> - %9 = tt.broadcast %5 : tensor<1x8x1xi32> -> tensor<8x8x1xi32> - %10 = arith.addi %8, %9 : tensor<8x8x1xi32> - %11 = tt.broadcast %10 : tensor<8x8x1xi32> -> tensor<8x8x4xi32> - %12 = tt.broadcast %7 : tensor<1x1x4xi32> -> tensor<8x8x4xi32> - %13 = arith.addi %11, %12 : tensor<8x8x4xi32> - %14 = tt.splat %arg0 : !tt.ptr -> tensor<8x8x4x!tt.ptr> - %15 = tt.addptr %14, %13 : tensor<8x8x4x!tt.ptr>, tensor<8x8x4xi32> - tt.store %15, %cst : tensor<8x8x4x!tt.ptr> - tt.return - } -} - -// CHECK: arith.constant 100 : i64 -// CHECK: tensor.empty() : tensor<8x8x4xi64> -// CHECK: linalg.fill ins(%{{.*}} : i64) outs(%{{.*}} : tensor<8x8x4xi64>) -> tensor<8x8x4xi64> -// CHECK: memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [8, 8, 4] -// CHECK: bufferization.materialize_in_destination %{{.*}} in writable %{{.*}} - - -// === float8_e4m3fn version === -module { - tt.func public @fn_npu_f8E4M3FN(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { - %cst = arith.constant dense<100> : tensor<8x8x4xf8E4M3FN> - %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> - %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> - %2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> - %3 = tt.expand_dims %2 {axis = 2 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> - %4 = tt.expand_dims %0 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> - %5 = tt.expand_dims %4 {axis = 2 : i32} : tensor<1x8xi32> -> tensor<1x8x1xi32> - %6 = tt.expand_dims %1 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> - %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<1x4xi32> -> tensor<1x1x4xi32> - %8 = tt.broadcast %3 : tensor<8x1x1xi32> -> tensor<8x8x1xi32> - %9 = tt.broadcast %5 : tensor<1x8x1xi32> -> tensor<8x8x1xi32> - %10 = arith.addi %8, %9 : tensor<8x8x1xi32> - %11 = tt.broadcast %10 : tensor<8x8x1xi32> -> tensor<8x8x4xi32> - %12 = tt.broadcast %7 : tensor<1x1x4xi32> -> tensor<8x8x4xi32> - %13 = arith.addi %11, %12 : tensor<8x8x4xi32> - %14 = tt.splat %arg0 : !tt.ptr -> tensor<8x8x4x!tt.ptr> - %15 = tt.addptr %14, %13 : tensor<8x8x4x!tt.ptr>, tensor<8x8x4xi32> - tt.store %15, %cst : tensor<8x8x4x!tt.ptr> - tt.return - } -} - -// CHECK: arith.constant 100.0 : f8E4M3FN -// CHECK: tensor.empty() : tensor<8x8x4xf8E4M3FN> -// CHECK: linalg.fill ins(%{{.*}} : f8E4M3FN) outs(%{{.*}} : tensor<8x8x4xf8E4M3FN>) -> tensor<8x8x4xf8E4M3FN> -// CHECK: memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [8, 8, 4] -// CHECK: bufferization.materialize_in_destination %{{.*}} in writable %{{.*}} - - -// === float8_e5m2 version === -module { - tt.func public @fn_npu_f8E5M2(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { - %cst = arith.constant dense<100> : tensor<8x8x4xf8E5M2> - %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> - %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> - %2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> - %3 = tt.expand_dims %2 {axis = 2 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> - %4 = tt.expand_dims %0 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> - %5 = tt.expand_dims %4 {axis = 2 : i32} : tensor<1x8xi32> -> tensor<1x8x1xi32> - %6 = tt.expand_dims %1 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> - %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<1x4xi32> -> tensor<1x1x4xi32> - %8 = tt.broadcast %3 : tensor<8x1x1xi32> -> tensor<8x8x1xi32> - %9 = tt.broadcast %5 : tensor<1x8x1xi32> -> tensor<8x8x1xi32> - %10 = arith.addi %8, %9 : tensor<8x8x1xi32> - %11 = tt.broadcast %10 : tensor<8x8x1xi32> -> tensor<8x8x4xi32> - %12 = tt.broadcast %7 : tensor<1x1x4xi32> -> tensor<8x8x4xi32> - %13 = arith.addi %11, %12 : tensor<8x8x4xi32> - %14 = tt.splat %arg0 : !tt.ptr -> tensor<8x8x4x!tt.ptr> - %15 = tt.addptr %14, %13 : tensor<8x8x4x!tt.ptr>, tensor<8x8x4xi32> - tt.store %15, %cst : tensor<8x8x4x!tt.ptr> - tt.return - } -} - -// CHECK: arith.constant 100.0 : f8E5M2 -// CHECK: tensor.empty() : tensor<8x8x4xf8E5M2> -// CHECK: linalg.fill ins(%{{.*}} : f8E5M2) outs(%{{.*}} : tensor<8x8x4xf8E5M2>) -> tensor<8x8x4xf8E5M2> -// CHECK: memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [8, 8, 4] +// RUN: triton-adapter-opt --triton-linearize --discrete-mask-access-conversion --triton-to-annotation --triton-to-unstructure --triton-to-hivm --triton-to-hfusion --triton-to-llvm --bubble-up-operation '--triton-to-linalg=global-kernel=false named-ops=True enable-nd2nz-on-vector=False' --split-input-file %s | FileCheck %s + +// === i8 u8 version === +module { + tt.func public @fn_npu_u8(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<100> : tensor<8x8x4xi8> + %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> + %3 = tt.expand_dims %2 {axis = 2 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> + %4 = tt.expand_dims %0 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> + %5 = tt.expand_dims %4 {axis = 2 : i32} : tensor<1x8xi32> -> tensor<1x8x1xi32> + %6 = tt.expand_dims %1 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> + %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<1x4xi32> -> tensor<1x1x4xi32> + %8 = tt.broadcast %3 : tensor<8x1x1xi32> -> tensor<8x8x1xi32> + %9 = tt.broadcast %5 : tensor<1x8x1xi32> -> tensor<8x8x1xi32> + %10 = arith.addi %8, %9 : tensor<8x8x1xi32> + %11 = tt.broadcast %10 : tensor<8x8x1xi32> -> tensor<8x8x4xi32> + %12 = tt.broadcast %7 : tensor<1x1x4xi32> -> tensor<8x8x4xi32> + %13 = arith.addi %11, %12 : tensor<8x8x4xi32> + %14 = tt.splat %arg0 : !tt.ptr -> tensor<8x8x4x!tt.ptr> + %15 = tt.addptr %14, %13 : tensor<8x8x4x!tt.ptr>, tensor<8x8x4xi32> + tt.store %15, %cst : tensor<8x8x4x!tt.ptr> + tt.return + } +} + +// CHECK: arith.constant 100 : i8 +// CHECK: tensor.empty() : tensor<8x8x4xi8> +// CHECK: linalg.fill ins(%{{.*}} : i8) outs(%{{.*}} : tensor<8x8x4xi8>) -> tensor<8x8x4xi8> +// CHECK: memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [8, 8, 4] +// CHECK: bufferization.materialize_in_destination %{{.*}} in writable %{{.*}} + +// === i16 u16 version === +module { + tt.func public @fn_npu_u16(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<100> : tensor<8x8x4xi16> + %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> + %3 = tt.expand_dims %2 {axis = 2 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> + %4 = tt.expand_dims %0 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> + %5 = tt.expand_dims %4 {axis = 2 : i32} : tensor<1x8xi32> -> tensor<1x8x1xi32> + %6 = tt.expand_dims %1 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> + %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<1x4xi32> -> tensor<1x1x4xi32> + %8 = tt.broadcast %3 : tensor<8x1x1xi32> -> tensor<8x8x1xi32> + %9 = tt.broadcast %5 : tensor<1x8x1xi32> -> tensor<8x8x1xi32> + %10 = arith.addi %8, %9 : tensor<8x8x1xi32> + %11 = tt.broadcast %10 : tensor<8x8x1xi32> -> tensor<8x8x4xi32> + %12 = tt.broadcast %7 : tensor<1x1x4xi32> -> tensor<8x8x4xi32> + %13 = arith.addi %11, %12 : tensor<8x8x4xi32> + %14 = tt.splat %arg0 : !tt.ptr -> tensor<8x8x4x!tt.ptr> + %15 = tt.addptr %14, %13 : tensor<8x8x4x!tt.ptr>, tensor<8x8x4xi32> + tt.store %15, %cst : tensor<8x8x4x!tt.ptr> + tt.return + } +} + +// CHECK: arith.constant 100 : i16 +// CHECK: tensor.empty() : tensor<8x8x4xi16> +// CHECK: linalg.fill ins(%{{.*}} : i16) outs(%{{.*}} : tensor<8x8x4xi16>) -> tensor<8x8x4xi16> +// CHECK: memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [8, 8, 4] +// CHECK: bufferization.materialize_in_destination %{{.*}} in writable %{{.*}} + +// === i32 u32 version === +module { + tt.func public @fn_npu_i32(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<100> : tensor<8x8x4xi32> + %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> + %3 = tt.expand_dims %2 {axis = 2 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> + %4 = tt.expand_dims %0 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> + %5 = tt.expand_dims %4 {axis = 2 : i32} : tensor<1x8xi32> -> tensor<1x8x1xi32> + %6 = tt.expand_dims %1 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> + %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<1x4xi32> -> tensor<1x1x4xi32> + %8 = tt.broadcast %3 : tensor<8x1x1xi32> -> tensor<8x8x1xi32> + %9 = tt.broadcast %5 : tensor<1x8x1xi32> -> tensor<8x8x1xi32> + %10 = arith.addi %8, %9 : tensor<8x8x1xi32> + %11 = tt.broadcast %10 : tensor<8x8x1xi32> -> tensor<8x8x4xi32> + %12 = tt.broadcast %7 : tensor<1x1x4xi32> -> tensor<8x8x4xi32> + %13 = arith.addi %11, %12 : tensor<8x8x4xi32> + %14 = tt.splat %arg0 : !tt.ptr -> tensor<8x8x4x!tt.ptr> + %15 = tt.addptr %14, %13 : tensor<8x8x4x!tt.ptr>, tensor<8x8x4xi32> + tt.store %15, %cst : tensor<8x8x4x!tt.ptr> + tt.return + } +} + +// CHECK: arith.constant 100 : i32 +// CHECK: tensor.empty() : tensor<8x8x4xi32> +// CHECK: linalg.fill ins(%{{.*}} : i32) outs(%{{.*}} : tensor<8x8x4xi32>) -> tensor<8x8x4xi32> +// CHECK: memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [8, 8, 4] +// CHECK: bufferization.materialize_in_destination %{{.*}} in writable %{{.*}} + +// === i64 u64 version === +module { + tt.func public @fn_npu_i64(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<100> : tensor<8x8x4xi64> + %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> + %3 = tt.expand_dims %2 {axis = 2 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> + %4 = tt.expand_dims %0 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> + %5 = tt.expand_dims %4 {axis = 2 : i32} : tensor<1x8xi32> -> tensor<1x8x1xi32> + %6 = tt.expand_dims %1 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> + %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<1x4xi32> -> tensor<1x1x4xi32> + %8 = tt.broadcast %3 : tensor<8x1x1xi32> -> tensor<8x8x1xi32> + %9 = tt.broadcast %5 : tensor<1x8x1xi32> -> tensor<8x8x1xi32> + %10 = arith.addi %8, %9 : tensor<8x8x1xi32> + %11 = tt.broadcast %10 : tensor<8x8x1xi32> -> tensor<8x8x4xi32> + %12 = tt.broadcast %7 : tensor<1x1x4xi32> -> tensor<8x8x4xi32> + %13 = arith.addi %11, %12 : tensor<8x8x4xi32> + %14 = tt.splat %arg0 : !tt.ptr -> tensor<8x8x4x!tt.ptr> + %15 = tt.addptr %14, %13 : tensor<8x8x4x!tt.ptr>, tensor<8x8x4xi32> + tt.store %15, %cst : tensor<8x8x4x!tt.ptr> + tt.return + } +} + +// CHECK: arith.constant 100 : i64 +// CHECK: tensor.empty() : tensor<8x8x4xi64> +// CHECK: linalg.fill ins(%{{.*}} : i64) outs(%{{.*}} : tensor<8x8x4xi64>) -> tensor<8x8x4xi64> +// CHECK: memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [8, 8, 4] +// CHECK: bufferization.materialize_in_destination %{{.*}} in writable %{{.*}} + + +// === float8_e4m3fn version === +module { + tt.func public @fn_npu_f8E4M3FN(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<100> : tensor<8x8x4xf8E4M3FN> + %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> + %3 = tt.expand_dims %2 {axis = 2 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> + %4 = tt.expand_dims %0 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> + %5 = tt.expand_dims %4 {axis = 2 : i32} : tensor<1x8xi32> -> tensor<1x8x1xi32> + %6 = tt.expand_dims %1 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> + %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<1x4xi32> -> tensor<1x1x4xi32> + %8 = tt.broadcast %3 : tensor<8x1x1xi32> -> tensor<8x8x1xi32> + %9 = tt.broadcast %5 : tensor<1x8x1xi32> -> tensor<8x8x1xi32> + %10 = arith.addi %8, %9 : tensor<8x8x1xi32> + %11 = tt.broadcast %10 : tensor<8x8x1xi32> -> tensor<8x8x4xi32> + %12 = tt.broadcast %7 : tensor<1x1x4xi32> -> tensor<8x8x4xi32> + %13 = arith.addi %11, %12 : tensor<8x8x4xi32> + %14 = tt.splat %arg0 : !tt.ptr -> tensor<8x8x4x!tt.ptr> + %15 = tt.addptr %14, %13 : tensor<8x8x4x!tt.ptr>, tensor<8x8x4xi32> + tt.store %15, %cst : tensor<8x8x4x!tt.ptr> + tt.return + } +} + +// CHECK: arith.constant 100.0 : f8E4M3FN +// CHECK: tensor.empty() : tensor<8x8x4xf8E4M3FN> +// CHECK: linalg.fill ins(%{{.*}} : f8E4M3FN) outs(%{{.*}} : tensor<8x8x4xf8E4M3FN>) -> tensor<8x8x4xf8E4M3FN> +// CHECK: memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [8, 8, 4] +// CHECK: bufferization.materialize_in_destination %{{.*}} in writable %{{.*}} + + +// === float8_e5m2 version === +module { + tt.func public @fn_npu_f8E5M2(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<100> : tensor<8x8x4xf8E5M2> + %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> + %3 = tt.expand_dims %2 {axis = 2 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> + %4 = tt.expand_dims %0 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> + %5 = tt.expand_dims %4 {axis = 2 : i32} : tensor<1x8xi32> -> tensor<1x8x1xi32> + %6 = tt.expand_dims %1 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> + %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<1x4xi32> -> tensor<1x1x4xi32> + %8 = tt.broadcast %3 : tensor<8x1x1xi32> -> tensor<8x8x1xi32> + %9 = tt.broadcast %5 : tensor<1x8x1xi32> -> tensor<8x8x1xi32> + %10 = arith.addi %8, %9 : tensor<8x8x1xi32> + %11 = tt.broadcast %10 : tensor<8x8x1xi32> -> tensor<8x8x4xi32> + %12 = tt.broadcast %7 : tensor<1x1x4xi32> -> tensor<8x8x4xi32> + %13 = arith.addi %11, %12 : tensor<8x8x4xi32> + %14 = tt.splat %arg0 : !tt.ptr -> tensor<8x8x4x!tt.ptr> + %15 = tt.addptr %14, %13 : tensor<8x8x4x!tt.ptr>, tensor<8x8x4xi32> + tt.store %15, %cst : tensor<8x8x4x!tt.ptr> + tt.return + } +} + +// CHECK: arith.constant 100.0 : f8E5M2 +// CHECK: tensor.empty() : tensor<8x8x4xf8E5M2> +// CHECK: linalg.fill ins(%{{.*}} : f8E5M2) outs(%{{.*}} : tensor<8x8x4xf8E5M2>) -> tensor<8x8x4xf8E5M2> +// CHECK: memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [8, 8, 4] // CHECK: bufferization.materialize_in_destination %{{.*}} in writable %{{.*}} diff --git a/third_party/ascend/test/Conversion/TritonOp/gather.mlir b/third_party/ascend/test/Conversion/TritonOp/gather.mlir index 0eaedc36a..3e4a26d71 100644 --- a/third_party/ascend/test/Conversion/TritonOp/gather.mlir +++ b/third_party/ascend/test/Conversion/TritonOp/gather.mlir @@ -1,137 +1,137 @@ -// RUN: triton-adapter-opt %s --triton-linearize '--discrete-mask-access-conversion=compile-on-910-95=False force-simt-template=False' --triton-to-annotation '--triton-to-unstructure=compile-on-910-95=False force-simt-template=False' --triton-to-hivm --triton-to-hfusion --triton-to-llvm --bubble-up-operation '--triton-to-linalg=global-kernel=false named-ops=True enable-nd2nz-on-vector=False compile-on-910-95=False' --split-input-file | FileCheck %s - -// dtype: bool -module { - tt.func public @triton_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32} , %arg1: !tt.ptr {tt.divisibility = 16 : i32} , %arg2: !tt.ptr {tt.divisibility = 16 : i32} ) attributes {noinline = false} { - %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> - %1 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<8x!tt.ptr>, tensor<8xi32> - %3 = tt.load %2 : tensor<8x!tt.ptr> - %4 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> - %5 = tt.splat %arg2 : !tt.ptr -> tensor<4x!tt.ptr> - %6 = tt.addptr %5, %4 : tensor<4x!tt.ptr>, tensor<4xi32> - %7 = tt.load %6 : tensor<4x!tt.ptr> - %8 = tt.gather %3[%7] {axis = 0 : i32} : (tensor<8xi8>, tensor<4xi32>) -> tensor<4xi8> - %9 = tt.splat %arg1 : !tt.ptr -> tensor<4x!tt.ptr> - %10 = tt.addptr %9, %4 : tensor<4x!tt.ptr>, tensor<4xi32> - tt.store %10, %8 : tensor<4x!tt.ptr> - tt.return - } -} - -// CHECK: %[[RESULT:.*]] = call @triton_gather(%[[SOURCE:.*]], %[[INDICES:.*]], %[[DIMENSION:.*]]) : (tensor<8xi8>, tensor<4xi32>, i32) -> tensor<4xi8> - -// ----- - -// dtype: float16 -module { - tt.func public @triton_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32} , %arg1: !tt.ptr {tt.divisibility = 16 : i32} , %arg2: !tt.ptr {tt.divisibility = 16 : i32} ) attributes {noinline = false} { - %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> - %1 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<8x!tt.ptr>, tensor<8xi32> - %3 = tt.load %2 : tensor<8x!tt.ptr> - %4 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> - %5 = tt.splat %arg2 : !tt.ptr -> tensor<4x!tt.ptr> - %6 = tt.addptr %5, %4 : tensor<4x!tt.ptr>, tensor<4xi32> - %7 = tt.load %6 : tensor<4x!tt.ptr> - %8 = tt.gather %3[%7] {axis = 0 : i32} : (tensor<8xf16>, tensor<4xi32>) -> tensor<4xf16> - %9 = tt.splat %arg1 : !tt.ptr -> tensor<4x!tt.ptr> - %10 = tt.addptr %9, %4 : tensor<4x!tt.ptr>, tensor<4xi32> - tt.store %10, %8 : tensor<4x!tt.ptr> - tt.return - } -} - -// CHECK: %[[RESULT:.*]] = call @triton_gather(%[[SOURCE:.*]], %[[INDICES:.*]], %[[DIMENSION:.*]]) : (tensor<8xf16>, tensor<4xi32>, i32) -> tensor<4xf16> - -// ----- - -// dtype: float32 -module { - tt.func public @triton_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32} , %arg1: !tt.ptr {tt.divisibility = 16 : i32} , %arg2: !tt.ptr {tt.divisibility = 16 : i32} ) attributes {noinline = false} { - %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> - %1 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<8x!tt.ptr>, tensor<8xi32> - %3 = tt.load %2 : tensor<8x!tt.ptr> - %4 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> - %5 = tt.splat %arg2 : !tt.ptr -> tensor<4x!tt.ptr> - %6 = tt.addptr %5, %4 : tensor<4x!tt.ptr>, tensor<4xi32> - %7 = tt.load %6 : tensor<4x!tt.ptr> - %8 = tt.gather %3[%7] {axis = 0 : i32} : (tensor<8xf32>, tensor<4xi32>) -> tensor<4xf32> - %9 = tt.splat %arg1 : !tt.ptr -> tensor<4x!tt.ptr> - %10 = tt.addptr %9, %4 : tensor<4x!tt.ptr>, tensor<4xi32> - tt.store %10, %8 : tensor<4x!tt.ptr> - tt.return - } -} - -// CHECK: %[[RESULT:.*]] = call @triton_gather(%[[SOURCE:.*]], %[[INDICES:.*]], %[[DIMENSION:.*]]) : (tensor<8xf32>, tensor<4xi32>, i32) -> tensor<4xf32> - -// ----- - -// dtype: bfloat16 -module { - tt.func public @triton_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32} , %arg1: !tt.ptr {tt.divisibility = 16 : i32} , %arg2: !tt.ptr {tt.divisibility = 16 : i32} ) attributes {noinline = false} { - %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> - %1 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<8x!tt.ptr>, tensor<8xi32> - %3 = tt.load %2 : tensor<8x!tt.ptr> - %4 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> - %5 = tt.splat %arg2 : !tt.ptr -> tensor<4x!tt.ptr> - %6 = tt.addptr %5, %4 : tensor<4x!tt.ptr>, tensor<4xi32> - %7 = tt.load %6 : tensor<4x!tt.ptr> - %8 = tt.gather %3[%7] {axis = 0 : i32} : (tensor<8xbf16>, tensor<4xi32>) -> tensor<4xbf16> - %9 = tt.splat %arg1 : !tt.ptr -> tensor<4x!tt.ptr> - %10 = tt.addptr %9, %4 : tensor<4x!tt.ptr>, tensor<4xi32> - tt.store %10, %8 : tensor<4x!tt.ptr> - tt.return - } -} - -// CHECK: %[[RESULT:.*]] = call @triton_gather(%[[SOURCE:.*]], %[[INDICES:.*]], %[[DIMENSION:.*]]) : (tensor<8xbf16>, tensor<4xi32>, i32) -> tensor<4xbf16> - -// ----- - -// dtype: float8_e4m3 -module { - tt.func public @triton_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32} , %arg1: !tt.ptr {tt.divisibility = 16 : i32} , %arg2: !tt.ptr {tt.divisibility = 16 : i32} ) attributes {noinline = false} { - %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> - %1 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<8x!tt.ptr>, tensor<8xi32> - %3 = tt.load %2 : tensor<8x!tt.ptr> - %4 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> - %5 = tt.splat %arg2 : !tt.ptr -> tensor<4x!tt.ptr> - %6 = tt.addptr %5, %4 : tensor<4x!tt.ptr>, tensor<4xi32> - %7 = tt.load %6 : tensor<4x!tt.ptr> - %8 = tt.gather %3[%7] {axis = 0 : i32} : (tensor<8xf8E4M3FN>, tensor<4xi32>) -> tensor<4xf8E4M3FN> - %9 = tt.splat %arg1 : !tt.ptr -> tensor<4x!tt.ptr> - %10 = tt.addptr %9, %4 : tensor<4x!tt.ptr>, tensor<4xi32> - tt.store %10, %8 : tensor<4x!tt.ptr> - tt.return - } -} - -// CHECK: %[[RESULT:.*]] = call @triton_gather(%[[SOURCE:.*]], %[[INDICES:.*]], %[[DIMENSION:.*]]) : (tensor<8xf8E4M3FN>, tensor<4xi32>, i32) -> tensor<4xf8E4M3FN> - -// ----- - -// dtype: float8_e5m2 -module { - tt.func public @triton_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32} , %arg1: !tt.ptr {tt.divisibility = 16 : i32} , %arg2: !tt.ptr {tt.divisibility = 16 : i32} ) attributes {noinline = false} { - %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> - %1 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<8x!tt.ptr>, tensor<8xi32> - %3 = tt.load %2 : tensor<8x!tt.ptr> - %4 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> - %5 = tt.splat %arg2 : !tt.ptr -> tensor<4x!tt.ptr> - %6 = tt.addptr %5, %4 : tensor<4x!tt.ptr>, tensor<4xi32> - %7 = tt.load %6 : tensor<4x!tt.ptr> - %8 = tt.gather %3[%7] {axis = 0 : i32} : (tensor<8xf8E5M2>, tensor<4xi32>) -> tensor<4xf8E5M2> - %9 = tt.splat %arg1 : !tt.ptr -> tensor<4x!tt.ptr> - %10 = tt.addptr %9, %4 : tensor<4x!tt.ptr>, tensor<4xi32> - tt.store %10, %8 : tensor<4x!tt.ptr> - tt.return - } -} - +// RUN: triton-adapter-opt %s --triton-linearize '--discrete-mask-access-conversion=compile-on-910-95=False force-simt-template=False' --triton-to-annotation '--triton-to-unstructure=compile-on-910-95=False force-simt-template=False' --triton-to-hivm --triton-to-hfusion --triton-to-llvm --bubble-up-operation '--triton-to-linalg=global-kernel=false named-ops=True enable-nd2nz-on-vector=False compile-on-910-95=False' --split-input-file | FileCheck %s + +// dtype: bool +module { + tt.func public @triton_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32} , %arg1: !tt.ptr {tt.divisibility = 16 : i32} , %arg2: !tt.ptr {tt.divisibility = 16 : i32} ) attributes {noinline = false} { + %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<8x!tt.ptr>, tensor<8xi32> + %3 = tt.load %2 : tensor<8x!tt.ptr> + %4 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %5 = tt.splat %arg2 : !tt.ptr -> tensor<4x!tt.ptr> + %6 = tt.addptr %5, %4 : tensor<4x!tt.ptr>, tensor<4xi32> + %7 = tt.load %6 : tensor<4x!tt.ptr> + %8 = tt.gather %3[%7] {axis = 0 : i32} : (tensor<8xi8>, tensor<4xi32>) -> tensor<4xi8> + %9 = tt.splat %arg1 : !tt.ptr -> tensor<4x!tt.ptr> + %10 = tt.addptr %9, %4 : tensor<4x!tt.ptr>, tensor<4xi32> + tt.store %10, %8 : tensor<4x!tt.ptr> + tt.return + } +} + +// CHECK: %[[RESULT:.*]] = call @triton_gather(%[[SOURCE:.*]], %[[INDICES:.*]], %[[DIMENSION:.*]]) : (tensor<8xi8>, tensor<4xi32>, i32) -> tensor<4xi8> + +// ----- + +// dtype: float16 +module { + tt.func public @triton_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32} , %arg1: !tt.ptr {tt.divisibility = 16 : i32} , %arg2: !tt.ptr {tt.divisibility = 16 : i32} ) attributes {noinline = false} { + %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<8x!tt.ptr>, tensor<8xi32> + %3 = tt.load %2 : tensor<8x!tt.ptr> + %4 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %5 = tt.splat %arg2 : !tt.ptr -> tensor<4x!tt.ptr> + %6 = tt.addptr %5, %4 : tensor<4x!tt.ptr>, tensor<4xi32> + %7 = tt.load %6 : tensor<4x!tt.ptr> + %8 = tt.gather %3[%7] {axis = 0 : i32} : (tensor<8xf16>, tensor<4xi32>) -> tensor<4xf16> + %9 = tt.splat %arg1 : !tt.ptr -> tensor<4x!tt.ptr> + %10 = tt.addptr %9, %4 : tensor<4x!tt.ptr>, tensor<4xi32> + tt.store %10, %8 : tensor<4x!tt.ptr> + tt.return + } +} + +// CHECK: %[[RESULT:.*]] = call @triton_gather(%[[SOURCE:.*]], %[[INDICES:.*]], %[[DIMENSION:.*]]) : (tensor<8xf16>, tensor<4xi32>, i32) -> tensor<4xf16> + +// ----- + +// dtype: float32 +module { + tt.func public @triton_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32} , %arg1: !tt.ptr {tt.divisibility = 16 : i32} , %arg2: !tt.ptr {tt.divisibility = 16 : i32} ) attributes {noinline = false} { + %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<8x!tt.ptr>, tensor<8xi32> + %3 = tt.load %2 : tensor<8x!tt.ptr> + %4 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %5 = tt.splat %arg2 : !tt.ptr -> tensor<4x!tt.ptr> + %6 = tt.addptr %5, %4 : tensor<4x!tt.ptr>, tensor<4xi32> + %7 = tt.load %6 : tensor<4x!tt.ptr> + %8 = tt.gather %3[%7] {axis = 0 : i32} : (tensor<8xf32>, tensor<4xi32>) -> tensor<4xf32> + %9 = tt.splat %arg1 : !tt.ptr -> tensor<4x!tt.ptr> + %10 = tt.addptr %9, %4 : tensor<4x!tt.ptr>, tensor<4xi32> + tt.store %10, %8 : tensor<4x!tt.ptr> + tt.return + } +} + +// CHECK: %[[RESULT:.*]] = call @triton_gather(%[[SOURCE:.*]], %[[INDICES:.*]], %[[DIMENSION:.*]]) : (tensor<8xf32>, tensor<4xi32>, i32) -> tensor<4xf32> + +// ----- + +// dtype: bfloat16 +module { + tt.func public @triton_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32} , %arg1: !tt.ptr {tt.divisibility = 16 : i32} , %arg2: !tt.ptr {tt.divisibility = 16 : i32} ) attributes {noinline = false} { + %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<8x!tt.ptr>, tensor<8xi32> + %3 = tt.load %2 : tensor<8x!tt.ptr> + %4 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %5 = tt.splat %arg2 : !tt.ptr -> tensor<4x!tt.ptr> + %6 = tt.addptr %5, %4 : tensor<4x!tt.ptr>, tensor<4xi32> + %7 = tt.load %6 : tensor<4x!tt.ptr> + %8 = tt.gather %3[%7] {axis = 0 : i32} : (tensor<8xbf16>, tensor<4xi32>) -> tensor<4xbf16> + %9 = tt.splat %arg1 : !tt.ptr -> tensor<4x!tt.ptr> + %10 = tt.addptr %9, %4 : tensor<4x!tt.ptr>, tensor<4xi32> + tt.store %10, %8 : tensor<4x!tt.ptr> + tt.return + } +} + +// CHECK: %[[RESULT:.*]] = call @triton_gather(%[[SOURCE:.*]], %[[INDICES:.*]], %[[DIMENSION:.*]]) : (tensor<8xbf16>, tensor<4xi32>, i32) -> tensor<4xbf16> + +// ----- + +// dtype: float8_e4m3 +module { + tt.func public @triton_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32} , %arg1: !tt.ptr {tt.divisibility = 16 : i32} , %arg2: !tt.ptr {tt.divisibility = 16 : i32} ) attributes {noinline = false} { + %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<8x!tt.ptr>, tensor<8xi32> + %3 = tt.load %2 : tensor<8x!tt.ptr> + %4 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %5 = tt.splat %arg2 : !tt.ptr -> tensor<4x!tt.ptr> + %6 = tt.addptr %5, %4 : tensor<4x!tt.ptr>, tensor<4xi32> + %7 = tt.load %6 : tensor<4x!tt.ptr> + %8 = tt.gather %3[%7] {axis = 0 : i32} : (tensor<8xf8E4M3FN>, tensor<4xi32>) -> tensor<4xf8E4M3FN> + %9 = tt.splat %arg1 : !tt.ptr -> tensor<4x!tt.ptr> + %10 = tt.addptr %9, %4 : tensor<4x!tt.ptr>, tensor<4xi32> + tt.store %10, %8 : tensor<4x!tt.ptr> + tt.return + } +} + +// CHECK: %[[RESULT:.*]] = call @triton_gather(%[[SOURCE:.*]], %[[INDICES:.*]], %[[DIMENSION:.*]]) : (tensor<8xf8E4M3FN>, tensor<4xi32>, i32) -> tensor<4xf8E4M3FN> + +// ----- + +// dtype: float8_e5m2 +module { + tt.func public @triton_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32} , %arg1: !tt.ptr {tt.divisibility = 16 : i32} , %arg2: !tt.ptr {tt.divisibility = 16 : i32} ) attributes {noinline = false} { + %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<8x!tt.ptr>, tensor<8xi32> + %3 = tt.load %2 : tensor<8x!tt.ptr> + %4 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %5 = tt.splat %arg2 : !tt.ptr -> tensor<4x!tt.ptr> + %6 = tt.addptr %5, %4 : tensor<4x!tt.ptr>, tensor<4xi32> + %7 = tt.load %6 : tensor<4x!tt.ptr> + %8 = tt.gather %3[%7] {axis = 0 : i32} : (tensor<8xf8E5M2>, tensor<4xi32>) -> tensor<4xf8E5M2> + %9 = tt.splat %arg1 : !tt.ptr -> tensor<4x!tt.ptr> + %10 = tt.addptr %9, %4 : tensor<4x!tt.ptr>, tensor<4xi32> + tt.store %10, %8 : tensor<4x!tt.ptr> + tt.return + } +} + // CHECK: %[[RESULT:.*]] = call @triton_gather(%[[SOURCE:.*]], %[[INDICES:.*]], %[[DIMENSION:.*]]) : (tensor<8xf8E5M2>, tensor<4xi32>, i32) -> tensor<4xf8E5M2> diff --git a/third_party/ascend/test/Conversion/TritonOp/inline_asm_elementwise.mlir b/third_party/ascend/test/Conversion/TritonOp/inline_asm_elementwise.mlir index 462aaf3c3..17c9f70c7 100644 --- a/third_party/ascend/test/Conversion/TritonOp/inline_asm_elementwise.mlir +++ b/third_party/ascend/test/Conversion/TritonOp/inline_asm_elementwise.mlir @@ -1,179 +1,179 @@ -// RUN: triton-adapter-opt --triton-linearize --discrete-mask-access-conversion --triton-to-annotation '--triton-to-unstructure=compile-on-910-95=False force-simt-template=False' --triton-to-hivm --triton-to-hfusion --triton-to-llvm --bubble-up-operation '--triton-to-linalg=global-kernel=false named-ops=True enable-nd2nz-on-vector=False compile-on-910-95=False' %s | FileCheck %s - -module { - tt.func public @triton_asm_add(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) attributes {noinline = false} { - %0 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> - %1 = tt.splat %arg0 : !tt.ptr -> tensor<2x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<2x!tt.ptr>, tensor<2xi32> - %3 = tt.load %2 : tensor<2x!tt.ptr> - %4 = tt.splat %arg1 : !tt.ptr -> tensor<2x!tt.ptr> - %5 = tt.addptr %4, %0 : tensor<2x!tt.ptr>, tensor<2xi32> - %6 = tt.load %5 : tensor<2x!tt.ptr> - // CHECK: %[[ASM_RESULT1:.*]] = llvm.inline_asm asm_dialect = att "\0A ADD.s64 $0, $1, $2\0A ", "=l,l,l" %{{.*}}, %{{.*}} : (i32, i32) -> i32 - %7 = tt.elementwise_inline_asm "\0A ADD.s64 $0, $1, $2\0A " {constraints = "=l,l,l", packed_element = 1 : i32, pure = true} %3, %6 : tensor<2xi32>, tensor<2xi32> -> tensor<2xi32> - %8 = tt.splat %arg2 : !tt.ptr -> tensor<2x!tt.ptr> - %9 = tt.addptr %8, %0 : tensor<2x!tt.ptr>, tensor<2xi32> - tt.store %9, %7 : tensor<2x!tt.ptr> - tt.return - } -} - -// ----- - -module { - tt.func public @triton_asm_add(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) attributes {noinline = false} { - %0 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> - %1 = tt.splat %arg0 : !tt.ptr -> tensor<2x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<2x!tt.ptr>, tensor<2xi32> - %3 = tt.load %2 : tensor<2x!tt.ptr> - %4 = tt.splat %arg1 : !tt.ptr -> tensor<2x!tt.ptr> - %5 = tt.addptr %4, %0 : tensor<2x!tt.ptr>, tensor<2xi32> - %6 = tt.load %5 : tensor<2x!tt.ptr> - // CHECK: %[[ASM_RESULT1:.*]] = llvm.inline_asm asm_dialect = att "\0A ADD.s64 $0, $1, $2\0A ", "=l,l,l" %{{.*}}, %{{.*}} : (i64, i64) -> i64 - %7 = tt.elementwise_inline_asm "\0A ADD.s64 $0, $1, $2\0A " {constraints = "=l,l,l", packed_element = 1 : i32, pure = true} %3, %6 : tensor<2xi64>, tensor<2xi64> -> tensor<2xi64> - %8 = tt.splat %arg2 : !tt.ptr -> tensor<2x!tt.ptr> - %9 = tt.addptr %8, %0 : tensor<2x!tt.ptr>, tensor<2xi32> - tt.store %9, %7 : tensor<2x!tt.ptr> - tt.return - } -} - -// ----- - -module { - tt.func public @triton_asm_add(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) attributes {noinline = false} { - %0 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> - %1 = tt.splat %arg0 : !tt.ptr -> tensor<2x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<2x!tt.ptr>, tensor<2xi32> - %3 = tt.load %2 : tensor<2x!tt.ptr> - %4 = tt.splat %arg1 : !tt.ptr -> tensor<2x!tt.ptr> - %5 = tt.addptr %4, %0 : tensor<2x!tt.ptr>, tensor<2xi32> - %6 = tt.load %5 : tensor<2x!tt.ptr> - // CHECK: %[[ASM_RESULT1:.*]] = llvm.inline_asm asm_dialect = att "\0A ADD.s64 $0, $1, $2\0A ", "=l,l,l" %{{.*}}, %{{.*}} : (i16, i16) -> i16 - %7 = tt.elementwise_inline_asm "\0A ADD.s64 $0, $1, $2\0A " {constraints = "=l,l,l", packed_element = 1 : i32, pure = true} %3, %6 : tensor<2xi16>, tensor<2xi16> -> tensor<2xi16> - %8 = tt.splat %arg2 : !tt.ptr -> tensor<2x!tt.ptr> - %9 = tt.addptr %8, %0 : tensor<2x!tt.ptr>, tensor<2xi32> - tt.store %9, %7 : tensor<2x!tt.ptr> - tt.return - } -} - -// ----- - -module { - tt.func public @triton_asm_add(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) attributes {noinline = false} { - %0 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> - %1 = tt.splat %arg0 : !tt.ptr -> tensor<2x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<2x!tt.ptr>, tensor<2xi32> - %3 = tt.load %2 : tensor<2x!tt.ptr> - %4 = tt.splat %arg1 : !tt.ptr -> tensor<2x!tt.ptr> - %5 = tt.addptr %4, %0 : tensor<2x!tt.ptr>, tensor<2xi32> - %6 = tt.load %5 : tensor<2x!tt.ptr> - // CHECK: %[[ASM_RESULT1:.*]] = llvm.inline_asm asm_dialect = att "\0A ADD.s64 $0, $1, $2\0A ", "=l,l,l" %{{.*}}, %{{.*}} : (i8, i8) -> i8 - %7 = tt.elementwise_inline_asm "\0A ADD.s64 $0, $1, $2\0A " {constraints = "=l,l,l", packed_element = 1 : i32, pure = true} %3, %6 : tensor<2xi8>, tensor<2xi8> -> tensor<2xi8> - %8 = tt.splat %arg2 : !tt.ptr -> tensor<2x!tt.ptr> - %9 = tt.addptr %8, %0 : tensor<2x!tt.ptr>, tensor<2xi32> - tt.store %9, %7 : tensor<2x!tt.ptr> - tt.return - } -} - -// ----- - -module { - tt.func public @triton_asm_add(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) attributes {noinline = false} { - %0 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> - %1 = tt.splat %arg0 : !tt.ptr -> tensor<2x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<2x!tt.ptr>, tensor<2xi32> - %3 = tt.load %2 : tensor<2x!tt.ptr> - %4 = tt.splat %arg1 : !tt.ptr -> tensor<2x!tt.ptr> - %5 = tt.addptr %4, %0 : tensor<2x!tt.ptr>, tensor<2xi32> - %6 = tt.load %5 : tensor<2x!tt.ptr> - // CHECK: %[[ASM_RESULT1:.*]] = llvm.inline_asm asm_dialect = att "\0A ADD.s64 $0, $1, $2\0A ", "=l,l,l" %{{.*}}, %{{.*}} : (f16, f16) -> f16 - %7 = tt.elementwise_inline_asm "\0A ADD.s64 $0, $1, $2\0A " {constraints = "=l,l,l", packed_element = 1 : i32, pure = true} %3, %6 : tensor<2xf16>, tensor<2xf16> -> tensor<2xf16> - %8 = tt.splat %arg2 : !tt.ptr -> tensor<2x!tt.ptr> - %9 = tt.addptr %8, %0 : tensor<2x!tt.ptr>, tensor<2xi32> - tt.store %9, %7 : tensor<2x!tt.ptr> - tt.return - } -} - -// ----- - -module { - tt.func public @triton_asm_add(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) attributes {noinline = false} { - %0 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> - %1 = tt.splat %arg0 : !tt.ptr -> tensor<2x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<2x!tt.ptr>, tensor<2xi32> - %3 = tt.load %2 : tensor<2x!tt.ptr> - %4 = tt.splat %arg1 : !tt.ptr -> tensor<2x!tt.ptr> - %5 = tt.addptr %4, %0 : tensor<2x!tt.ptr>, tensor<2xi32> - %6 = tt.load %5 : tensor<2x!tt.ptr> - // CHECK: %[[ASM_RESULT1:.*]] = llvm.inline_asm asm_dialect = att "\0A ADD.s64 $0, $1, $2\0A ", "=l,l,l" %{{.*}}, %{{.*}} : (f32, f32) -> f32 - %7 = tt.elementwise_inline_asm "\0A ADD.s64 $0, $1, $2\0A " {constraints = "=l,l,l", packed_element = 1 : i32, pure = true} %3, %6 : tensor<2xf32>, tensor<2xf32> -> tensor<2xf32> - %8 = tt.splat %arg2 : !tt.ptr -> tensor<2x!tt.ptr> - %9 = tt.addptr %8, %0 : tensor<2x!tt.ptr>, tensor<2xi32> - tt.store %9, %7 : tensor<2x!tt.ptr> - tt.return - } -} - -// ----- - -module { - tt.func public @triton_asm_add(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) attributes {noinline = false} { - %0 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> - %1 = tt.splat %arg0 : !tt.ptr -> tensor<2x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<2x!tt.ptr>, tensor<2xi32> - %3 = tt.load %2 : tensor<2x!tt.ptr> - %4 = tt.splat %arg1 : !tt.ptr -> tensor<2x!tt.ptr> - %5 = tt.addptr %4, %0 : tensor<2x!tt.ptr>, tensor<2xi32> - %6 = tt.load %5 : tensor<2x!tt.ptr> - // CHECK: %[[ASM_RESULT1:.*]] = llvm.inline_asm asm_dialect = att "\0A ADD.s64 $0, $1, $2\0A ", "=l,l,l" %{{.*}}, %{{.*}} : (bf16, bf16) -> bf16 - %7 = tt.elementwise_inline_asm "\0A ADD.s64 $0, $1, $2\0A " {constraints = "=l,l,l", packed_element = 1 : i32, pure = true} %3, %6 : tensor<2xbf16>, tensor<2xbf16> -> tensor<2xbf16> - %8 = tt.splat %arg2 : !tt.ptr -> tensor<2x!tt.ptr> - %9 = tt.addptr %8, %0 : tensor<2x!tt.ptr>, tensor<2xi32> - tt.store %9, %7 : tensor<2x!tt.ptr> - tt.return - } -} - -// ----- - -module { - tt.func public @triton_asm_add(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) attributes {noinline = false} { - %0 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> - %1 = tt.splat %arg0 : !tt.ptr -> tensor<2x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<2x!tt.ptr>, tensor<2xi32> - %3 = tt.load %2 : tensor<2x!tt.ptr> - %4 = tt.splat %arg1 : !tt.ptr -> tensor<2x!tt.ptr> - %5 = tt.addptr %4, %0 : tensor<2x!tt.ptr>, tensor<2xi32> - %6 = tt.load %5 : tensor<2x!tt.ptr> - // CHECK: %[[ASM_RESULT1:.*]] = llvm.inline_asm asm_dialect = att "\0A ADD.s64 $0, $1, $2\0A ", "=l,l,l" %{{.*}}, %{{.*}} : (f8E4M3FN, f8E4M3FN) -> f8E4M3FN - %7 = tt.elementwise_inline_asm "\0A ADD.s64 $0, $1, $2\0A " {constraints = "=l,l,l", packed_element = 1 : i32, pure = true} %3, %6 : tensor<2xf8E4M3FN>, tensor<2xf8E4M3FN> -> tensor<2xf8E4M3FN> - %8 = tt.splat %arg2 : !tt.ptr -> tensor<2x!tt.ptr> - %9 = tt.addptr %8, %0 : tensor<2x!tt.ptr>, tensor<2xi32> - tt.store %9, %7 : tensor<2x!tt.ptr> - tt.return - } -} - -// ----- - -module { - tt.func public @triton_asm_add(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) attributes {noinline = false} { - %0 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> - %1 = tt.splat %arg0 : !tt.ptr -> tensor<2x!tt.ptr> - %2 = tt.addptr %1, %0 : tensor<2x!tt.ptr>, tensor<2xi32> - %3 = tt.load %2 : tensor<2x!tt.ptr> - %4 = tt.splat %arg1 : !tt.ptr -> tensor<2x!tt.ptr> - %5 = tt.addptr %4, %0 : tensor<2x!tt.ptr>, tensor<2xi32> - %6 = tt.load %5 : tensor<2x!tt.ptr> - // CHECK: %[[ASM_RESULT1:.*]] = llvm.inline_asm asm_dialect = att "\0A ADD.s64 $0, $1, $2\0A ", "=l,l,l" %{{.*}}, %{{.*}} : (f8E5M2, f8E5M2) -> f8E5M2 - %7 = tt.elementwise_inline_asm "\0A ADD.s64 $0, $1, $2\0A " {constraints = "=l,l,l", packed_element = 1 : i32, pure = true} %3, %6 : tensor<2xf8E5M2>, tensor<2xf8E5M2> -> tensor<2xf8E5M2> - %8 = tt.splat %arg2 : !tt.ptr -> tensor<2x!tt.ptr> - %9 = tt.addptr %8, %0 : tensor<2x!tt.ptr>, tensor<2xi32> - tt.store %9, %7 : tensor<2x!tt.ptr> - tt.return - } +// RUN: triton-adapter-opt --triton-linearize --discrete-mask-access-conversion --triton-to-annotation '--triton-to-unstructure=compile-on-910-95=False force-simt-template=False' --triton-to-hivm --triton-to-hfusion --triton-to-llvm --bubble-up-operation '--triton-to-linalg=global-kernel=false named-ops=True enable-nd2nz-on-vector=False compile-on-910-95=False' %s | FileCheck %s + +module { + tt.func public @triton_asm_add(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) attributes {noinline = false} { + %0 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<2x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<2x!tt.ptr>, tensor<2xi32> + %3 = tt.load %2 : tensor<2x!tt.ptr> + %4 = tt.splat %arg1 : !tt.ptr -> tensor<2x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<2x!tt.ptr>, tensor<2xi32> + %6 = tt.load %5 : tensor<2x!tt.ptr> + // CHECK: %[[ASM_RESULT1:.*]] = llvm.inline_asm asm_dialect = att "\0A ADD.s64 $0, $1, $2\0A ", "=l,l,l" %{{.*}}, %{{.*}} : (i32, i32) -> i32 + %7 = tt.elementwise_inline_asm "\0A ADD.s64 $0, $1, $2\0A " {constraints = "=l,l,l", packed_element = 1 : i32, pure = true} %3, %6 : tensor<2xi32>, tensor<2xi32> -> tensor<2xi32> + %8 = tt.splat %arg2 : !tt.ptr -> tensor<2x!tt.ptr> + %9 = tt.addptr %8, %0 : tensor<2x!tt.ptr>, tensor<2xi32> + tt.store %9, %7 : tensor<2x!tt.ptr> + tt.return + } +} + +// ----- + +module { + tt.func public @triton_asm_add(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) attributes {noinline = false} { + %0 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<2x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<2x!tt.ptr>, tensor<2xi32> + %3 = tt.load %2 : tensor<2x!tt.ptr> + %4 = tt.splat %arg1 : !tt.ptr -> tensor<2x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<2x!tt.ptr>, tensor<2xi32> + %6 = tt.load %5 : tensor<2x!tt.ptr> + // CHECK: %[[ASM_RESULT1:.*]] = llvm.inline_asm asm_dialect = att "\0A ADD.s64 $0, $1, $2\0A ", "=l,l,l" %{{.*}}, %{{.*}} : (i64, i64) -> i64 + %7 = tt.elementwise_inline_asm "\0A ADD.s64 $0, $1, $2\0A " {constraints = "=l,l,l", packed_element = 1 : i32, pure = true} %3, %6 : tensor<2xi64>, tensor<2xi64> -> tensor<2xi64> + %8 = tt.splat %arg2 : !tt.ptr -> tensor<2x!tt.ptr> + %9 = tt.addptr %8, %0 : tensor<2x!tt.ptr>, tensor<2xi32> + tt.store %9, %7 : tensor<2x!tt.ptr> + tt.return + } +} + +// ----- + +module { + tt.func public @triton_asm_add(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) attributes {noinline = false} { + %0 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<2x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<2x!tt.ptr>, tensor<2xi32> + %3 = tt.load %2 : tensor<2x!tt.ptr> + %4 = tt.splat %arg1 : !tt.ptr -> tensor<2x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<2x!tt.ptr>, tensor<2xi32> + %6 = tt.load %5 : tensor<2x!tt.ptr> + // CHECK: %[[ASM_RESULT1:.*]] = llvm.inline_asm asm_dialect = att "\0A ADD.s64 $0, $1, $2\0A ", "=l,l,l" %{{.*}}, %{{.*}} : (i16, i16) -> i16 + %7 = tt.elementwise_inline_asm "\0A ADD.s64 $0, $1, $2\0A " {constraints = "=l,l,l", packed_element = 1 : i32, pure = true} %3, %6 : tensor<2xi16>, tensor<2xi16> -> tensor<2xi16> + %8 = tt.splat %arg2 : !tt.ptr -> tensor<2x!tt.ptr> + %9 = tt.addptr %8, %0 : tensor<2x!tt.ptr>, tensor<2xi32> + tt.store %9, %7 : tensor<2x!tt.ptr> + tt.return + } +} + +// ----- + +module { + tt.func public @triton_asm_add(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) attributes {noinline = false} { + %0 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<2x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<2x!tt.ptr>, tensor<2xi32> + %3 = tt.load %2 : tensor<2x!tt.ptr> + %4 = tt.splat %arg1 : !tt.ptr -> tensor<2x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<2x!tt.ptr>, tensor<2xi32> + %6 = tt.load %5 : tensor<2x!tt.ptr> + // CHECK: %[[ASM_RESULT1:.*]] = llvm.inline_asm asm_dialect = att "\0A ADD.s64 $0, $1, $2\0A ", "=l,l,l" %{{.*}}, %{{.*}} : (i8, i8) -> i8 + %7 = tt.elementwise_inline_asm "\0A ADD.s64 $0, $1, $2\0A " {constraints = "=l,l,l", packed_element = 1 : i32, pure = true} %3, %6 : tensor<2xi8>, tensor<2xi8> -> tensor<2xi8> + %8 = tt.splat %arg2 : !tt.ptr -> tensor<2x!tt.ptr> + %9 = tt.addptr %8, %0 : tensor<2x!tt.ptr>, tensor<2xi32> + tt.store %9, %7 : tensor<2x!tt.ptr> + tt.return + } +} + +// ----- + +module { + tt.func public @triton_asm_add(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) attributes {noinline = false} { + %0 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<2x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<2x!tt.ptr>, tensor<2xi32> + %3 = tt.load %2 : tensor<2x!tt.ptr> + %4 = tt.splat %arg1 : !tt.ptr -> tensor<2x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<2x!tt.ptr>, tensor<2xi32> + %6 = tt.load %5 : tensor<2x!tt.ptr> + // CHECK: %[[ASM_RESULT1:.*]] = llvm.inline_asm asm_dialect = att "\0A ADD.s64 $0, $1, $2\0A ", "=l,l,l" %{{.*}}, %{{.*}} : (f16, f16) -> f16 + %7 = tt.elementwise_inline_asm "\0A ADD.s64 $0, $1, $2\0A " {constraints = "=l,l,l", packed_element = 1 : i32, pure = true} %3, %6 : tensor<2xf16>, tensor<2xf16> -> tensor<2xf16> + %8 = tt.splat %arg2 : !tt.ptr -> tensor<2x!tt.ptr> + %9 = tt.addptr %8, %0 : tensor<2x!tt.ptr>, tensor<2xi32> + tt.store %9, %7 : tensor<2x!tt.ptr> + tt.return + } +} + +// ----- + +module { + tt.func public @triton_asm_add(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) attributes {noinline = false} { + %0 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<2x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<2x!tt.ptr>, tensor<2xi32> + %3 = tt.load %2 : tensor<2x!tt.ptr> + %4 = tt.splat %arg1 : !tt.ptr -> tensor<2x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<2x!tt.ptr>, tensor<2xi32> + %6 = tt.load %5 : tensor<2x!tt.ptr> + // CHECK: %[[ASM_RESULT1:.*]] = llvm.inline_asm asm_dialect = att "\0A ADD.s64 $0, $1, $2\0A ", "=l,l,l" %{{.*}}, %{{.*}} : (f32, f32) -> f32 + %7 = tt.elementwise_inline_asm "\0A ADD.s64 $0, $1, $2\0A " {constraints = "=l,l,l", packed_element = 1 : i32, pure = true} %3, %6 : tensor<2xf32>, tensor<2xf32> -> tensor<2xf32> + %8 = tt.splat %arg2 : !tt.ptr -> tensor<2x!tt.ptr> + %9 = tt.addptr %8, %0 : tensor<2x!tt.ptr>, tensor<2xi32> + tt.store %9, %7 : tensor<2x!tt.ptr> + tt.return + } +} + +// ----- + +module { + tt.func public @triton_asm_add(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) attributes {noinline = false} { + %0 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<2x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<2x!tt.ptr>, tensor<2xi32> + %3 = tt.load %2 : tensor<2x!tt.ptr> + %4 = tt.splat %arg1 : !tt.ptr -> tensor<2x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<2x!tt.ptr>, tensor<2xi32> + %6 = tt.load %5 : tensor<2x!tt.ptr> + // CHECK: %[[ASM_RESULT1:.*]] = llvm.inline_asm asm_dialect = att "\0A ADD.s64 $0, $1, $2\0A ", "=l,l,l" %{{.*}}, %{{.*}} : (bf16, bf16) -> bf16 + %7 = tt.elementwise_inline_asm "\0A ADD.s64 $0, $1, $2\0A " {constraints = "=l,l,l", packed_element = 1 : i32, pure = true} %3, %6 : tensor<2xbf16>, tensor<2xbf16> -> tensor<2xbf16> + %8 = tt.splat %arg2 : !tt.ptr -> tensor<2x!tt.ptr> + %9 = tt.addptr %8, %0 : tensor<2x!tt.ptr>, tensor<2xi32> + tt.store %9, %7 : tensor<2x!tt.ptr> + tt.return + } +} + +// ----- + +module { + tt.func public @triton_asm_add(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) attributes {noinline = false} { + %0 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<2x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<2x!tt.ptr>, tensor<2xi32> + %3 = tt.load %2 : tensor<2x!tt.ptr> + %4 = tt.splat %arg1 : !tt.ptr -> tensor<2x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<2x!tt.ptr>, tensor<2xi32> + %6 = tt.load %5 : tensor<2x!tt.ptr> + // CHECK: %[[ASM_RESULT1:.*]] = llvm.inline_asm asm_dialect = att "\0A ADD.s64 $0, $1, $2\0A ", "=l,l,l" %{{.*}}, %{{.*}} : (f8E4M3FN, f8E4M3FN) -> f8E4M3FN + %7 = tt.elementwise_inline_asm "\0A ADD.s64 $0, $1, $2\0A " {constraints = "=l,l,l", packed_element = 1 : i32, pure = true} %3, %6 : tensor<2xf8E4M3FN>, tensor<2xf8E4M3FN> -> tensor<2xf8E4M3FN> + %8 = tt.splat %arg2 : !tt.ptr -> tensor<2x!tt.ptr> + %9 = tt.addptr %8, %0 : tensor<2x!tt.ptr>, tensor<2xi32> + tt.store %9, %7 : tensor<2x!tt.ptr> + tt.return + } +} + +// ----- + +module { + tt.func public @triton_asm_add(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) attributes {noinline = false} { + %0 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<2x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<2x!tt.ptr>, tensor<2xi32> + %3 = tt.load %2 : tensor<2x!tt.ptr> + %4 = tt.splat %arg1 : !tt.ptr -> tensor<2x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<2x!tt.ptr>, tensor<2xi32> + %6 = tt.load %5 : tensor<2x!tt.ptr> + // CHECK: %[[ASM_RESULT1:.*]] = llvm.inline_asm asm_dialect = att "\0A ADD.s64 $0, $1, $2\0A ", "=l,l,l" %{{.*}}, %{{.*}} : (f8E5M2, f8E5M2) -> f8E5M2 + %7 = tt.elementwise_inline_asm "\0A ADD.s64 $0, $1, $2\0A " {constraints = "=l,l,l", packed_element = 1 : i32, pure = true} %3, %6 : tensor<2xf8E5M2>, tensor<2xf8E5M2> -> tensor<2xf8E5M2> + %8 = tt.splat %arg2 : !tt.ptr -> tensor<2x!tt.ptr> + %9 = tt.addptr %8, %0 : tensor<2x!tt.ptr>, tensor<2xi32> + tt.store %9, %7 : tensor<2x!tt.ptr> + tt.return + } } diff --git a/third_party/ascend/test/Conversion/TritonOp/load_store_make_blk_ptr_advance.mlir b/third_party/ascend/test/Conversion/TritonOp/load_store_make_blk_ptr_advance.mlir index f86ff5c35..f2ab28353 100644 --- a/third_party/ascend/test/Conversion/TritonOp/load_store_make_blk_ptr_advance.mlir +++ b/third_party/ascend/test/Conversion/TritonOp/load_store_make_blk_ptr_advance.mlir @@ -1,266 +1,266 @@ -// RUN: triton-adapter-opt %s --triton-linearize '--discrete-mask-access-conversion=compile-on-910-95=False force-simt-template=False' --triton-to-annotation '--triton-to-unstructure=compile-on-910-95=False force-simt-template=False' --triton-to-hivm --triton-to-hfusion --triton-to-llvm --bubble-up-operation '--triton-to-linalg=global-kernel=false named-ops=True enable-nd2nz-on-vector=False compile-on-910-95=False' --split-input-file | FileCheck %s - -// dtype: uint8 & int8 -module { - tt.func public @triton_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { - %cst = arith.constant dense<5> : tensor<4xi64> - %0 = tt.splat %arg0 : !tt.ptr -> tensor<4x!tt.ptr> - %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> - %2 = arith.extsi %1 : tensor<4xi32> to tensor<4xi64> - %3 = arith.addi %2, %cst : tensor<4xi64> - %4 = tt.addptr %0, %3 : tensor<4x!tt.ptr>, tensor<4xi64> - %5 = tt.load %4 : tensor<4x!tt.ptr> - %6 = tt.splat %arg1 : !tt.ptr -> tensor<4x!tt.ptr> - %7 = tt.addptr %6, %2 : tensor<4x!tt.ptr>, tensor<4xi64> - tt.store %7, %5 : tensor<4x!tt.ptr> - tt.return - } -} - -// CHECK: %[[REINT_CAST0:.*]] = memref.reinterpret_cast %[[ARG0:.*]] to offset: [5], sizes: [4], strides: [1] : memref to memref<4xi8, strided<[1], offset: 5>> -// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<4xi8> -// CHECK: memref.copy %[[REINT_CAST0]], %[[ALLOC]] : memref<4xi8, strided<[1], offset: 5>> to memref<4xi8> -// CHECK: %[[VAL0:.*]] = bufferization.to_tensor %[[ALLOC]] restrict writable : memref<4xi8> -// CHECK: %[[REINT_CAST1:.*]] = memref.reinterpret_cast %[[ARG1:.*]] to offset: [0], sizes: [4], strides: [1] : memref to memref<4xi8, strided<[1]>> -// CHECK: bufferization.materialize_in_destination %[[VAL0]] in writable %[[REINT_CAST1]] : (tensor<4xi8>, memref<4xi8, strided<[1]>>) -> () - -// ----- - -// dtype: uint16 & int16 -module { - tt.func public @triton_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32} , %arg1: !tt.ptr {tt.divisibility = 16 : i32} ) attributes {noinline = false} { - %cst = arith.constant dense<5> : tensor<4xi64> - %0 = tt.splat %arg0 : !tt.ptr -> tensor<4x!tt.ptr> - %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> - %2 = arith.extsi %1 : tensor<4xi32> to tensor<4xi64> - %3 = arith.addi %2, %cst : tensor<4xi64> - %4 = tt.addptr %0, %3 : tensor<4x!tt.ptr>, tensor<4xi64> - %5 = tt.load %4 : tensor<4x!tt.ptr> - %6 = tt.splat %arg1 : !tt.ptr -> tensor<4x!tt.ptr> - %7 = tt.addptr %6, %2 : tensor<4x!tt.ptr>, tensor<4xi64> - tt.store %7, %5 : tensor<4x!tt.ptr> - tt.return - } -} - -// CHECK: %[[REINT_CAST0:.*]] = memref.reinterpret_cast %[[ARG0:.*]] to offset: [5], sizes: [4], strides: [1] : memref to memref<4xi16, strided<[1], offset: 5>> -// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<4xi16> -// CHECK: memref.copy %[[REINT_CAST0]], %[[ALLOC]] : memref<4xi16, strided<[1], offset: 5>> to memref<4xi16> -// CHECK: %[[VAL0:.*]] = bufferization.to_tensor %[[ALLOC]] restrict writable : memref<4xi16> -// CHECK: %[[REINT_CAST1:.*]] = memref.reinterpret_cast %[[ARG1:.*]] to offset: [0], sizes: [4], strides: [1] : memref to memref<4xi16, strided<[1]>> -// CHECK: bufferization.materialize_in_destination %[[VAL0]] in writable %[[REINT_CAST1]] : (tensor<4xi16>, memref<4xi16, strided<[1]>>) -> () - -// ----- - -// dtype: uint32 & int32 -module { - tt.func public @triton_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32} , %arg1: !tt.ptr {tt.divisibility = 16 : i32} ) attributes {noinline = false} { - %cst = arith.constant dense<5> : tensor<4xi64> - %0 = tt.splat %arg0 : !tt.ptr -> tensor<4x!tt.ptr> - %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> - %2 = arith.extsi %1 : tensor<4xi32> to tensor<4xi64> - %3 = arith.addi %2, %cst : tensor<4xi64> - %4 = tt.addptr %0, %3 : tensor<4x!tt.ptr>, tensor<4xi64> - %5 = tt.load %4 : tensor<4x!tt.ptr> - %6 = tt.splat %arg1 : !tt.ptr -> tensor<4x!tt.ptr> - %7 = tt.addptr %6, %2 : tensor<4x!tt.ptr>, tensor<4xi64> - tt.store %7, %5 : tensor<4x!tt.ptr> - tt.return - } -} - -// CHECK: %[[REINT_CAST0:.*]] = memref.reinterpret_cast %[[ARG0:.*]] to offset: [5], sizes: [4], strides: [1] : memref to memref<4xi32, strided<[1], offset: 5>> -// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<4xi32> -// CHECK: memref.copy %[[REINT_CAST0]], %[[ALLOC]] : memref<4xi32, strided<[1], offset: 5>> to memref<4xi32> -// CHECK: %[[VAL0:.*]] = bufferization.to_tensor %[[ALLOC]] restrict writable : memref<4xi32> -// CHECK: %[[REINT_CAST1:.*]] = memref.reinterpret_cast %[[ARG1:.*]] to offset: [0], sizes: [4], strides: [1] : memref to memref<4xi32, strided<[1]>> -// CHECK: bufferization.materialize_in_destination %[[VAL0]] in writable %[[REINT_CAST1]] : (tensor<4xi32>, memref<4xi32, strided<[1]>>) -> () - - -// ----- - -// dtype: uint64 & int64 -module { - tt.func public @triton_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32} , %arg1: !tt.ptr {tt.divisibility = 16 : i32} ) attributes {noinline = false} { - %cst = arith.constant dense<5> : tensor<4xi64> - %0 = tt.splat %arg0 : !tt.ptr -> tensor<4x!tt.ptr> - %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> - %2 = arith.extsi %1 : tensor<4xi32> to tensor<4xi64> - %3 = arith.addi %2, %cst : tensor<4xi64> - %4 = tt.addptr %0, %3 : tensor<4x!tt.ptr>, tensor<4xi64> - %5 = tt.load %4 : tensor<4x!tt.ptr> - %6 = tt.splat %arg1 : !tt.ptr -> tensor<4x!tt.ptr> - %7 = tt.addptr %6, %2 : tensor<4x!tt.ptr>, tensor<4xi64> - tt.store %7, %5 : tensor<4x!tt.ptr> - tt.return - } -} - -// CHECK: %[[REINT_CAST0:.*]] = memref.reinterpret_cast %[[ARG0:.*]] to offset: [5], sizes: [4], strides: [1] : memref to memref<4xi64, strided<[1], offset: 5>> -// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<4xi64> -// CHECK: memref.copy %[[REINT_CAST0]], %[[ALLOC]] : memref<4xi64, strided<[1], offset: 5>> to memref<4xi64> -// CHECK: %[[VAL0:.*]] = bufferization.to_tensor %[[ALLOC]] restrict writable : memref<4xi64> -// CHECK: %[[REINT_CAST1:.*]] = memref.reinterpret_cast %[[ARG1:.*]] to offset: [0], sizes: [4], strides: [1] : memref to memref<4xi64, strided<[1]>> -// CHECK: bufferization.materialize_in_destination %[[VAL0]] in writable %[[REINT_CAST1]] : (tensor<4xi64>, memref<4xi64, strided<[1]>>) -> () - -// ----- - -// dtype: bool -module { - tt.func public @triton_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32} , %arg1: !tt.ptr {tt.divisibility = 16 : i32} ) attributes {noinline = false} { - %cst = arith.constant dense<5> : tensor<4xi64> - %0 = tt.bitcast %arg0 : !tt.ptr -> !tt.ptr - %1 = tt.splat %0 : !tt.ptr -> tensor<4x!tt.ptr> - %2 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> - %3 = arith.extsi %2 : tensor<4xi32> to tensor<4xi64> - %4 = arith.addi %3, %cst : tensor<4xi64> - %5 = tt.addptr %1, %4 : tensor<4x!tt.ptr>, tensor<4xi64> - %6 = tt.load %5 : tensor<4x!tt.ptr> - %7 = tt.bitcast %arg1 : !tt.ptr -> !tt.ptr - %8 = tt.splat %7 : !tt.ptr -> tensor<4x!tt.ptr> - %9 = tt.addptr %8, %3 : tensor<4x!tt.ptr>, tensor<4xi64> - tt.store %9, %6 : tensor<4x!tt.ptr> - tt.return - } -} - -// CHECK: %[[REINT_CAST0:.*]] = memref.reinterpret_cast %[[ARG0:.*]] to offset: [5], sizes: [4], strides: [1] : memref to memref<4xi8, strided<[1], offset: 5>> -// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<4xi8> -// CHECK: memref.copy %[[REINT_CAST0]], %[[ALLOC]] : memref<4xi8, strided<[1], offset: 5>> to memref<4xi8> -// CHECK: %[[VAL0:.*]] = bufferization.to_tensor %[[ALLOC]] restrict writable : memref<4xi8> -// CHECK: %[[REINT_CAST1:.*]] = memref.reinterpret_cast %[[ARG1:.*]] to offset: [0], sizes: [4], strides: [1] : memref to memref<4xi8, strided<[1]>> -// CHECK: bufferization.materialize_in_destination %[[VAL0]] in writable %[[REINT_CAST1]] : (tensor<4xi8>, memref<4xi8, strided<[1]>>) -> () - -// ----- - -// dtype: float16 -module { - tt.func public @triton_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32} , %arg1: !tt.ptr {tt.divisibility = 16 : i32} ) attributes {noinline = false} { - %cst = arith.constant dense<5> : tensor<4xi64> - %0 = tt.splat %arg0 : !tt.ptr -> tensor<4x!tt.ptr> - %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> - %2 = arith.extsi %1 : tensor<4xi32> to tensor<4xi64> - %3 = arith.addi %2, %cst : tensor<4xi64> - %4 = tt.addptr %0, %3 : tensor<4x!tt.ptr>, tensor<4xi64> - %5 = tt.load %4 : tensor<4x!tt.ptr> - %6 = tt.splat %arg1 : !tt.ptr -> tensor<4x!tt.ptr> - %7 = tt.addptr %6, %2 : tensor<4x!tt.ptr>, tensor<4xi64> - tt.store %7, %5 : tensor<4x!tt.ptr> - tt.return - } -} - -// CHECK: %[[REINT_CAST0:.*]] = memref.reinterpret_cast %[[ARG0:.*]] to offset: [5], sizes: [4], strides: [1] : memref to memref<4xf16, strided<[1], offset: 5>> -// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<4xf16> -// CHECK: memref.copy %[[REINT_CAST0]], %[[ALLOC]] : memref<4xf16, strided<[1], offset: 5>> to memref<4xf16> -// CHECK: %[[VAL0:.*]] = bufferization.to_tensor %[[ALLOC]] restrict writable : memref<4xf16> -// CHECK: %[[REINT_CAST1:.*]] = memref.reinterpret_cast %[[ARG1:.*]] to offset: [0], sizes: [4], strides: [1] : memref to memref<4xf16, strided<[1]>> -// CHECK: bufferization.materialize_in_destination %[[VAL0]] in writable %[[REINT_CAST1]] : (tensor<4xf16>, memref<4xf16, strided<[1]>>) -> () - - -// ----- - -// dtype: float32 -module { - tt.func public @triton_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32} , %arg1: !tt.ptr {tt.divisibility = 16 : i32} ) attributes {noinline = false} { - %cst = arith.constant dense<5> : tensor<4xi64> - %0 = tt.splat %arg0 : !tt.ptr -> tensor<4x!tt.ptr> - %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> - %2 = arith.extsi %1 : tensor<4xi32> to tensor<4xi64> - %3 = arith.addi %2, %cst : tensor<4xi64> - %4 = tt.addptr %0, %3 : tensor<4x!tt.ptr>, tensor<4xi64> - %5 = tt.load %4 : tensor<4x!tt.ptr> - %6 = tt.splat %arg1 : !tt.ptr -> tensor<4x!tt.ptr> - %7 = tt.addptr %6, %2 : tensor<4x!tt.ptr>, tensor<4xi64> - tt.store %7, %5 : tensor<4x!tt.ptr> - tt.return - } -} - -// CHECK: %[[REINT_CAST0:.*]] = memref.reinterpret_cast %[[ARG0:.*]] to offset: [5], sizes: [4], strides: [1] : memref to memref<4xf32, strided<[1], offset: 5>> -// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<4xf32> -// CHECK: memref.copy %[[REINT_CAST0]], %[[ALLOC]] : memref<4xf32, strided<[1], offset: 5>> to memref<4xf32> -// CHECK: %[[VAL0:.*]] = bufferization.to_tensor %[[ALLOC]] restrict writable : memref<4xf32> -// CHECK: %[[REINT_CAST1:.*]] = memref.reinterpret_cast %[[ARG1:.*]] to offset: [0], sizes: [4], strides: [1] : memref to memref<4xf32, strided<[1]>> -// CHECK: bufferization.materialize_in_destination %[[VAL0]] in writable %[[REINT_CAST1]] : (tensor<4xf32>, memref<4xf32, strided<[1]>>) -> () - - -// ----- - -// dtype: bfloat16 -module { - tt.func public @triton_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32} , %arg1: !tt.ptr {tt.divisibility = 16 : i32} ) attributes {noinline = false} { - %cst = arith.constant dense<5> : tensor<4xi64> - %0 = tt.splat %arg0 : !tt.ptr -> tensor<4x!tt.ptr> - %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> - %2 = arith.extsi %1 : tensor<4xi32> to tensor<4xi64> - %3 = arith.addi %2, %cst : tensor<4xi64> - %4 = tt.addptr %0, %3 : tensor<4x!tt.ptr>, tensor<4xi64> - %5 = tt.load %4 : tensor<4x!tt.ptr> - %6 = tt.splat %arg1 : !tt.ptr -> tensor<4x!tt.ptr> - %7 = tt.addptr %6, %2 : tensor<4x!tt.ptr>, tensor<4xi64> - tt.store %7, %5 : tensor<4x!tt.ptr> - tt.return - } -} - -// CHECK: %[[REINT_CAST0:.*]] = memref.reinterpret_cast %[[ARG0:.*]] to offset: [5], sizes: [4], strides: [1] : memref to memref<4xbf16, strided<[1], offset: 5>> -// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<4xbf16> -// CHECK: memref.copy %[[REINT_CAST0]], %[[ALLOC]] : memref<4xbf16, strided<[1], offset: 5>> to memref<4xbf16> -// CHECK: %[[VAL0:.*]] = bufferization.to_tensor %[[ALLOC]] restrict writable : memref<4xbf16> -// CHECK: %[[REINT_CAST1:.*]] = memref.reinterpret_cast %[[ARG1:.*]] to offset: [0], sizes: [4], strides: [1] : memref to memref<4xbf16, strided<[1]>> -// CHECK: bufferization.materialize_in_destination %[[VAL0]] in writable %[[REINT_CAST1]] : (tensor<4xbf16>, memref<4xbf16, strided<[1]>>) -> () - - -// ----- - -// dtype: float8_e5m2 -module { - tt.func public @triton_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32} , %arg1: !tt.ptr {tt.divisibility = 16 : i32} ) attributes {noinline = false} { - %cst = arith.constant dense<5> : tensor<4xi64> - %0 = tt.splat %arg0 : !tt.ptr -> tensor<4x!tt.ptr> - %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> - %2 = arith.extsi %1 : tensor<4xi32> to tensor<4xi64> - %3 = arith.addi %2, %cst : tensor<4xi64> - %4 = tt.addptr %0, %3 : tensor<4x!tt.ptr>, tensor<4xi64> - %5 = tt.load %4 : tensor<4x!tt.ptr> - %6 = tt.splat %arg1 : !tt.ptr -> tensor<4x!tt.ptr> - %7 = tt.addptr %6, %2 : tensor<4x!tt.ptr>, tensor<4xi64> - tt.store %7, %5 : tensor<4x!tt.ptr> - tt.return - } -} - -// CHECK: %[[REINT_CAST0:.*]] = memref.reinterpret_cast %[[ARG0:.*]] to offset: [5], sizes: [4], strides: [1] : memref to memref<4xf8E5M2, strided<[1], offset: 5>> -// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<4xf8E5M2> -// CHECK: memref.copy %[[REINT_CAST0]], %[[ALLOC]] : memref<4xf8E5M2, strided<[1], offset: 5>> to memref<4xf8E5M2> -// CHECK: %[[VAL0:.*]] = bufferization.to_tensor %[[ALLOC]] restrict writable : memref<4xf8E5M2> -// CHECK: %[[REINT_CAST1:.*]] = memref.reinterpret_cast %[[ARG1:.*]] to offset: [0], sizes: [4], strides: [1] : memref to memref<4xf8E5M2, strided<[1]>> -// CHECK: bufferization.materialize_in_destination %[[VAL0]] in writable %[[REINT_CAST1]] : (tensor<4xf8E5M2>, memref<4xf8E5M2, strided<[1]>>) -> () - - -// ----- - -// dtype: float8_e4m3 -module { - tt.func public @triton_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32} , %arg1: !tt.ptr {tt.divisibility = 16 : i32} ) attributes {noinline = false} { - %cst = arith.constant dense<5> : tensor<4xi64> - %0 = tt.splat %arg0 : !tt.ptr -> tensor<4x!tt.ptr> - %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> - %2 = arith.extsi %1 : tensor<4xi32> to tensor<4xi64> - %3 = arith.addi %2, %cst : tensor<4xi64> - %4 = tt.addptr %0, %3 : tensor<4x!tt.ptr>, tensor<4xi64> - %5 = tt.load %4 : tensor<4x!tt.ptr> - %6 = tt.splat %arg1 : !tt.ptr -> tensor<4x!tt.ptr> - %7 = tt.addptr %6, %2 : tensor<4x!tt.ptr>, tensor<4xi64> - tt.store %7, %5 : tensor<4x!tt.ptr> - tt.return - } -} - -// CHECK: %[[REINT_CAST0:.*]] = memref.reinterpret_cast %[[ARG0:.*]] to offset: [5], sizes: [4], strides: [1] : memref to memref<4xf8E4M3FN, strided<[1], offset: 5>> -// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<4xf8E4M3FN> -// CHECK: memref.copy %[[REINT_CAST0]], %[[ALLOC]] : memref<4xf8E4M3FN, strided<[1], offset: 5>> to memref<4xf8E4M3FN> -// CHECK: %[[VAL0:.*]] = bufferization.to_tensor %[[ALLOC]] restrict writable : memref<4xf8E4M3FN> -// CHECK: %[[REINT_CAST1:.*]] = memref.reinterpret_cast %[[ARG1:.*]] to offset: [0], sizes: [4], strides: [1] : memref to memref<4xf8E4M3FN, strided<[1]>> -// CHECK: bufferization.materialize_in_destination %[[VAL0]] in writable %[[REINT_CAST1]] : (tensor<4xf8E4M3FN>, memref<4xf8E4M3FN, strided<[1]>>) -> () +// RUN: triton-adapter-opt %s --triton-linearize '--discrete-mask-access-conversion=compile-on-910-95=False force-simt-template=False' --triton-to-annotation '--triton-to-unstructure=compile-on-910-95=False force-simt-template=False' --triton-to-hivm --triton-to-hfusion --triton-to-llvm --bubble-up-operation '--triton-to-linalg=global-kernel=false named-ops=True enable-nd2nz-on-vector=False compile-on-910-95=False' --split-input-file | FileCheck %s + +// dtype: uint8 & int8 +module { + tt.func public @triton_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<5> : tensor<4xi64> + %0 = tt.splat %arg0 : !tt.ptr -> tensor<4x!tt.ptr> + %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %2 = arith.extsi %1 : tensor<4xi32> to tensor<4xi64> + %3 = arith.addi %2, %cst : tensor<4xi64> + %4 = tt.addptr %0, %3 : tensor<4x!tt.ptr>, tensor<4xi64> + %5 = tt.load %4 : tensor<4x!tt.ptr> + %6 = tt.splat %arg1 : !tt.ptr -> tensor<4x!tt.ptr> + %7 = tt.addptr %6, %2 : tensor<4x!tt.ptr>, tensor<4xi64> + tt.store %7, %5 : tensor<4x!tt.ptr> + tt.return + } +} + +// CHECK: %[[REINT_CAST0:.*]] = memref.reinterpret_cast %[[ARG0:.*]] to offset: [5], sizes: [4], strides: [1] : memref to memref<4xi8, strided<[1], offset: 5>> +// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<4xi8> +// CHECK: memref.copy %[[REINT_CAST0]], %[[ALLOC]] : memref<4xi8, strided<[1], offset: 5>> to memref<4xi8> +// CHECK: %[[VAL0:.*]] = bufferization.to_tensor %[[ALLOC]] restrict writable : memref<4xi8> +// CHECK: %[[REINT_CAST1:.*]] = memref.reinterpret_cast %[[ARG1:.*]] to offset: [0], sizes: [4], strides: [1] : memref to memref<4xi8, strided<[1]>> +// CHECK: bufferization.materialize_in_destination %[[VAL0]] in writable %[[REINT_CAST1]] : (tensor<4xi8>, memref<4xi8, strided<[1]>>) -> () + +// ----- + +// dtype: uint16 & int16 +module { + tt.func public @triton_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32} , %arg1: !tt.ptr {tt.divisibility = 16 : i32} ) attributes {noinline = false} { + %cst = arith.constant dense<5> : tensor<4xi64> + %0 = tt.splat %arg0 : !tt.ptr -> tensor<4x!tt.ptr> + %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %2 = arith.extsi %1 : tensor<4xi32> to tensor<4xi64> + %3 = arith.addi %2, %cst : tensor<4xi64> + %4 = tt.addptr %0, %3 : tensor<4x!tt.ptr>, tensor<4xi64> + %5 = tt.load %4 : tensor<4x!tt.ptr> + %6 = tt.splat %arg1 : !tt.ptr -> tensor<4x!tt.ptr> + %7 = tt.addptr %6, %2 : tensor<4x!tt.ptr>, tensor<4xi64> + tt.store %7, %5 : tensor<4x!tt.ptr> + tt.return + } +} + +// CHECK: %[[REINT_CAST0:.*]] = memref.reinterpret_cast %[[ARG0:.*]] to offset: [5], sizes: [4], strides: [1] : memref to memref<4xi16, strided<[1], offset: 5>> +// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<4xi16> +// CHECK: memref.copy %[[REINT_CAST0]], %[[ALLOC]] : memref<4xi16, strided<[1], offset: 5>> to memref<4xi16> +// CHECK: %[[VAL0:.*]] = bufferization.to_tensor %[[ALLOC]] restrict writable : memref<4xi16> +// CHECK: %[[REINT_CAST1:.*]] = memref.reinterpret_cast %[[ARG1:.*]] to offset: [0], sizes: [4], strides: [1] : memref to memref<4xi16, strided<[1]>> +// CHECK: bufferization.materialize_in_destination %[[VAL0]] in writable %[[REINT_CAST1]] : (tensor<4xi16>, memref<4xi16, strided<[1]>>) -> () + +// ----- + +// dtype: uint32 & int32 +module { + tt.func public @triton_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32} , %arg1: !tt.ptr {tt.divisibility = 16 : i32} ) attributes {noinline = false} { + %cst = arith.constant dense<5> : tensor<4xi64> + %0 = tt.splat %arg0 : !tt.ptr -> tensor<4x!tt.ptr> + %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %2 = arith.extsi %1 : tensor<4xi32> to tensor<4xi64> + %3 = arith.addi %2, %cst : tensor<4xi64> + %4 = tt.addptr %0, %3 : tensor<4x!tt.ptr>, tensor<4xi64> + %5 = tt.load %4 : tensor<4x!tt.ptr> + %6 = tt.splat %arg1 : !tt.ptr -> tensor<4x!tt.ptr> + %7 = tt.addptr %6, %2 : tensor<4x!tt.ptr>, tensor<4xi64> + tt.store %7, %5 : tensor<4x!tt.ptr> + tt.return + } +} + +// CHECK: %[[REINT_CAST0:.*]] = memref.reinterpret_cast %[[ARG0:.*]] to offset: [5], sizes: [4], strides: [1] : memref to memref<4xi32, strided<[1], offset: 5>> +// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<4xi32> +// CHECK: memref.copy %[[REINT_CAST0]], %[[ALLOC]] : memref<4xi32, strided<[1], offset: 5>> to memref<4xi32> +// CHECK: %[[VAL0:.*]] = bufferization.to_tensor %[[ALLOC]] restrict writable : memref<4xi32> +// CHECK: %[[REINT_CAST1:.*]] = memref.reinterpret_cast %[[ARG1:.*]] to offset: [0], sizes: [4], strides: [1] : memref to memref<4xi32, strided<[1]>> +// CHECK: bufferization.materialize_in_destination %[[VAL0]] in writable %[[REINT_CAST1]] : (tensor<4xi32>, memref<4xi32, strided<[1]>>) -> () + + +// ----- + +// dtype: uint64 & int64 +module { + tt.func public @triton_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32} , %arg1: !tt.ptr {tt.divisibility = 16 : i32} ) attributes {noinline = false} { + %cst = arith.constant dense<5> : tensor<4xi64> + %0 = tt.splat %arg0 : !tt.ptr -> tensor<4x!tt.ptr> + %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %2 = arith.extsi %1 : tensor<4xi32> to tensor<4xi64> + %3 = arith.addi %2, %cst : tensor<4xi64> + %4 = tt.addptr %0, %3 : tensor<4x!tt.ptr>, tensor<4xi64> + %5 = tt.load %4 : tensor<4x!tt.ptr> + %6 = tt.splat %arg1 : !tt.ptr -> tensor<4x!tt.ptr> + %7 = tt.addptr %6, %2 : tensor<4x!tt.ptr>, tensor<4xi64> + tt.store %7, %5 : tensor<4x!tt.ptr> + tt.return + } +} + +// CHECK: %[[REINT_CAST0:.*]] = memref.reinterpret_cast %[[ARG0:.*]] to offset: [5], sizes: [4], strides: [1] : memref to memref<4xi64, strided<[1], offset: 5>> +// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<4xi64> +// CHECK: memref.copy %[[REINT_CAST0]], %[[ALLOC]] : memref<4xi64, strided<[1], offset: 5>> to memref<4xi64> +// CHECK: %[[VAL0:.*]] = bufferization.to_tensor %[[ALLOC]] restrict writable : memref<4xi64> +// CHECK: %[[REINT_CAST1:.*]] = memref.reinterpret_cast %[[ARG1:.*]] to offset: [0], sizes: [4], strides: [1] : memref to memref<4xi64, strided<[1]>> +// CHECK: bufferization.materialize_in_destination %[[VAL0]] in writable %[[REINT_CAST1]] : (tensor<4xi64>, memref<4xi64, strided<[1]>>) -> () + +// ----- + +// dtype: bool +module { + tt.func public @triton_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32} , %arg1: !tt.ptr {tt.divisibility = 16 : i32} ) attributes {noinline = false} { + %cst = arith.constant dense<5> : tensor<4xi64> + %0 = tt.bitcast %arg0 : !tt.ptr -> !tt.ptr + %1 = tt.splat %0 : !tt.ptr -> tensor<4x!tt.ptr> + %2 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %3 = arith.extsi %2 : tensor<4xi32> to tensor<4xi64> + %4 = arith.addi %3, %cst : tensor<4xi64> + %5 = tt.addptr %1, %4 : tensor<4x!tt.ptr>, tensor<4xi64> + %6 = tt.load %5 : tensor<4x!tt.ptr> + %7 = tt.bitcast %arg1 : !tt.ptr -> !tt.ptr + %8 = tt.splat %7 : !tt.ptr -> tensor<4x!tt.ptr> + %9 = tt.addptr %8, %3 : tensor<4x!tt.ptr>, tensor<4xi64> + tt.store %9, %6 : tensor<4x!tt.ptr> + tt.return + } +} + +// CHECK: %[[REINT_CAST0:.*]] = memref.reinterpret_cast %[[ARG0:.*]] to offset: [5], sizes: [4], strides: [1] : memref to memref<4xi8, strided<[1], offset: 5>> +// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<4xi8> +// CHECK: memref.copy %[[REINT_CAST0]], %[[ALLOC]] : memref<4xi8, strided<[1], offset: 5>> to memref<4xi8> +// CHECK: %[[VAL0:.*]] = bufferization.to_tensor %[[ALLOC]] restrict writable : memref<4xi8> +// CHECK: %[[REINT_CAST1:.*]] = memref.reinterpret_cast %[[ARG1:.*]] to offset: [0], sizes: [4], strides: [1] : memref to memref<4xi8, strided<[1]>> +// CHECK: bufferization.materialize_in_destination %[[VAL0]] in writable %[[REINT_CAST1]] : (tensor<4xi8>, memref<4xi8, strided<[1]>>) -> () + +// ----- + +// dtype: float16 +module { + tt.func public @triton_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32} , %arg1: !tt.ptr {tt.divisibility = 16 : i32} ) attributes {noinline = false} { + %cst = arith.constant dense<5> : tensor<4xi64> + %0 = tt.splat %arg0 : !tt.ptr -> tensor<4x!tt.ptr> + %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %2 = arith.extsi %1 : tensor<4xi32> to tensor<4xi64> + %3 = arith.addi %2, %cst : tensor<4xi64> + %4 = tt.addptr %0, %3 : tensor<4x!tt.ptr>, tensor<4xi64> + %5 = tt.load %4 : tensor<4x!tt.ptr> + %6 = tt.splat %arg1 : !tt.ptr -> tensor<4x!tt.ptr> + %7 = tt.addptr %6, %2 : tensor<4x!tt.ptr>, tensor<4xi64> + tt.store %7, %5 : tensor<4x!tt.ptr> + tt.return + } +} + +// CHECK: %[[REINT_CAST0:.*]] = memref.reinterpret_cast %[[ARG0:.*]] to offset: [5], sizes: [4], strides: [1] : memref to memref<4xf16, strided<[1], offset: 5>> +// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<4xf16> +// CHECK: memref.copy %[[REINT_CAST0]], %[[ALLOC]] : memref<4xf16, strided<[1], offset: 5>> to memref<4xf16> +// CHECK: %[[VAL0:.*]] = bufferization.to_tensor %[[ALLOC]] restrict writable : memref<4xf16> +// CHECK: %[[REINT_CAST1:.*]] = memref.reinterpret_cast %[[ARG1:.*]] to offset: [0], sizes: [4], strides: [1] : memref to memref<4xf16, strided<[1]>> +// CHECK: bufferization.materialize_in_destination %[[VAL0]] in writable %[[REINT_CAST1]] : (tensor<4xf16>, memref<4xf16, strided<[1]>>) -> () + + +// ----- + +// dtype: float32 +module { + tt.func public @triton_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32} , %arg1: !tt.ptr {tt.divisibility = 16 : i32} ) attributes {noinline = false} { + %cst = arith.constant dense<5> : tensor<4xi64> + %0 = tt.splat %arg0 : !tt.ptr -> tensor<4x!tt.ptr> + %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %2 = arith.extsi %1 : tensor<4xi32> to tensor<4xi64> + %3 = arith.addi %2, %cst : tensor<4xi64> + %4 = tt.addptr %0, %3 : tensor<4x!tt.ptr>, tensor<4xi64> + %5 = tt.load %4 : tensor<4x!tt.ptr> + %6 = tt.splat %arg1 : !tt.ptr -> tensor<4x!tt.ptr> + %7 = tt.addptr %6, %2 : tensor<4x!tt.ptr>, tensor<4xi64> + tt.store %7, %5 : tensor<4x!tt.ptr> + tt.return + } +} + +// CHECK: %[[REINT_CAST0:.*]] = memref.reinterpret_cast %[[ARG0:.*]] to offset: [5], sizes: [4], strides: [1] : memref to memref<4xf32, strided<[1], offset: 5>> +// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<4xf32> +// CHECK: memref.copy %[[REINT_CAST0]], %[[ALLOC]] : memref<4xf32, strided<[1], offset: 5>> to memref<4xf32> +// CHECK: %[[VAL0:.*]] = bufferization.to_tensor %[[ALLOC]] restrict writable : memref<4xf32> +// CHECK: %[[REINT_CAST1:.*]] = memref.reinterpret_cast %[[ARG1:.*]] to offset: [0], sizes: [4], strides: [1] : memref to memref<4xf32, strided<[1]>> +// CHECK: bufferization.materialize_in_destination %[[VAL0]] in writable %[[REINT_CAST1]] : (tensor<4xf32>, memref<4xf32, strided<[1]>>) -> () + + +// ----- + +// dtype: bfloat16 +module { + tt.func public @triton_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32} , %arg1: !tt.ptr {tt.divisibility = 16 : i32} ) attributes {noinline = false} { + %cst = arith.constant dense<5> : tensor<4xi64> + %0 = tt.splat %arg0 : !tt.ptr -> tensor<4x!tt.ptr> + %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %2 = arith.extsi %1 : tensor<4xi32> to tensor<4xi64> + %3 = arith.addi %2, %cst : tensor<4xi64> + %4 = tt.addptr %0, %3 : tensor<4x!tt.ptr>, tensor<4xi64> + %5 = tt.load %4 : tensor<4x!tt.ptr> + %6 = tt.splat %arg1 : !tt.ptr -> tensor<4x!tt.ptr> + %7 = tt.addptr %6, %2 : tensor<4x!tt.ptr>, tensor<4xi64> + tt.store %7, %5 : tensor<4x!tt.ptr> + tt.return + } +} + +// CHECK: %[[REINT_CAST0:.*]] = memref.reinterpret_cast %[[ARG0:.*]] to offset: [5], sizes: [4], strides: [1] : memref to memref<4xbf16, strided<[1], offset: 5>> +// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<4xbf16> +// CHECK: memref.copy %[[REINT_CAST0]], %[[ALLOC]] : memref<4xbf16, strided<[1], offset: 5>> to memref<4xbf16> +// CHECK: %[[VAL0:.*]] = bufferization.to_tensor %[[ALLOC]] restrict writable : memref<4xbf16> +// CHECK: %[[REINT_CAST1:.*]] = memref.reinterpret_cast %[[ARG1:.*]] to offset: [0], sizes: [4], strides: [1] : memref to memref<4xbf16, strided<[1]>> +// CHECK: bufferization.materialize_in_destination %[[VAL0]] in writable %[[REINT_CAST1]] : (tensor<4xbf16>, memref<4xbf16, strided<[1]>>) -> () + + +// ----- + +// dtype: float8_e5m2 +module { + tt.func public @triton_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32} , %arg1: !tt.ptr {tt.divisibility = 16 : i32} ) attributes {noinline = false} { + %cst = arith.constant dense<5> : tensor<4xi64> + %0 = tt.splat %arg0 : !tt.ptr -> tensor<4x!tt.ptr> + %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %2 = arith.extsi %1 : tensor<4xi32> to tensor<4xi64> + %3 = arith.addi %2, %cst : tensor<4xi64> + %4 = tt.addptr %0, %3 : tensor<4x!tt.ptr>, tensor<4xi64> + %5 = tt.load %4 : tensor<4x!tt.ptr> + %6 = tt.splat %arg1 : !tt.ptr -> tensor<4x!tt.ptr> + %7 = tt.addptr %6, %2 : tensor<4x!tt.ptr>, tensor<4xi64> + tt.store %7, %5 : tensor<4x!tt.ptr> + tt.return + } +} + +// CHECK: %[[REINT_CAST0:.*]] = memref.reinterpret_cast %[[ARG0:.*]] to offset: [5], sizes: [4], strides: [1] : memref to memref<4xf8E5M2, strided<[1], offset: 5>> +// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<4xf8E5M2> +// CHECK: memref.copy %[[REINT_CAST0]], %[[ALLOC]] : memref<4xf8E5M2, strided<[1], offset: 5>> to memref<4xf8E5M2> +// CHECK: %[[VAL0:.*]] = bufferization.to_tensor %[[ALLOC]] restrict writable : memref<4xf8E5M2> +// CHECK: %[[REINT_CAST1:.*]] = memref.reinterpret_cast %[[ARG1:.*]] to offset: [0], sizes: [4], strides: [1] : memref to memref<4xf8E5M2, strided<[1]>> +// CHECK: bufferization.materialize_in_destination %[[VAL0]] in writable %[[REINT_CAST1]] : (tensor<4xf8E5M2>, memref<4xf8E5M2, strided<[1]>>) -> () + + +// ----- + +// dtype: float8_e4m3 +module { + tt.func public @triton_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32} , %arg1: !tt.ptr {tt.divisibility = 16 : i32} ) attributes {noinline = false} { + %cst = arith.constant dense<5> : tensor<4xi64> + %0 = tt.splat %arg0 : !tt.ptr -> tensor<4x!tt.ptr> + %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %2 = arith.extsi %1 : tensor<4xi32> to tensor<4xi64> + %3 = arith.addi %2, %cst : tensor<4xi64> + %4 = tt.addptr %0, %3 : tensor<4x!tt.ptr>, tensor<4xi64> + %5 = tt.load %4 : tensor<4x!tt.ptr> + %6 = tt.splat %arg1 : !tt.ptr -> tensor<4x!tt.ptr> + %7 = tt.addptr %6, %2 : tensor<4x!tt.ptr>, tensor<4xi64> + tt.store %7, %5 : tensor<4x!tt.ptr> + tt.return + } +} + +// CHECK: %[[REINT_CAST0:.*]] = memref.reinterpret_cast %[[ARG0:.*]] to offset: [5], sizes: [4], strides: [1] : memref to memref<4xf8E4M3FN, strided<[1], offset: 5>> +// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<4xf8E4M3FN> +// CHECK: memref.copy %[[REINT_CAST0]], %[[ALLOC]] : memref<4xf8E4M3FN, strided<[1], offset: 5>> to memref<4xf8E4M3FN> +// CHECK: %[[VAL0:.*]] = bufferization.to_tensor %[[ALLOC]] restrict writable : memref<4xf8E4M3FN> +// CHECK: %[[REINT_CAST1:.*]] = memref.reinterpret_cast %[[ARG1:.*]] to offset: [0], sizes: [4], strides: [1] : memref to memref<4xf8E4M3FN, strided<[1]>> +// CHECK: bufferization.materialize_in_destination %[[VAL0]] in writable %[[REINT_CAST1]] : (tensor<4xf8E4M3FN>, memref<4xf8E4M3FN, strided<[1]>>) -> () diff --git a/third_party/ascend/test/Conversion/TritonOp/max_uint.mlir b/third_party/ascend/test/Conversion/TritonOp/max_uint.mlir index 78d77dac5..773f7c84a 100644 --- a/third_party/ascend/test/Conversion/TritonOp/max_uint.mlir +++ b/third_party/ascend/test/Conversion/TritonOp/max_uint.mlir @@ -1,78 +1,78 @@ -// RUN: triton-adapter-opt --triton-linearize --discrete-mask-access-conversion --triton-to-annotation '--triton-to-unstructure=compile-on-a5=False force_simt_template=False' --triton-to-hivm --triton-to-hfusion --triton-to-llvm --bubble-up-operation '--triton-to-linalg=global-kernel=false named-ops=True enable-nd2nz-on-vector=False compile-on-a5=False' --split-input-file %s | FileCheck %s - -module { - tt.func public @triton_max_1d8(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/max_uint.py":22:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/max_uint.py":22:0), %arg2: i32 {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/max_uint.py":22:0)) attributes {noinline = false} { - %true = arith.constant true loc(#loc1) - %0 = tt.get_program_id x : i32 loc(#loc2) - %1 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> loc(#loc3) - %2 = tt.splat %0 : i32 -> tensor<16xi32> loc(#loc4) - %3 = arith.addi %2, %1 : tensor<16xi32> loc(#loc4) - %4 = tt.splat %arg0 : !tt.ptr -> tensor<16x!tt.ptr> loc(#loc5) - %5 = tt.addptr %4, %3 : tensor<16x!tt.ptr>, tensor<16xi32> loc(#loc5) - %6 = tt.load %5 : tensor<16x!tt.ptr> loc(#loc6) - tt.assert %true, "Expecting input to be integer type" : i1 loc(#loc14) - %7 = arith.extui %6 : tensor<16xi8> to tensor<16xi32> loc(#loc15) - %8 = "tt.reduce"(%7) <{axis = 0 : i32}> ({ - ^bb0(%arg3: i32 loc(callsite(#loc1 at #loc8)), %arg4: i32 loc(callsite(#loc1 at #loc8))): - %10 = arith.maxsi %arg3, %arg4 : i32 loc(#loc19) - tt.reduce.return %10 : i32 loc(#loc16) - }) : (tensor<16xi32>) -> i32 loc(#loc16) - %9 = arith.trunci %8 : i32 to i8 loc(#loc12) - tt.store %arg1, %9 : !tt.ptr loc(#loc12) - tt.return loc(#loc13) - } loc(#loc) -} loc(#loc) - - -// CHECK: %[[VAL_2:[A-Za-z0-9_]+]] = arith.maxsi %[[VAL_0]], %[[VAL_1]] : i32 -// ----------- - - -module { - tt.func public @triton_max_1d(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/max_uint.py":45:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/max_uint.py":45:0), %arg2: i32 {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/max_uint.py":45:0)) attributes {noinline = false} { - %true = arith.constant true loc(#loc1) - %0 = tt.get_program_id x : i32 loc(#loc2) - %1 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> loc(#loc3) - %2 = tt.splat %0 : i32 -> tensor<16xi32> loc(#loc4) - %3 = arith.addi %2, %1 : tensor<16xi32> loc(#loc4) - %4 = tt.splat %arg0 : !tt.ptr -> tensor<16x!tt.ptr> loc(#loc5) - %5 = tt.addptr %4, %3 : tensor<16x!tt.ptr>, tensor<16xi32> loc(#loc5) - %6 = tt.load %5 : tensor<16x!tt.ptr> loc(#loc6) - tt.assert %true, "Expecting input to be integer type" : i1 loc(#loc14) - %7 = arith.extui %6 : tensor<16xi16> to tensor<16xi32> loc(#loc15) - %8 = "tt.reduce"(%7) <{axis = 0 : i32}> ({ - ^bb0(%arg3: i32 loc(callsite(#loc1 at #loc8)), %arg4: i32 loc(callsite(#loc1 at #loc8))): - %10 = arith.maxsi %arg3, %arg4 : i32 loc(#loc19) - tt.reduce.return %10 : i32 loc(#loc16) - }) : (tensor<16xi32>) -> i32 loc(#loc16) - %9 = arith.trunci %8 : i32 to i16 loc(#loc12) - tt.store %arg1, %9 : !tt.ptr loc(#loc12) - tt.return loc(#loc13) - } loc(#loc) -} loc(#loc) - - -// CHECK: %[[VAL_2:[A-Za-z0-9_]+]] = arith.maxsi %[[VAL_0]], %[[VAL_1]] : i32 -// ----------- - -module { - tt.func public @triton_max_1d(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/max_uint.py":45:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/max_uint.py":45:0), %arg2: i32 {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/max_uint.py":45:0)) attributes {noinline = false} { - %0 = tt.get_program_id x : i32 loc(#loc1) - %1 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> loc(#loc2) - %2 = tt.splat %0 : i32 -> tensor<16xi32> loc(#loc3) - %3 = arith.addi %2, %1 : tensor<16xi32> loc(#loc3) - %4 = tt.splat %arg0 : !tt.ptr -> tensor<16x!tt.ptr> loc(#loc4) - %5 = tt.addptr %4, %3 : tensor<16x!tt.ptr>, tensor<16xi32> loc(#loc4) - %6 = tt.load %5 : tensor<16x!tt.ptr> loc(#loc5) - %7 = "tt.reduce"(%6) <{axis = 0 : i32}> ({ - ^bb0(%arg3: i32 loc(callsite(#loc8 at #loc7)), %arg4: i32 loc(callsite(#loc8 at #loc7))): - %8 = arith.maxui %arg3, %arg4 : i32 loc(#loc15) - tt.reduce.return %8 : i32 loc(#loc12) - }) : (tensor<16xi32>) -> i32 loc(#loc12) - tt.store %arg1, %7 : !tt.ptr loc(#loc10) - tt.return loc(#loc11) - } loc(#loc) -} loc(#loc) - -// CHECK: %[[VAL_2:[A-Za-z0-9_]+]] = arith.maxui %[[VAL_0]], %[[VAL_1]] : i32 -// ----------- +// RUN: triton-adapter-opt --triton-linearize --discrete-mask-access-conversion --triton-to-annotation '--triton-to-unstructure=compile-on-a5=False force_simt_template=False' --triton-to-hivm --triton-to-hfusion --triton-to-llvm --bubble-up-operation '--triton-to-linalg=global-kernel=false named-ops=True enable-nd2nz-on-vector=False compile-on-a5=False' --split-input-file %s | FileCheck %s + +module { + tt.func public @triton_max_1d8(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/max_uint.py":22:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/max_uint.py":22:0), %arg2: i32 {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/max_uint.py":22:0)) attributes {noinline = false} { + %true = arith.constant true loc(#loc1) + %0 = tt.get_program_id x : i32 loc(#loc2) + %1 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> loc(#loc3) + %2 = tt.splat %0 : i32 -> tensor<16xi32> loc(#loc4) + %3 = arith.addi %2, %1 : tensor<16xi32> loc(#loc4) + %4 = tt.splat %arg0 : !tt.ptr -> tensor<16x!tt.ptr> loc(#loc5) + %5 = tt.addptr %4, %3 : tensor<16x!tt.ptr>, tensor<16xi32> loc(#loc5) + %6 = tt.load %5 : tensor<16x!tt.ptr> loc(#loc6) + tt.assert %true, "Expecting input to be integer type" : i1 loc(#loc14) + %7 = arith.extui %6 : tensor<16xi8> to tensor<16xi32> loc(#loc15) + %8 = "tt.reduce"(%7) <{axis = 0 : i32}> ({ + ^bb0(%arg3: i32 loc(callsite(#loc1 at #loc8)), %arg4: i32 loc(callsite(#loc1 at #loc8))): + %10 = arith.maxsi %arg3, %arg4 : i32 loc(#loc19) + tt.reduce.return %10 : i32 loc(#loc16) + }) : (tensor<16xi32>) -> i32 loc(#loc16) + %9 = arith.trunci %8 : i32 to i8 loc(#loc12) + tt.store %arg1, %9 : !tt.ptr loc(#loc12) + tt.return loc(#loc13) + } loc(#loc) +} loc(#loc) + + +// CHECK: %[[VAL_2:[A-Za-z0-9_]+]] = arith.maxsi %[[VAL_0]], %[[VAL_1]] : i32 +// ----------- + + +module { + tt.func public @triton_max_1d(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/max_uint.py":45:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/max_uint.py":45:0), %arg2: i32 {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/max_uint.py":45:0)) attributes {noinline = false} { + %true = arith.constant true loc(#loc1) + %0 = tt.get_program_id x : i32 loc(#loc2) + %1 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> loc(#loc3) + %2 = tt.splat %0 : i32 -> tensor<16xi32> loc(#loc4) + %3 = arith.addi %2, %1 : tensor<16xi32> loc(#loc4) + %4 = tt.splat %arg0 : !tt.ptr -> tensor<16x!tt.ptr> loc(#loc5) + %5 = tt.addptr %4, %3 : tensor<16x!tt.ptr>, tensor<16xi32> loc(#loc5) + %6 = tt.load %5 : tensor<16x!tt.ptr> loc(#loc6) + tt.assert %true, "Expecting input to be integer type" : i1 loc(#loc14) + %7 = arith.extui %6 : tensor<16xi16> to tensor<16xi32> loc(#loc15) + %8 = "tt.reduce"(%7) <{axis = 0 : i32}> ({ + ^bb0(%arg3: i32 loc(callsite(#loc1 at #loc8)), %arg4: i32 loc(callsite(#loc1 at #loc8))): + %10 = arith.maxsi %arg3, %arg4 : i32 loc(#loc19) + tt.reduce.return %10 : i32 loc(#loc16) + }) : (tensor<16xi32>) -> i32 loc(#loc16) + %9 = arith.trunci %8 : i32 to i16 loc(#loc12) + tt.store %arg1, %9 : !tt.ptr loc(#loc12) + tt.return loc(#loc13) + } loc(#loc) +} loc(#loc) + + +// CHECK: %[[VAL_2:[A-Za-z0-9_]+]] = arith.maxsi %[[VAL_0]], %[[VAL_1]] : i32 +// ----------- + +module { + tt.func public @triton_max_1d(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/max_uint.py":45:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/max_uint.py":45:0), %arg2: i32 {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/max_uint.py":45:0)) attributes {noinline = false} { + %0 = tt.get_program_id x : i32 loc(#loc1) + %1 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> loc(#loc2) + %2 = tt.splat %0 : i32 -> tensor<16xi32> loc(#loc3) + %3 = arith.addi %2, %1 : tensor<16xi32> loc(#loc3) + %4 = tt.splat %arg0 : !tt.ptr -> tensor<16x!tt.ptr> loc(#loc4) + %5 = tt.addptr %4, %3 : tensor<16x!tt.ptr>, tensor<16xi32> loc(#loc4) + %6 = tt.load %5 : tensor<16x!tt.ptr> loc(#loc5) + %7 = "tt.reduce"(%6) <{axis = 0 : i32}> ({ + ^bb0(%arg3: i32 loc(callsite(#loc8 at #loc7)), %arg4: i32 loc(callsite(#loc8 at #loc7))): + %8 = arith.maxui %arg3, %arg4 : i32 loc(#loc15) + tt.reduce.return %8 : i32 loc(#loc12) + }) : (tensor<16xi32>) -> i32 loc(#loc12) + tt.store %arg1, %7 : !tt.ptr loc(#loc10) + tt.return loc(#loc11) + } loc(#loc) +} loc(#loc) + +// CHECK: %[[VAL_2:[A-Za-z0-9_]+]] = arith.maxui %[[VAL_0]], %[[VAL_1]] : i32 +// ----------- diff --git a/third_party/ascend/test/Conversion/TritonOp/min_uint.mlir b/third_party/ascend/test/Conversion/TritonOp/min_uint.mlir index 813313604..1c96232e9 100644 --- a/third_party/ascend/test/Conversion/TritonOp/min_uint.mlir +++ b/third_party/ascend/test/Conversion/TritonOp/min_uint.mlir @@ -1,76 +1,76 @@ -// RUN: triton-adapter-opt --triton-linearize --discrete-mask-access-conversion --triton-to-annotation '--triton-to-unstructure=compile-on-a5=False force_simt_template=False' --triton-to-hivm --triton-to-hfusion --triton-to-llvm --bubble-up-operation '--triton-to-linalg=global-kernel=false named-ops=True enable-nd2nz-on-vector=False compile-on-a5=False' --split-input-file %s | FileCheck %s - -module { - tt.func public @triton_min_1d8(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/min_uint.py":22:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/min_uint.py":22:0), %arg2: i32 {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/min_uint.py":22:0)) attributes {noinline = false} { - %true = arith.constant true loc(#loc1) - %0 = tt.get_program_id x : i32 loc(#loc2) - %1 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> loc(#loc3) - %2 = tt.splat %0 : i32 -> tensor<16xi32> loc(#loc4) - %3 = arith.addi %2, %1 : tensor<16xi32> loc(#loc4) - %4 = tt.splat %arg0 : !tt.ptr -> tensor<16x!tt.ptr> loc(#loc5) - %5 = tt.addptr %4, %3 : tensor<16x!tt.ptr>, tensor<16xi32> loc(#loc5) - %6 = tt.load %5 : tensor<16x!tt.ptr> loc(#loc6) - tt.assert %true, "Expecting input to be integer type" : i1 loc(#loc14) - %7 = arith.extui %6 : tensor<16xi8> to tensor<16xi32> loc(#loc15) - %8 = "tt.reduce"(%7) <{axis = 0 : i32}> ({ - ^bb0(%arg3: i32 loc(callsite(#loc1 at #loc8)), %arg4: i32 loc(callsite(#loc1 at #loc8))): - %10 = arith.minsi %arg3, %arg4 : i32 loc(#loc19) - tt.reduce.return %10 : i32 loc(#loc16) - }) : (tensor<16xi32>) -> i32 loc(#loc16) - %9 = arith.trunci %8 : i32 to i8 loc(#loc12) - tt.store %arg1, %9 : !tt.ptr loc(#loc12) - tt.return loc(#loc13) - } loc(#loc) -} loc(#loc) - -// CHECK: %[[VAL_2:[A-Za-z0-9_]+]] = arith.minsi %[[VAL_0]], %[[VAL_1]] : i32 -// ----------- - - -module { - tt.func public @triton_min_1d(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/min_uint.py":45:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/min_uint.py":45:0), %arg2: i32 {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/min_uint.py":45:0)) attributes {noinline = false} { - %true = arith.constant true loc(#loc1) - %0 = tt.get_program_id x : i32 loc(#loc2) - %1 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> loc(#loc3) - %2 = tt.splat %0 : i32 -> tensor<16xi32> loc(#loc4) - %3 = arith.addi %2, %1 : tensor<16xi32> loc(#loc4) - %4 = tt.splat %arg0 : !tt.ptr -> tensor<16x!tt.ptr> loc(#loc5) - %5 = tt.addptr %4, %3 : tensor<16x!tt.ptr>, tensor<16xi32> loc(#loc5) - %6 = tt.load %5 : tensor<16x!tt.ptr> loc(#loc6) - tt.assert %true, "Expecting input to be integer type" : i1 loc(#loc14) - %7 = arith.extui %6 : tensor<16xi16> to tensor<16xi32> loc(#loc15) - %8 = "tt.reduce"(%7) <{axis = 0 : i32}> ({ - ^bb0(%arg3: i32 loc(callsite(#loc1 at #loc8)), %arg4: i32 loc(callsite(#loc1 at #loc8))): - %10 = arith.minsi %arg3, %arg4 : i32 loc(#loc19) - tt.reduce.return %10 : i32 loc(#loc16) - }) : (tensor<16xi32>) -> i32 loc(#loc16) - %9 = arith.trunci %8 : i32 to i16 loc(#loc12) - tt.store %arg1, %9 : !tt.ptr loc(#loc12) - tt.return loc(#loc13) - } loc(#loc) -} loc(#loc) - -// CHECK: %[[VAL_2:[A-Za-z0-9_]+]] = arith.minsi %[[VAL_0]], %[[VAL_1]] : i32 -// ----------- - -module { - tt.func public @triton_min_1d(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/min_uint.py":45:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/min_uint.py":45:0), %arg2: i32 {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/min_uint.py":45:0)) attributes {noinline = false} { - %0 = tt.get_program_id x : i32 loc(#loc1) - %1 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> loc(#loc2) - %2 = tt.splat %0 : i32 -> tensor<16xi32> loc(#loc3) - %3 = arith.addi %2, %1 : tensor<16xi32> loc(#loc3) - %4 = tt.splat %arg0 : !tt.ptr -> tensor<16x!tt.ptr> loc(#loc4) - %5 = tt.addptr %4, %3 : tensor<16x!tt.ptr>, tensor<16xi32> loc(#loc4) - %6 = tt.load %5 : tensor<16x!tt.ptr> loc(#loc5) - %7 = "tt.reduce"(%6) <{axis = 0 : i32}> ({ - ^bb0(%arg3: i32 loc(callsite(#loc8 at #loc7)), %arg4: i32 loc(callsite(#loc8 at #loc7))): - %8 = arith.minui %arg3, %arg4 : i32 loc(#loc15) - tt.reduce.return %8 : i32 loc(#loc12) - }) : (tensor<16xi32>) -> i32 loc(#loc12) - tt.store %arg1, %7 : !tt.ptr loc(#loc10) - tt.return loc(#loc11) - } loc(#loc) -} loc(#loc) - -// CHECK: %[[VAL_2:[A-Za-z0-9_]+]] = arith.minui %[[VAL_0]], %[[VAL_1]] : i32 +// RUN: triton-adapter-opt --triton-linearize --discrete-mask-access-conversion --triton-to-annotation '--triton-to-unstructure=compile-on-a5=False force_simt_template=False' --triton-to-hivm --triton-to-hfusion --triton-to-llvm --bubble-up-operation '--triton-to-linalg=global-kernel=false named-ops=True enable-nd2nz-on-vector=False compile-on-a5=False' --split-input-file %s | FileCheck %s + +module { + tt.func public @triton_min_1d8(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/min_uint.py":22:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/min_uint.py":22:0), %arg2: i32 {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/min_uint.py":22:0)) attributes {noinline = false} { + %true = arith.constant true loc(#loc1) + %0 = tt.get_program_id x : i32 loc(#loc2) + %1 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> loc(#loc3) + %2 = tt.splat %0 : i32 -> tensor<16xi32> loc(#loc4) + %3 = arith.addi %2, %1 : tensor<16xi32> loc(#loc4) + %4 = tt.splat %arg0 : !tt.ptr -> tensor<16x!tt.ptr> loc(#loc5) + %5 = tt.addptr %4, %3 : tensor<16x!tt.ptr>, tensor<16xi32> loc(#loc5) + %6 = tt.load %5 : tensor<16x!tt.ptr> loc(#loc6) + tt.assert %true, "Expecting input to be integer type" : i1 loc(#loc14) + %7 = arith.extui %6 : tensor<16xi8> to tensor<16xi32> loc(#loc15) + %8 = "tt.reduce"(%7) <{axis = 0 : i32}> ({ + ^bb0(%arg3: i32 loc(callsite(#loc1 at #loc8)), %arg4: i32 loc(callsite(#loc1 at #loc8))): + %10 = arith.minsi %arg3, %arg4 : i32 loc(#loc19) + tt.reduce.return %10 : i32 loc(#loc16) + }) : (tensor<16xi32>) -> i32 loc(#loc16) + %9 = arith.trunci %8 : i32 to i8 loc(#loc12) + tt.store %arg1, %9 : !tt.ptr loc(#loc12) + tt.return loc(#loc13) + } loc(#loc) +} loc(#loc) + +// CHECK: %[[VAL_2:[A-Za-z0-9_]+]] = arith.minsi %[[VAL_0]], %[[VAL_1]] : i32 +// ----------- + + +module { + tt.func public @triton_min_1d(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/min_uint.py":45:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/min_uint.py":45:0), %arg2: i32 {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/min_uint.py":45:0)) attributes {noinline = false} { + %true = arith.constant true loc(#loc1) + %0 = tt.get_program_id x : i32 loc(#loc2) + %1 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> loc(#loc3) + %2 = tt.splat %0 : i32 -> tensor<16xi32> loc(#loc4) + %3 = arith.addi %2, %1 : tensor<16xi32> loc(#loc4) + %4 = tt.splat %arg0 : !tt.ptr -> tensor<16x!tt.ptr> loc(#loc5) + %5 = tt.addptr %4, %3 : tensor<16x!tt.ptr>, tensor<16xi32> loc(#loc5) + %6 = tt.load %5 : tensor<16x!tt.ptr> loc(#loc6) + tt.assert %true, "Expecting input to be integer type" : i1 loc(#loc14) + %7 = arith.extui %6 : tensor<16xi16> to tensor<16xi32> loc(#loc15) + %8 = "tt.reduce"(%7) <{axis = 0 : i32}> ({ + ^bb0(%arg3: i32 loc(callsite(#loc1 at #loc8)), %arg4: i32 loc(callsite(#loc1 at #loc8))): + %10 = arith.minsi %arg3, %arg4 : i32 loc(#loc19) + tt.reduce.return %10 : i32 loc(#loc16) + }) : (tensor<16xi32>) -> i32 loc(#loc16) + %9 = arith.trunci %8 : i32 to i16 loc(#loc12) + tt.store %arg1, %9 : !tt.ptr loc(#loc12) + tt.return loc(#loc13) + } loc(#loc) +} loc(#loc) + +// CHECK: %[[VAL_2:[A-Za-z0-9_]+]] = arith.minsi %[[VAL_0]], %[[VAL_1]] : i32 +// ----------- + +module { + tt.func public @triton_min_1d(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/min_uint.py":45:0), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/min_uint.py":45:0), %arg2: i32 {tt.divisibility = 16 : i32} loc("/home/l30058175/wxue/min_uint.py":45:0)) attributes {noinline = false} { + %0 = tt.get_program_id x : i32 loc(#loc1) + %1 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> loc(#loc2) + %2 = tt.splat %0 : i32 -> tensor<16xi32> loc(#loc3) + %3 = arith.addi %2, %1 : tensor<16xi32> loc(#loc3) + %4 = tt.splat %arg0 : !tt.ptr -> tensor<16x!tt.ptr> loc(#loc4) + %5 = tt.addptr %4, %3 : tensor<16x!tt.ptr>, tensor<16xi32> loc(#loc4) + %6 = tt.load %5 : tensor<16x!tt.ptr> loc(#loc5) + %7 = "tt.reduce"(%6) <{axis = 0 : i32}> ({ + ^bb0(%arg3: i32 loc(callsite(#loc8 at #loc7)), %arg4: i32 loc(callsite(#loc8 at #loc7))): + %8 = arith.minui %arg3, %arg4 : i32 loc(#loc15) + tt.reduce.return %8 : i32 loc(#loc12) + }) : (tensor<16xi32>) -> i32 loc(#loc12) + tt.store %arg1, %7 : !tt.ptr loc(#loc10) + tt.return loc(#loc11) + } loc(#loc) +} loc(#loc) + +// CHECK: %[[VAL_2:[A-Za-z0-9_]+]] = arith.minui %[[VAL_0]], %[[VAL_1]] : i32 // ----------- diff --git a/third_party/ascend/test/Conversion/TritonOp/reduce_sum.mlir b/third_party/ascend/test/Conversion/TritonOp/reduce_sum.mlir index d943de800..7f29ccf8e 100644 --- a/third_party/ascend/test/Conversion/TritonOp/reduce_sum.mlir +++ b/third_party/ascend/test/Conversion/TritonOp/reduce_sum.mlir @@ -1,270 +1,270 @@ -// RUN: triton-adapter-opt --triton-linearize --discrete-mask-access-conversion --triton-to-annotation --triton-to-unstructure --triton-to-hivm --triton-to-hfusion --triton-to-llvm --bubble-up-operation '--triton-to-linalg=global-kernel=false named-ops=True enable-nd2nz-on-vector=False' --split-input-file %s | FileCheck %s - -// === reduce and sum use the same case === -// === i8 u8 version === -module { - tt.func public @fn_npu_u8( - %arg0: !tt.ptr {tt.divisibility = 16 : i32}, - %arg1: !tt.ptr {tt.divisibility = 16 : i32} - ) { - %cst = arith.constant dense<4> : tensor<1x8x1xi32> - %cst_0 = arith.constant dense<4> : tensor<8x1x1xi32> - %cst_1 = arith.constant dense<8> : tensor<8x1x1xi32> - %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> - %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> - %2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> - %3 = tt.expand_dims %2 {axis = 2 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> - %4 = arith.muli %3, %cst_1 : tensor<8x1x1xi32> - %5 = arith.muli %4, %cst_0 : tensor<8x1x1xi32> - %6 = tt.expand_dims %0 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> - %7 = tt.expand_dims %6 {axis = 2 : i32} : tensor<1x8xi32> -> tensor<1x8x1xi32> - %8 = arith.muli %7, %cst : tensor<1x8x1xi32> - %9 = tt.broadcast %5 : tensor<8x1x1xi32> -> tensor<8x8x1xi32> - %10 = tt.broadcast %8 : tensor<1x8x1xi32> -> tensor<8x8x1xi32> - %11 = arith.addi %9, %10 : tensor<8x8x1xi32> - %12 = tt.expand_dims %1 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> - %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<1x4xi32> -> tensor<1x1x4xi32> - %14 = tt.broadcast %11 : tensor<8x8x1xi32> -> tensor<8x8x4xi32> - %15 = tt.broadcast %13 : tensor<1x1x4xi32> -> tensor<8x8x4xi32> - %16 = arith.addi %14, %15 : tensor<8x8x4xi32> - %17 = tt.splat %arg1 : !tt.ptr -> tensor<8x8x4x!tt.ptr> - %18 = tt.addptr %17, %16 : tensor<8x8x4x!tt.ptr>, tensor<8x8x4xi32> - %19 = tt.load %18 : tensor<8x8x4x!tt.ptr> - %20 = tt.reshape %19 : tensor<8x8x4xi8> -> tensor<256xi8> - %21 = "tt.reduce"(%20) <{axis = 0 : i32}> ({ - ^bb0(%arg2: i8, %arg3: i8): - %22 = arith.addi %arg2, %arg3 : i8 - tt.reduce.return %22 : i8 - }) : (tensor<256xi8>) -> i8 - tt.store %arg0, %21 : !tt.ptr - tt.return - } -} - -// === i16 u16 version === -module { - tt.func public @fn_npu_i16( - %arg0: !tt.ptr {tt.divisibility = 16 : i32}, - %arg1: !tt.ptr {tt.divisibility = 16 : i32} - ) { - %cst = arith.constant dense<4> : tensor<1x8x1xi32> - %cst_0 = arith.constant dense<4> : tensor<8x1x1xi32> - %cst_1 = arith.constant dense<8> : tensor<8x1x1xi32> - %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> - %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> - %2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> - %3 = tt.expand_dims %2 {axis = 2 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> - %4 = arith.muli %3, %cst_1 : tensor<8x1x1xi32> - %5 = arith.muli %4, %cst_0 : tensor<8x1x1xi32> - %6 = tt.expand_dims %0 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> - %7 = tt.expand_dims %6 {axis = 2 : i32} : tensor<1x8xi32> -> tensor<1x8x1xi32> - %8 = arith.muli %7, %cst : tensor<1x8x1xi32> - %9 = tt.broadcast %5 : tensor<8x1x1xi32> -> tensor<8x8x1xi32> - %10 = tt.broadcast %8 : tensor<1x8x1xi32> -> tensor<8x8x1xi32> - %11 = arith.addi %9, %10 : tensor<8x8x1xi32> - %12 = tt.expand_dims %1 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> - %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<1x4xi32> -> tensor<1x1x4xi32> - %14 = tt.broadcast %11 : tensor<8x8x1xi32> -> tensor<8x8x4xi32> - %15 = tt.broadcast %13 : tensor<1x1x4xi32> -> tensor<8x8x4xi32> - %16 = arith.addi %14, %15 : tensor<8x8x4xi32> - %17 = tt.splat %arg1 : !tt.ptr -> tensor<8x8x4x!tt.ptr> - %18 = tt.addptr %17, %16 : tensor<8x8x4x!tt.ptr>, tensor<8x8x4xi32> - %19 = tt.load %18 : tensor<8x8x4x!tt.ptr> - %20 = tt.reshape %19 : tensor<8x8x4xi16> -> tensor<256xi16> - %21 = "tt.reduce"(%20) <{axis = 0 : i32}> ({ - ^bb0(%arg2: i16, %arg3: i16): - %22 = arith.addi %arg2, %arg3 : i16 - tt.reduce.return %22 : i16 - }) : (tensor<256xi16>) -> i16 - tt.store %arg0, %21 : !tt.ptr - tt.return - } -} - -// === i32 u32 version === -module { - tt.func public @fn_npu_i32( - %arg0: !tt.ptr {tt.divisibility = 16 : i32}, - %arg1: !tt.ptr {tt.divisibility = 16 : i32} - ) { - %cst = arith.constant dense<4> : tensor<1x8x1xi32> - %cst_0 = arith.constant dense<4> : tensor<8x1x1xi32> - %cst_1 = arith.constant dense<8> : tensor<8x1x1xi32> - %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> - %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> - %2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> - %3 = tt.expand_dims %2 {axis = 2 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> - %4 = arith.muli %3, %cst_1 : tensor<8x1x1xi32> - %5 = arith.muli %4, %cst_0 : tensor<8x1x1xi32> - %6 = tt.expand_dims %0 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> - %7 = tt.expand_dims %6 {axis = 2 : i32} : tensor<1x8xi32> -> tensor<1x8x1xi32> - %8 = arith.muli %7, %cst : tensor<1x8x1xi32> - %9 = tt.broadcast %5 : tensor<8x1x1xi32> -> tensor<8x8x1xi32> - %10 = tt.broadcast %8 : tensor<1x8x1xi32> -> tensor<8x8x1xi32> - %11 = arith.addi %9, %10 : tensor<8x8x1xi32> - %12 = tt.expand_dims %1 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> - %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<1x4xi32> -> tensor<1x1x4xi32> - %14 = tt.broadcast %11 : tensor<8x8x1xi32> -> tensor<8x8x4xi32> - %15 = tt.broadcast %13 : tensor<1x1x4xi32> -> tensor<8x8x4xi32> - %16 = arith.addi %14, %15 : tensor<8x8x4xi32> - %17 = tt.splat %arg1 : !tt.ptr -> tensor<8x8x4x!tt.ptr> - %18 = tt.addptr %17, %16 : tensor<8x8x4x!tt.ptr>, tensor<8x8x4xi32> - %19 = tt.load %18 : tensor<8x8x4x!tt.ptr> - %20 = tt.reshape %19 : tensor<8x8x4xi32> -> tensor<256xi32> - %21 = "tt.reduce"(%20) <{axis = 0 : i32}> ({ - ^bb0(%arg2: i32, %arg3: i32): - %22 = arith.addi %arg2, %arg3 : i32 - tt.reduce.return %22 : i32 - }) : (tensor<256xi32>) -> i32 - tt.store %arg0, %21 : !tt.ptr - tt.return - } -} - -// === i64 u64 version === -module { - tt.func public @fn_npu_i64( - %arg0: !tt.ptr {tt.divisibility = 16 : i32}, - %arg1: !tt.ptr {tt.divisibility = 16 : i32} - ) { - %cst = arith.constant dense<4> : tensor<1x8x1xi32> - %cst_0 = arith.constant dense<4> : tensor<8x1x1xi32> - %cst_1 = arith.constant dense<8> : tensor<8x1x1xi32> - %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> - %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> - %2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> - %3 = tt.expand_dims %2 {axis = 2 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> - %4 = arith.muli %3, %cst_1 : tensor<8x1x1xi32> - %5 = arith.muli %4, %cst_0 : tensor<8x1x1xi32> - %6 = tt.expand_dims %0 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> - %7 = tt.expand_dims %6 {axis = 2 : i32} : tensor<1x8xi32> -> tensor<1x8x1xi32> - %8 = arith.muli %7, %cst : tensor<1x8x1xi32> - %9 = tt.broadcast %5 : tensor<8x1x1xi32> -> tensor<8x8x1xi32> - %10 = tt.broadcast %8 : tensor<1x8x1xi32> -> tensor<8x8x1xi32> - %11 = arith.addi %9, %10 : tensor<8x8x1xi32> - %12 = tt.expand_dims %1 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> - %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<1x4xi32> -> tensor<1x1x4xi32> - %14 = tt.broadcast %11 : tensor<8x8x1xi32> -> tensor<8x8x4xi32> - %15 = tt.broadcast %13 : tensor<1x1x4xi32> -> tensor<8x8x4xi32> - %16 = arith.addi %14, %15 : tensor<8x8x4xi32> - %17 = tt.splat %arg1 : !tt.ptr -> tensor<8x8x4x!tt.ptr> - %18 = tt.addptr %17, %16 : tensor<8x8x4x!tt.ptr>, tensor<8x8x4xi32> - %19 = tt.load %18 : tensor<8x8x4x!tt.ptr> - %20 = tt.reshape %19 : tensor<8x8x4xi64> -> tensor<256xi64> - %21 = "tt.reduce"(%20) <{axis = 0 : i32}> ({ - ^bb0(%arg2: i64, %arg3: i64): - %22 = arith.addi %arg2, %arg3 : i64 - tt.reduce.return %22 : i64 - }) : (tensor<256xi64>) -> i64 - tt.store %arg0, %21 : !tt.ptr - tt.return - } -} - -// === f8E4M3FN version === -module { - tt.func public @fn_npu_f8E4M3FN(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, - %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { - %cst = arith.constant dense<4> : tensor<1x8x1xi32> - %cst_0 = arith.constant dense<4> : tensor<8x1x1xi32> - %cst_1 = arith.constant dense<8> : tensor<8x1x1xi32> - %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> - %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> - %2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> - %3 = tt.expand_dims %2 {axis = 2 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> - %4 = arith.muli %3, %cst_1 : tensor<8x1x1xi32> - %5 = arith.muli %4, %cst_0 : tensor<8x1x1xi32> - %6 = tt.expand_dims %0 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> - %7 = tt.expand_dims %6 {axis = 2 : i32} : tensor<1x8xi32> -> tensor<1x8x1xi32> - %8 = arith.muli %7, %cst : tensor<1x8x1xi32> - %9 = tt.broadcast %5 : tensor<8x1x1xi32> -> tensor<8x8x1xi32> - %10 = tt.broadcast %8 : tensor<1x8x1xi32> -> tensor<8x8x1xi32> - %11 = arith.addi %9, %10 : tensor<8x8x1xi32> - %12 = tt.expand_dims %1 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> - %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<1x4xi32> -> tensor<1x1x4xi32> - %14 = tt.broadcast %11 : tensor<8x8x1xi32> -> tensor<8x8x4xi32> - %15 = tt.broadcast %13 : tensor<1x1x4xi32> -> tensor<8x8x4xi32> - %16 = arith.addi %14, %15 : tensor<8x8x4xi32> - %17 = tt.splat %arg1 : !tt.ptr -> tensor<8x8x4x!tt.ptr> - %18 = tt.addptr %17, %16 : tensor<8x8x4x!tt.ptr>, tensor<8x8x4xi32> - %19 = tt.load %18 : tensor<8x8x4x!tt.ptr> - %20 = tt.reshape %19 allow_reorder : tensor<8x8x4xf8E4M3FN> -> tensor<256xf8E4M3FN> - %21 = "tt.reduce"(%20) <{axis = 0 : i32}> ({ - ^bb0(%arg2: f8E4M3FN, %arg3: f8E4M3FN): - %22 = arith.addf %arg2, %arg3 : f8E4M3FN - tt.reduce.return %22 : f8E4M3FN - }) : (tensor<256xf8E4M3FN>) -> f8E4M3FN - tt.store %arg0, %21 : !tt.ptr - tt.return - } -} - -// === f8E5M2 version === -module { - tt.func public @fn_npu_f8E5M2(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, - %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { - %cst = arith.constant dense<4> : tensor<1x8x1xi32> - %cst_0 = arith.constant dense<4> : tensor<8x1x1xi32> - %cst_1 = arith.constant dense<8> : tensor<8x1x1xi32> - %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> - %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> - %2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> - %3 = tt.expand_dims %2 {axis = 2 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> - %4 = arith.muli %3, %cst_1 : tensor<8x1x1xi32> - %5 = arith.muli %4, %cst_0 : tensor<8x1x1xi32> - %6 = tt.expand_dims %0 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> - %7 = tt.expand_dims %6 {axis = 2 : i32} : tensor<1x8xi32> -> tensor<1x8x1xi32> - %8 = arith.muli %7, %cst : tensor<1x8x1xi32> - %9 = tt.broadcast %5 : tensor<8x1x1xi32> -> tensor<8x8x1xi32> - %10 = tt.broadcast %8 : tensor<1x8x1xi32> -> tensor<8x8x1xi32> - %11 = arith.addi %9, %10 : tensor<8x8x1xi32> - %12 = tt.expand_dims %1 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> - %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<1x4xi32> -> tensor<1x1x4xi32> - %14 = tt.broadcast %11 : tensor<8x8x1xi32> -> tensor<8x8x4xi32> - %15 = tt.broadcast %13 : tensor<1x1x4xi32> -> tensor<8x8x4xi32> - %16 = arith.addi %14, %15 : tensor<8x8x4xi32> - %17 = tt.splat %arg1 : !tt.ptr -> tensor<8x8x4x!tt.ptr> - %18 = tt.addptr %17, %16 : tensor<8x8x4x!tt.ptr>, tensor<8x8x4xi32> - %19 = tt.load %18 : tensor<8x8x4x!tt.ptr> - %20 = tt.reshape %19 allow_reorder : tensor<8x8x4xf8E5M2> -> tensor<256xf8E5M2> - %21 = "tt.reduce"(%20) <{axis = 0 : i32}> ({ - ^bb0(%arg2: f8E5M2, %arg3: f8E5M2): - %22 = arith.addf %arg2, %arg3 : f8E5M2 - tt.reduce.return %22 : f8E5M2 - }) : (tensor<256xf8E5M2>) -> f8E5M2 - tt.store %arg0, %21 : !tt.ptr - tt.return - } -} - - -// ===== CHECKS ===== -// CHECK-DAG: arith.constant 0 : i8 -// CHECK-DAG: tensor.reshape %{{.*}} : (tensor<8x8x4xi8>, {{.*}}) -> tensor<256xi8> -// CHECK-DAG: linalg.reduce ins(%{{.*}} : tensor<256xi8>) outs(%{{.*}} : tensor) dimensions = [0] -// CHECK-DAG: arith.addi %in, %init : i8 - -// CHECK-DAG: arith.constant 0 : i16 -// CHECK-DAG: tensor.reshape %{{.*}} : (tensor<8x8x4xi16>, {{.*}}) -> tensor<256xi16> -// CHECK-DAG: linalg.reduce ins(%{{.*}} : tensor<256xi16>) outs(%{{.*}} : tensor) dimensions = [0] -// CHECK-DAG: arith.addi %in, %init : i16 - -// CHECK-DAG: arith.constant 0 : i32 -// CHECK-DAG: tensor.reshape %{{.*}} : (tensor<8x8x4xi32>, {{.*}}) -> tensor<256xi32> -// CHECK-DAG: linalg.reduce ins(%{{.*}} : tensor<256xi32>) outs(%{{.*}} : tensor) dimensions = [0] -// CHECK-DAG: arith.addi %in, %init : i32 - -// CHECK-DAG: arith.constant 0 : i64 -// CHECK-DAG: tensor.reshape %{{.*}} : (tensor<8x8x4xi64>, {{.*}}) -> tensor<256xi64> -// CHECK-DAG: linalg.reduce ins(%{{.*}} : tensor<256xi64>) outs(%{{.*}} : tensor) dimensions = [0] -// CHECK-DAG: arith.addi %in, %init : i64 - -// CHECK-DAG: arith.constant 0.0{{.*}} : f8E4M3FN -// CHECK-DAG: tensor.reshape %{{.*}} : (tensor<8x8x4xf8E4M3FN>, {{.*}}) -> tensor<256xf8E4M3FN> -// CHECK-DAG: linalg.reduce ins(%{{.*}} : tensor<256xf8E4M3FN>) outs(%{{.*}} : tensor) dimensions = [0] -// CHECK-DAG: arith.addf %in, %init : f8E4M3FN - -// CHECK-DAG: arith.constant 0.0{{.*}} : f8E5M2 -// CHECK-DAG: tensor.reshape %{{.*}} : (tensor<8x8x4xf8E5M2>, {{.*}}) -> tensor<256xf8E5M2> -// CHECK-DAG: linalg.reduce ins(%{{.*}} : tensor<256xf8E5M2>) outs(%{{.*}} : tensor) dimensions = [0] +// RUN: triton-adapter-opt --triton-linearize --discrete-mask-access-conversion --triton-to-annotation --triton-to-unstructure --triton-to-hivm --triton-to-hfusion --triton-to-llvm --bubble-up-operation '--triton-to-linalg=global-kernel=false named-ops=True enable-nd2nz-on-vector=False' --split-input-file %s | FileCheck %s + +// === reduce and sum use the same case === +// === i8 u8 version === +module { + tt.func public @fn_npu_u8( + %arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32} + ) { + %cst = arith.constant dense<4> : tensor<1x8x1xi32> + %cst_0 = arith.constant dense<4> : tensor<8x1x1xi32> + %cst_1 = arith.constant dense<8> : tensor<8x1x1xi32> + %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> + %3 = tt.expand_dims %2 {axis = 2 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> + %4 = arith.muli %3, %cst_1 : tensor<8x1x1xi32> + %5 = arith.muli %4, %cst_0 : tensor<8x1x1xi32> + %6 = tt.expand_dims %0 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> + %7 = tt.expand_dims %6 {axis = 2 : i32} : tensor<1x8xi32> -> tensor<1x8x1xi32> + %8 = arith.muli %7, %cst : tensor<1x8x1xi32> + %9 = tt.broadcast %5 : tensor<8x1x1xi32> -> tensor<8x8x1xi32> + %10 = tt.broadcast %8 : tensor<1x8x1xi32> -> tensor<8x8x1xi32> + %11 = arith.addi %9, %10 : tensor<8x8x1xi32> + %12 = tt.expand_dims %1 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<1x4xi32> -> tensor<1x1x4xi32> + %14 = tt.broadcast %11 : tensor<8x8x1xi32> -> tensor<8x8x4xi32> + %15 = tt.broadcast %13 : tensor<1x1x4xi32> -> tensor<8x8x4xi32> + %16 = arith.addi %14, %15 : tensor<8x8x4xi32> + %17 = tt.splat %arg1 : !tt.ptr -> tensor<8x8x4x!tt.ptr> + %18 = tt.addptr %17, %16 : tensor<8x8x4x!tt.ptr>, tensor<8x8x4xi32> + %19 = tt.load %18 : tensor<8x8x4x!tt.ptr> + %20 = tt.reshape %19 : tensor<8x8x4xi8> -> tensor<256xi8> + %21 = "tt.reduce"(%20) <{axis = 0 : i32}> ({ + ^bb0(%arg2: i8, %arg3: i8): + %22 = arith.addi %arg2, %arg3 : i8 + tt.reduce.return %22 : i8 + }) : (tensor<256xi8>) -> i8 + tt.store %arg0, %21 : !tt.ptr + tt.return + } +} + +// === i16 u16 version === +module { + tt.func public @fn_npu_i16( + %arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32} + ) { + %cst = arith.constant dense<4> : tensor<1x8x1xi32> + %cst_0 = arith.constant dense<4> : tensor<8x1x1xi32> + %cst_1 = arith.constant dense<8> : tensor<8x1x1xi32> + %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> + %3 = tt.expand_dims %2 {axis = 2 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> + %4 = arith.muli %3, %cst_1 : tensor<8x1x1xi32> + %5 = arith.muli %4, %cst_0 : tensor<8x1x1xi32> + %6 = tt.expand_dims %0 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> + %7 = tt.expand_dims %6 {axis = 2 : i32} : tensor<1x8xi32> -> tensor<1x8x1xi32> + %8 = arith.muli %7, %cst : tensor<1x8x1xi32> + %9 = tt.broadcast %5 : tensor<8x1x1xi32> -> tensor<8x8x1xi32> + %10 = tt.broadcast %8 : tensor<1x8x1xi32> -> tensor<8x8x1xi32> + %11 = arith.addi %9, %10 : tensor<8x8x1xi32> + %12 = tt.expand_dims %1 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<1x4xi32> -> tensor<1x1x4xi32> + %14 = tt.broadcast %11 : tensor<8x8x1xi32> -> tensor<8x8x4xi32> + %15 = tt.broadcast %13 : tensor<1x1x4xi32> -> tensor<8x8x4xi32> + %16 = arith.addi %14, %15 : tensor<8x8x4xi32> + %17 = tt.splat %arg1 : !tt.ptr -> tensor<8x8x4x!tt.ptr> + %18 = tt.addptr %17, %16 : tensor<8x8x4x!tt.ptr>, tensor<8x8x4xi32> + %19 = tt.load %18 : tensor<8x8x4x!tt.ptr> + %20 = tt.reshape %19 : tensor<8x8x4xi16> -> tensor<256xi16> + %21 = "tt.reduce"(%20) <{axis = 0 : i32}> ({ + ^bb0(%arg2: i16, %arg3: i16): + %22 = arith.addi %arg2, %arg3 : i16 + tt.reduce.return %22 : i16 + }) : (tensor<256xi16>) -> i16 + tt.store %arg0, %21 : !tt.ptr + tt.return + } +} + +// === i32 u32 version === +module { + tt.func public @fn_npu_i32( + %arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32} + ) { + %cst = arith.constant dense<4> : tensor<1x8x1xi32> + %cst_0 = arith.constant dense<4> : tensor<8x1x1xi32> + %cst_1 = arith.constant dense<8> : tensor<8x1x1xi32> + %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> + %3 = tt.expand_dims %2 {axis = 2 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> + %4 = arith.muli %3, %cst_1 : tensor<8x1x1xi32> + %5 = arith.muli %4, %cst_0 : tensor<8x1x1xi32> + %6 = tt.expand_dims %0 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> + %7 = tt.expand_dims %6 {axis = 2 : i32} : tensor<1x8xi32> -> tensor<1x8x1xi32> + %8 = arith.muli %7, %cst : tensor<1x8x1xi32> + %9 = tt.broadcast %5 : tensor<8x1x1xi32> -> tensor<8x8x1xi32> + %10 = tt.broadcast %8 : tensor<1x8x1xi32> -> tensor<8x8x1xi32> + %11 = arith.addi %9, %10 : tensor<8x8x1xi32> + %12 = tt.expand_dims %1 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<1x4xi32> -> tensor<1x1x4xi32> + %14 = tt.broadcast %11 : tensor<8x8x1xi32> -> tensor<8x8x4xi32> + %15 = tt.broadcast %13 : tensor<1x1x4xi32> -> tensor<8x8x4xi32> + %16 = arith.addi %14, %15 : tensor<8x8x4xi32> + %17 = tt.splat %arg1 : !tt.ptr -> tensor<8x8x4x!tt.ptr> + %18 = tt.addptr %17, %16 : tensor<8x8x4x!tt.ptr>, tensor<8x8x4xi32> + %19 = tt.load %18 : tensor<8x8x4x!tt.ptr> + %20 = tt.reshape %19 : tensor<8x8x4xi32> -> tensor<256xi32> + %21 = "tt.reduce"(%20) <{axis = 0 : i32}> ({ + ^bb0(%arg2: i32, %arg3: i32): + %22 = arith.addi %arg2, %arg3 : i32 + tt.reduce.return %22 : i32 + }) : (tensor<256xi32>) -> i32 + tt.store %arg0, %21 : !tt.ptr + tt.return + } +} + +// === i64 u64 version === +module { + tt.func public @fn_npu_i64( + %arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32} + ) { + %cst = arith.constant dense<4> : tensor<1x8x1xi32> + %cst_0 = arith.constant dense<4> : tensor<8x1x1xi32> + %cst_1 = arith.constant dense<8> : tensor<8x1x1xi32> + %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> + %3 = tt.expand_dims %2 {axis = 2 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> + %4 = arith.muli %3, %cst_1 : tensor<8x1x1xi32> + %5 = arith.muli %4, %cst_0 : tensor<8x1x1xi32> + %6 = tt.expand_dims %0 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> + %7 = tt.expand_dims %6 {axis = 2 : i32} : tensor<1x8xi32> -> tensor<1x8x1xi32> + %8 = arith.muli %7, %cst : tensor<1x8x1xi32> + %9 = tt.broadcast %5 : tensor<8x1x1xi32> -> tensor<8x8x1xi32> + %10 = tt.broadcast %8 : tensor<1x8x1xi32> -> tensor<8x8x1xi32> + %11 = arith.addi %9, %10 : tensor<8x8x1xi32> + %12 = tt.expand_dims %1 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<1x4xi32> -> tensor<1x1x4xi32> + %14 = tt.broadcast %11 : tensor<8x8x1xi32> -> tensor<8x8x4xi32> + %15 = tt.broadcast %13 : tensor<1x1x4xi32> -> tensor<8x8x4xi32> + %16 = arith.addi %14, %15 : tensor<8x8x4xi32> + %17 = tt.splat %arg1 : !tt.ptr -> tensor<8x8x4x!tt.ptr> + %18 = tt.addptr %17, %16 : tensor<8x8x4x!tt.ptr>, tensor<8x8x4xi32> + %19 = tt.load %18 : tensor<8x8x4x!tt.ptr> + %20 = tt.reshape %19 : tensor<8x8x4xi64> -> tensor<256xi64> + %21 = "tt.reduce"(%20) <{axis = 0 : i32}> ({ + ^bb0(%arg2: i64, %arg3: i64): + %22 = arith.addi %arg2, %arg3 : i64 + tt.reduce.return %22 : i64 + }) : (tensor<256xi64>) -> i64 + tt.store %arg0, %21 : !tt.ptr + tt.return + } +} + +// === f8E4M3FN version === +module { + tt.func public @fn_npu_f8E4M3FN(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<4> : tensor<1x8x1xi32> + %cst_0 = arith.constant dense<4> : tensor<8x1x1xi32> + %cst_1 = arith.constant dense<8> : tensor<8x1x1xi32> + %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> + %3 = tt.expand_dims %2 {axis = 2 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> + %4 = arith.muli %3, %cst_1 : tensor<8x1x1xi32> + %5 = arith.muli %4, %cst_0 : tensor<8x1x1xi32> + %6 = tt.expand_dims %0 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> + %7 = tt.expand_dims %6 {axis = 2 : i32} : tensor<1x8xi32> -> tensor<1x8x1xi32> + %8 = arith.muli %7, %cst : tensor<1x8x1xi32> + %9 = tt.broadcast %5 : tensor<8x1x1xi32> -> tensor<8x8x1xi32> + %10 = tt.broadcast %8 : tensor<1x8x1xi32> -> tensor<8x8x1xi32> + %11 = arith.addi %9, %10 : tensor<8x8x1xi32> + %12 = tt.expand_dims %1 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<1x4xi32> -> tensor<1x1x4xi32> + %14 = tt.broadcast %11 : tensor<8x8x1xi32> -> tensor<8x8x4xi32> + %15 = tt.broadcast %13 : tensor<1x1x4xi32> -> tensor<8x8x4xi32> + %16 = arith.addi %14, %15 : tensor<8x8x4xi32> + %17 = tt.splat %arg1 : !tt.ptr -> tensor<8x8x4x!tt.ptr> + %18 = tt.addptr %17, %16 : tensor<8x8x4x!tt.ptr>, tensor<8x8x4xi32> + %19 = tt.load %18 : tensor<8x8x4x!tt.ptr> + %20 = tt.reshape %19 allow_reorder : tensor<8x8x4xf8E4M3FN> -> tensor<256xf8E4M3FN> + %21 = "tt.reduce"(%20) <{axis = 0 : i32}> ({ + ^bb0(%arg2: f8E4M3FN, %arg3: f8E4M3FN): + %22 = arith.addf %arg2, %arg3 : f8E4M3FN + tt.reduce.return %22 : f8E4M3FN + }) : (tensor<256xf8E4M3FN>) -> f8E4M3FN + tt.store %arg0, %21 : !tt.ptr + tt.return + } +} + +// === f8E5M2 version === +module { + tt.func public @fn_npu_f8E5M2(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<4> : tensor<1x8x1xi32> + %cst_0 = arith.constant dense<4> : tensor<8x1x1xi32> + %cst_1 = arith.constant dense<8> : tensor<8x1x1xi32> + %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> + %3 = tt.expand_dims %2 {axis = 2 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> + %4 = arith.muli %3, %cst_1 : tensor<8x1x1xi32> + %5 = arith.muli %4, %cst_0 : tensor<8x1x1xi32> + %6 = tt.expand_dims %0 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> + %7 = tt.expand_dims %6 {axis = 2 : i32} : tensor<1x8xi32> -> tensor<1x8x1xi32> + %8 = arith.muli %7, %cst : tensor<1x8x1xi32> + %9 = tt.broadcast %5 : tensor<8x1x1xi32> -> tensor<8x8x1xi32> + %10 = tt.broadcast %8 : tensor<1x8x1xi32> -> tensor<8x8x1xi32> + %11 = arith.addi %9, %10 : tensor<8x8x1xi32> + %12 = tt.expand_dims %1 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<1x4xi32> -> tensor<1x1x4xi32> + %14 = tt.broadcast %11 : tensor<8x8x1xi32> -> tensor<8x8x4xi32> + %15 = tt.broadcast %13 : tensor<1x1x4xi32> -> tensor<8x8x4xi32> + %16 = arith.addi %14, %15 : tensor<8x8x4xi32> + %17 = tt.splat %arg1 : !tt.ptr -> tensor<8x8x4x!tt.ptr> + %18 = tt.addptr %17, %16 : tensor<8x8x4x!tt.ptr>, tensor<8x8x4xi32> + %19 = tt.load %18 : tensor<8x8x4x!tt.ptr> + %20 = tt.reshape %19 allow_reorder : tensor<8x8x4xf8E5M2> -> tensor<256xf8E5M2> + %21 = "tt.reduce"(%20) <{axis = 0 : i32}> ({ + ^bb0(%arg2: f8E5M2, %arg3: f8E5M2): + %22 = arith.addf %arg2, %arg3 : f8E5M2 + tt.reduce.return %22 : f8E5M2 + }) : (tensor<256xf8E5M2>) -> f8E5M2 + tt.store %arg0, %21 : !tt.ptr + tt.return + } +} + + +// ===== CHECKS ===== +// CHECK-DAG: arith.constant 0 : i8 +// CHECK-DAG: tensor.reshape %{{.*}} : (tensor<8x8x4xi8>, {{.*}}) -> tensor<256xi8> +// CHECK-DAG: linalg.reduce ins(%{{.*}} : tensor<256xi8>) outs(%{{.*}} : tensor) dimensions = [0] +// CHECK-DAG: arith.addi %in, %init : i8 + +// CHECK-DAG: arith.constant 0 : i16 +// CHECK-DAG: tensor.reshape %{{.*}} : (tensor<8x8x4xi16>, {{.*}}) -> tensor<256xi16> +// CHECK-DAG: linalg.reduce ins(%{{.*}} : tensor<256xi16>) outs(%{{.*}} : tensor) dimensions = [0] +// CHECK-DAG: arith.addi %in, %init : i16 + +// CHECK-DAG: arith.constant 0 : i32 +// CHECK-DAG: tensor.reshape %{{.*}} : (tensor<8x8x4xi32>, {{.*}}) -> tensor<256xi32> +// CHECK-DAG: linalg.reduce ins(%{{.*}} : tensor<256xi32>) outs(%{{.*}} : tensor) dimensions = [0] +// CHECK-DAG: arith.addi %in, %init : i32 + +// CHECK-DAG: arith.constant 0 : i64 +// CHECK-DAG: tensor.reshape %{{.*}} : (tensor<8x8x4xi64>, {{.*}}) -> tensor<256xi64> +// CHECK-DAG: linalg.reduce ins(%{{.*}} : tensor<256xi64>) outs(%{{.*}} : tensor) dimensions = [0] +// CHECK-DAG: arith.addi %in, %init : i64 + +// CHECK-DAG: arith.constant 0.0{{.*}} : f8E4M3FN +// CHECK-DAG: tensor.reshape %{{.*}} : (tensor<8x8x4xf8E4M3FN>, {{.*}}) -> tensor<256xf8E4M3FN> +// CHECK-DAG: linalg.reduce ins(%{{.*}} : tensor<256xf8E4M3FN>) outs(%{{.*}} : tensor) dimensions = [0] +// CHECK-DAG: arith.addf %in, %init : f8E4M3FN + +// CHECK-DAG: arith.constant 0.0{{.*}} : f8E5M2 +// CHECK-DAG: tensor.reshape %{{.*}} : (tensor<8x8x4xf8E5M2>, {{.*}}) -> tensor<256xf8E5M2> +// CHECK-DAG: linalg.reduce ins(%{{.*}} : tensor<256xf8E5M2>) outs(%{{.*}} : tensor) dimensions = [0] // CHECK-DAG: arith.addf %in, %init : f8E5M2 diff --git a/third_party/ascend/test/Conversion/TritonOp/xor_sum.mlir b/third_party/ascend/test/Conversion/TritonOp/xor_sum.mlir index 36254e929..19481f264 100644 --- a/third_party/ascend/test/Conversion/TritonOp/xor_sum.mlir +++ b/third_party/ascend/test/Conversion/TritonOp/xor_sum.mlir @@ -1,125 +1,125 @@ -// RUN: triton-adapter-opt --triton-linearize --discrete-mask-access-conversion --triton-to-annotation --triton-to-unstructure --triton-to-hivm --triton-to-hfusion --triton-to-llvm --bubble-up-operation '--triton-to-linalg=global-kernel=false named-ops=True enable-nd2nz-on-vector=False' --split-input-file %s | FileCheck %s - -// === xor_sum === -// === i8 u8 version === -module { - tt.func public @fn_npu_u8_xor_sum( - %arg0: !tt.ptr, - %arg1: !tt.ptr - ) { - %c0_i8 = arith.constant 0 : i8 - %input = tt.splat %c0_i8 : i8 -> tensor<64x32xi8> - - %reduced = "tt.reduce"(%input) <{axis = 1 : i32}> ({ - ^bb0(%a: i8, %b: i8): - %xor = arith.xori %a, %b : i8 - tt.reduce.return %xor : i8 - }) : (tensor<64x32xi8>) -> tensor<64xi8> - - %ptrs = tt.splat %arg1 : !tt.ptr -> tensor<64x!tt.ptr> - %offs = tt.make_range {start = 0 : i32, end = 64 : i32} : tensor<64xi32> - %addrs = tt.addptr %ptrs, %offs : tensor<64x!tt.ptr>, tensor<64xi32> - tt.store %addrs, %reduced : tensor<64x!tt.ptr> - - tt.return - } -} - -// CHECK: linalg.reduce ins(%{{.*}} : tensor<64x32xi8>) outs(%{{.*}} : tensor<64xi8>) dimensions = [1] -// CHECK-NEXT: (%{{.*}}: i8, %{{.*}}: i8) { -// CHECK-NEXT: %{{[0-9]+}} = arith.xori %{{.*}}, %{{.*}} : i8 -// CHECK-NEXT: linalg.yield %{{[0-9]+}} : i8 -// CHECK: bufferization.materialize_in_destination - - -// === i16 u16 version === -module { - tt.func public @fn_npu_u16_xor_sum( - %arg0: !tt.ptr, - %arg1: !tt.ptr - ) { - %c0_i16 = arith.constant 0 : i16 - %input = tt.splat %c0_i16 : i16 -> tensor<64x32xi16> - - %reduced = "tt.reduce"(%input) <{axis = 1 : i32}> ({ - ^bb0(%a: i16, %b: i16): - %xor = arith.xori %a, %b : i16 - tt.reduce.return %xor : i16 - }) : (tensor<64x32xi16>) -> tensor<64xi16> - - %ptrs = tt.splat %arg1 : !tt.ptr -> tensor<64x!tt.ptr> - %offs = tt.make_range {start = 0 : i32, end = 64 : i32} : tensor<64xi32> - %addrs = tt.addptr %ptrs, %offs : tensor<64x!tt.ptr>, tensor<64xi32> - tt.store %addrs, %reduced : tensor<64x!tt.ptr> - - tt.return - } -} - -// CHECK: linalg.reduce ins(%{{.*}} : tensor<64x32xi16>) outs(%{{.*}} : tensor<64xi16>) dimensions = [1] -// CHECK-NEXT: (%{{.*}}: i16, %{{.*}}: i16) { -// CHECK-NEXT: %{{[0-9]+}} = arith.xori %{{.*}}, %{{.*}} : i16 -// CHECK-NEXT: linalg.yield %{{[0-9]+}} : i16 -// CHECK: bufferization.materialize_in_destination - - -// === i32 u32 version === -module { - tt.func public @fn_npu_u32_xor_sum( - %arg0: !tt.ptr, - %arg1: !tt.ptr - ) { - %c0_i32 = arith.constant 0 : i32 - %input = tt.splat %c0_i32 : i32 -> tensor<64x32xi32> - - %reduced = "tt.reduce"(%input) <{axis = 1 : i32}> ({ - ^bb0(%a: i32, %b: i32): - %xor = arith.xori %a, %b : i32 - tt.reduce.return %xor : i32 - }) : (tensor<64x32xi32>) -> tensor<64xi32> - - %ptrs = tt.splat %arg1 : !tt.ptr -> tensor<64x!tt.ptr> - %offs = tt.make_range {start = 0 : i32, end = 64 : i32} : tensor<64xi32> - %addrs = tt.addptr %ptrs, %offs : tensor<64x!tt.ptr>, tensor<64xi32> - tt.store %addrs, %reduced : tensor<64x!tt.ptr> - - tt.return - } -} - -// CHECK: linalg.reduce ins(%{{.*}} : tensor<64x32xi32>) outs(%{{.*}} : tensor<64xi32>) dimensions = [1] -// CHECK-NEXT: (%{{.*}}: i32, %{{.*}}: i32) { -// CHECK-NEXT: %{{[0-9]+}} = arith.xori %{{.*}}, %{{.*}} : i32 -// CHECK-NEXT: linalg.yield %{{[0-9]+}} : i32 -// CHECK: bufferization.materialize_in_destination - - -// === i64 u64 version === -module { - tt.func public @fn_npu_u64_xor_sum( - %arg0: !tt.ptr, - %arg1: !tt.ptr - ) { - %c0_i64 = arith.constant 0 : i64 - %input = tt.splat %c0_i64 : i64 -> tensor<64x32xi64> - - %reduced = "tt.reduce"(%input) <{axis = 1 : i32}> ({ - ^bb0(%a: i64, %b: i64): - %xor = arith.xori %a, %b : i64 - tt.reduce.return %xor : i64 - }) : (tensor<64x32xi64>) -> tensor<64xi64> - - %ptrs = tt.splat %arg1 : !tt.ptr -> tensor<64x!tt.ptr> - %offs = tt.make_range {start = 0 : i32, end = 64 : i32} : tensor<64xi32> - %addrs = tt.addptr %ptrs, %offs : tensor<64x!tt.ptr>, tensor<64xi32> - tt.store %addrs, %reduced : tensor<64x!tt.ptr> - - tt.return - } -} - -// CHECK: linalg.reduce ins(%{{.*}} : tensor<64x32xi64>) outs(%{{.*}} : tensor<64xi64>) dimensions = [1] -// CHECK-NEXT: (%{{.*}}: i64, %{{.*}}: i64) { -// CHECK-NEXT: %{{[0-9]+}} = arith.xori %{{.*}}, %{{.*}} : i64 -// CHECK-NEXT: linalg.yield %{{[0-9]+}} : i64 +// RUN: triton-adapter-opt --triton-linearize --discrete-mask-access-conversion --triton-to-annotation --triton-to-unstructure --triton-to-hivm --triton-to-hfusion --triton-to-llvm --bubble-up-operation '--triton-to-linalg=global-kernel=false named-ops=True enable-nd2nz-on-vector=False' --split-input-file %s | FileCheck %s + +// === xor_sum === +// === i8 u8 version === +module { + tt.func public @fn_npu_u8_xor_sum( + %arg0: !tt.ptr, + %arg1: !tt.ptr + ) { + %c0_i8 = arith.constant 0 : i8 + %input = tt.splat %c0_i8 : i8 -> tensor<64x32xi8> + + %reduced = "tt.reduce"(%input) <{axis = 1 : i32}> ({ + ^bb0(%a: i8, %b: i8): + %xor = arith.xori %a, %b : i8 + tt.reduce.return %xor : i8 + }) : (tensor<64x32xi8>) -> tensor<64xi8> + + %ptrs = tt.splat %arg1 : !tt.ptr -> tensor<64x!tt.ptr> + %offs = tt.make_range {start = 0 : i32, end = 64 : i32} : tensor<64xi32> + %addrs = tt.addptr %ptrs, %offs : tensor<64x!tt.ptr>, tensor<64xi32> + tt.store %addrs, %reduced : tensor<64x!tt.ptr> + + tt.return + } +} + +// CHECK: linalg.reduce ins(%{{.*}} : tensor<64x32xi8>) outs(%{{.*}} : tensor<64xi8>) dimensions = [1] +// CHECK-NEXT: (%{{.*}}: i8, %{{.*}}: i8) { +// CHECK-NEXT: %{{[0-9]+}} = arith.xori %{{.*}}, %{{.*}} : i8 +// CHECK-NEXT: linalg.yield %{{[0-9]+}} : i8 +// CHECK: bufferization.materialize_in_destination + + +// === i16 u16 version === +module { + tt.func public @fn_npu_u16_xor_sum( + %arg0: !tt.ptr, + %arg1: !tt.ptr + ) { + %c0_i16 = arith.constant 0 : i16 + %input = tt.splat %c0_i16 : i16 -> tensor<64x32xi16> + + %reduced = "tt.reduce"(%input) <{axis = 1 : i32}> ({ + ^bb0(%a: i16, %b: i16): + %xor = arith.xori %a, %b : i16 + tt.reduce.return %xor : i16 + }) : (tensor<64x32xi16>) -> tensor<64xi16> + + %ptrs = tt.splat %arg1 : !tt.ptr -> tensor<64x!tt.ptr> + %offs = tt.make_range {start = 0 : i32, end = 64 : i32} : tensor<64xi32> + %addrs = tt.addptr %ptrs, %offs : tensor<64x!tt.ptr>, tensor<64xi32> + tt.store %addrs, %reduced : tensor<64x!tt.ptr> + + tt.return + } +} + +// CHECK: linalg.reduce ins(%{{.*}} : tensor<64x32xi16>) outs(%{{.*}} : tensor<64xi16>) dimensions = [1] +// CHECK-NEXT: (%{{.*}}: i16, %{{.*}}: i16) { +// CHECK-NEXT: %{{[0-9]+}} = arith.xori %{{.*}}, %{{.*}} : i16 +// CHECK-NEXT: linalg.yield %{{[0-9]+}} : i16 +// CHECK: bufferization.materialize_in_destination + + +// === i32 u32 version === +module { + tt.func public @fn_npu_u32_xor_sum( + %arg0: !tt.ptr, + %arg1: !tt.ptr + ) { + %c0_i32 = arith.constant 0 : i32 + %input = tt.splat %c0_i32 : i32 -> tensor<64x32xi32> + + %reduced = "tt.reduce"(%input) <{axis = 1 : i32}> ({ + ^bb0(%a: i32, %b: i32): + %xor = arith.xori %a, %b : i32 + tt.reduce.return %xor : i32 + }) : (tensor<64x32xi32>) -> tensor<64xi32> + + %ptrs = tt.splat %arg1 : !tt.ptr -> tensor<64x!tt.ptr> + %offs = tt.make_range {start = 0 : i32, end = 64 : i32} : tensor<64xi32> + %addrs = tt.addptr %ptrs, %offs : tensor<64x!tt.ptr>, tensor<64xi32> + tt.store %addrs, %reduced : tensor<64x!tt.ptr> + + tt.return + } +} + +// CHECK: linalg.reduce ins(%{{.*}} : tensor<64x32xi32>) outs(%{{.*}} : tensor<64xi32>) dimensions = [1] +// CHECK-NEXT: (%{{.*}}: i32, %{{.*}}: i32) { +// CHECK-NEXT: %{{[0-9]+}} = arith.xori %{{.*}}, %{{.*}} : i32 +// CHECK-NEXT: linalg.yield %{{[0-9]+}} : i32 +// CHECK: bufferization.materialize_in_destination + + +// === i64 u64 version === +module { + tt.func public @fn_npu_u64_xor_sum( + %arg0: !tt.ptr, + %arg1: !tt.ptr + ) { + %c0_i64 = arith.constant 0 : i64 + %input = tt.splat %c0_i64 : i64 -> tensor<64x32xi64> + + %reduced = "tt.reduce"(%input) <{axis = 1 : i32}> ({ + ^bb0(%a: i64, %b: i64): + %xor = arith.xori %a, %b : i64 + tt.reduce.return %xor : i64 + }) : (tensor<64x32xi64>) -> tensor<64xi64> + + %ptrs = tt.splat %arg1 : !tt.ptr -> tensor<64x!tt.ptr> + %offs = tt.make_range {start = 0 : i32, end = 64 : i32} : tensor<64xi32> + %addrs = tt.addptr %ptrs, %offs : tensor<64x!tt.ptr>, tensor<64xi32> + tt.store %addrs, %reduced : tensor<64x!tt.ptr> + + tt.return + } +} + +// CHECK: linalg.reduce ins(%{{.*}} : tensor<64x32xi64>) outs(%{{.*}} : tensor<64xi64>) dimensions = [1] +// CHECK-NEXT: (%{{.*}}: i64, %{{.*}}: i64) { +// CHECK-NEXT: %{{[0-9]+}} = arith.xori %{{.*}}, %{{.*}} : i64 +// CHECK-NEXT: linalg.yield %{{[0-9]+}} : i64 // CHECK: bufferization.materialize_in_destination diff --git a/third_party/ascend/test/Conversion/TritonOp/zeros.mlir b/third_party/ascend/test/Conversion/TritonOp/zeros.mlir index fa2c20686..8c790db9d 100644 --- a/third_party/ascend/test/Conversion/TritonOp/zeros.mlir +++ b/third_party/ascend/test/Conversion/TritonOp/zeros.mlir @@ -1,200 +1,200 @@ -// RUN: triton-adapter-opt --triton-linearize --discrete-mask-access-conversion --triton-to-annotation --triton-to-unstructure --triton-to-hivm --triton-to-hfusion --triton-to-llvm --bubble-up-operation '--triton-to-linalg=global-kernel=false named-ops=True enable-nd2nz-on-vector=False' --split-input-file %s | FileCheck %s - -// === zeros === -// === i8 u8 version === -module { - tt.func public @fn_npu_u8(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { - %cst = arith.constant dense<0> : tensor<8x8x4xi8> - %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> - %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> - %2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> - %3 = tt.expand_dims %2 {axis = 2 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> - %4 = tt.expand_dims %0 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> - %5 = tt.expand_dims %4 {axis = 2 : i32} : tensor<1x8xi32> -> tensor<1x8x1xi32> - %6 = tt.expand_dims %1 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> - %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<1x4xi32> -> tensor<1x1x4xi32> - %8 = tt.broadcast %3 : tensor<8x1x1xi32> -> tensor<8x8x1xi32> - %9 = tt.broadcast %5 : tensor<1x8x1xi32> -> tensor<8x8x1xi32> - %10 = arith.addi %8, %9 : tensor<8x8x1xi32> - %11 = tt.broadcast %10 : tensor<8x8x1xi32> -> tensor<8x8x4xi32> - %12 = tt.broadcast %7 : tensor<1x1x4xi32> -> tensor<8x8x4xi32> - %13 = arith.addi %11, %12 : tensor<8x8x4xi32> - %14 = tt.splat %arg0 : !tt.ptr -> tensor<8x8x4x!tt.ptr> - %15 = tt.addptr %14, %13 : tensor<8x8x4x!tt.ptr>, tensor<8x8x4xi32> - tt.store %15, %cst : tensor<8x8x4x!tt.ptr> - tt.return - } -} - -// CHECK: arith.constant 0 : i8 -// CHECK: tensor.empty() : tensor<8x8x4xi8> -// CHECK: linalg.fill ins(%{{.*}} : i8) outs(%{{.*}} : tensor<8x8x4xi8>) -> tensor<8x8x4xi8> -// CHECK: memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [8, 8, 4] -// CHECK: bufferization.materialize_in_destination %{{.*}} in writable %{{.*}} - - -// === i16 u16 version === -module { - tt.func public @fn_npu_u16(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { - %cst = arith.constant dense<0> : tensor<8x8x4xi16> - %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> - %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> - %2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> - %3 = tt.expand_dims %2 {axis = 2 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> - %4 = tt.expand_dims %0 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> - %5 = tt.expand_dims %4 {axis = 2 : i32} : tensor<1x8xi32> -> tensor<1x8x1xi32> - %6 = tt.expand_dims %1 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> - %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<1x4xi32> -> tensor<1x1x4xi32> - %8 = tt.broadcast %3 : tensor<8x1x1xi32> -> tensor<8x8x1xi32> - %9 = tt.broadcast %5 : tensor<1x8x1xi32> -> tensor<8x8x1xi32> - %10 = arith.addi %8, %9 : tensor<8x8x1xi32> - %11 = tt.broadcast %10 : tensor<8x8x1xi32> -> tensor<8x8x4xi32> - %12 = tt.broadcast %7 : tensor<1x1x4xi32> -> tensor<8x8x4xi32> - %13 = arith.addi %11, %12 : tensor<8x8x4xi32> - %14 = tt.splat %arg0 : !tt.ptr -> tensor<8x8x4x!tt.ptr> - %15 = tt.addptr %14, %13 : tensor<8x8x4x!tt.ptr>, tensor<8x8x4xi32> - tt.store %15, %cst : tensor<8x8x4x!tt.ptr> - tt.return - } -} - -// ===== CHECKS ===== -// CHECK-DAG: arith.constant 0 : i16 -// CHECK-DAG: tensor.empty() : tensor<8x8x4xi16> -// CHECK-DAG: linalg.fill ins(%{{.*}} : i16) outs(%{{.*}} : tensor<8x8x4xi16>) -> tensor<8x8x4xi16> - - -// === i32 u32 version === -module { - tt.func public @fn_npu_i32(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { - %cst = arith.constant dense<0> : tensor<8x8x4xi32> - %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> - %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> - %2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> - %3 = tt.expand_dims %2 {axis = 2 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> - %4 = tt.expand_dims %0 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> - %5 = tt.expand_dims %4 {axis = 2 : i32} : tensor<1x8xi32> -> tensor<1x8x1xi32> - %6 = tt.expand_dims %1 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> - %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<1x4xi32> -> tensor<1x1x4xi32> - %8 = tt.broadcast %3 : tensor<8x1x1xi32> -> tensor<8x8x1xi32> - %9 = tt.broadcast %5 : tensor<1x8x1xi32> -> tensor<8x8x1xi32> - %10 = arith.addi %8, %9 : tensor<8x8x1xi32> - %11 = tt.broadcast %10 : tensor<8x8x1xi32> -> tensor<8x8x4xi32> - %12 = tt.broadcast %7 : tensor<1x1x4xi32> -> tensor<8x8x4xi32> - %13 = arith.addi %11, %12 : tensor<8x8x4xi32> - %14 = tt.splat %arg0 : !tt.ptr -> tensor<8x8x4x!tt.ptr> - %15 = tt.addptr %14, %13 : tensor<8x8x4x!tt.ptr>, tensor<8x8x4xi32> - tt.store %15, %cst : tensor<8x8x4x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: arith.constant 0 : i32 -// CHECK-DAG: tensor.empty() : tensor<8x8x4xi32> -// CHECK-DAG: linalg.fill ins(%{{.*}} : i32) outs(%{{.*}} : tensor<8x8x4xi32>) -> tensor<8x8x4xi32> - - -// === i64 u64 version === -module { - tt.func public @fn_npu_i64(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { - %cst = arith.constant dense<0> : tensor<8x8x4xi64> - %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> - %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> - %2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> - %3 = tt.expand_dims %2 {axis = 2 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> - %4 = tt.expand_dims %0 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> - %5 = tt.expand_dims %4 {axis = 2 : i32} : tensor<1x8xi32> -> tensor<1x8x1xi32> - %6 = tt.expand_dims %1 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> - %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<1x4xi32> -> tensor<1x1x4xi32> - %8 = tt.broadcast %3 : tensor<8x1x1xi32> -> tensor<8x8x1xi32> - %9 = tt.broadcast %5 : tensor<1x8x1xi32> -> tensor<8x8x1xi32> - %10 = arith.addi %8, %9 : tensor<8x8x1xi32> - %11 = tt.broadcast %10 : tensor<8x8x1xi32> -> tensor<8x8x4xi32> - %12 = tt.broadcast %7 : tensor<1x1x4xi32> -> tensor<8x8x4xi32> - %13 = arith.addi %11, %12 : tensor<8x8x4xi32> - %14 = tt.splat %arg0 : !tt.ptr -> tensor<8x8x4x!tt.ptr> - %15 = tt.addptr %14, %13 : tensor<8x8x4x!tt.ptr>, tensor<8x8x4xi32> - tt.store %15, %cst : tensor<8x8x4x!tt.ptr> - tt.return - } -} - -// CHECK-DAG: arith.constant 0 : i64 -// CHECK-DAG: tensor.empty() : tensor<8x8x4xi64> -// CHECK-DAG: linalg.fill ins(%{{.*}} : i64) outs(%{{.*}} : tensor<8x8x4xi64>) -> tensor<8x8x4xi64> - - -// === f8E4M3FN version === -module { - tt.func public @fn_npu_f8E4M3FN(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { - %cst = arith.constant dense<0.000000e+00> : tensor<8x8x4xf8E4M3FN> - %cst_0 = arith.constant dense<4> : tensor<1x8x1xi32> - %cst_1 = arith.constant dense<4> : tensor<8x1x1xi32> - %cst_2 = arith.constant dense<8> : tensor<8x1x1xi32> - %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> - %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> - %2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> - %3 = tt.expand_dims %2 {axis = 2 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> - %4 = arith.muli %3, %cst_2 : tensor<8x1x1xi32> - %5 = arith.muli %4, %cst_1 : tensor<8x1x1xi32> - %6 = tt.expand_dims %0 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> - %7 = tt.expand_dims %6 {axis = 2 : i32} : tensor<1x8xi32> -> tensor<1x8x1xi32> - %8 = arith.muli %7, %cst_0 : tensor<1x8x1xi32> - %9 = tt.broadcast %5 : tensor<8x1x1xi32> -> tensor<8x8x1xi32> - %10 = tt.broadcast %8 : tensor<1x8x1xi32> -> tensor<8x8x1xi32> - %11 = arith.addi %9, %10 : tensor<8x8x1xi32> - %12 = tt.expand_dims %1 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> - %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<1x4xi32> -> tensor<1x1x4xi32> - %14 = tt.broadcast %11 : tensor<8x8x1xi32> -> tensor<8x8x4xi32> - %15 = tt.broadcast %13 : tensor<1x1x4xi32> -> tensor<8x8x4xi32> - %16 = arith.addi %14, %15 : tensor<8x8x4xi32> - %17 = tt.splat %arg0 : !tt.ptr -> tensor<8x8x4x!tt.ptr> - %18 = tt.addptr %17, %16 : tensor<8x8x4x!tt.ptr>, tensor<8x8x4xi32> - tt.store %18, %cst : tensor<8x8x4x!tt.ptr> - tt.return - } -} - -// ----- - -// CHECK-DAG: arith.constant 0.0 : f8E4M3FN -// CHECK-DAG: tensor.empty() : tensor<8x8x4xf8E4M3FN> -// CHECK-DAG: linalg.fill ins(%{{.*}} : f8E4M3FN) outs(%{{.*}} : tensor<8x8x4xf8E4M3FN>) -> tensor<8x8x4xf8E4M3FN> - - -// === f8E5M2 version === -module { - tt.func public @fn_npu_f8E5M2(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { - %cst = arith.constant dense<0.000000e+00> : tensor<8x8x4xf8E5M2> - %cst_0 = arith.constant dense<4> : tensor<1x8x1xi32> - %cst_1 = arith.constant dense<4> : tensor<8x1x1xi32> - %cst_2 = arith.constant dense<8> : tensor<8x1x1xi32> - %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> - %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> - %2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> - %3 = tt.expand_dims %2 {axis = 2 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> - %4 = arith.muli %3, %cst_2 : tensor<8x1x1xi32> - %5 = arith.muli %4, %cst_1 : tensor<8x1x1xi32> - %6 = tt.expand_dims %0 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> - %7 = tt.expand_dims %6 {axis = 2 : i32} : tensor<1x8xi32> -> tensor<1x8x1xi32> - %8 = arith.muli %7, %cst_0 : tensor<1x8x1xi32> - %9 = tt.broadcast %5 : tensor<8x1x1xi32> -> tensor<8x8x1xi32> - %10 = tt.broadcast %8 : tensor<1x8x1xi32> -> tensor<8x8x1xi32> - %11 = arith.addi %9, %10 : tensor<8x8x1xi32> - %12 = tt.expand_dims %1 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> - %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<1x4xi32> -> tensor<1x1x4xi32> - %14 = tt.broadcast %11 : tensor<8x8x1xi32> -> tensor<8x8x4xi32> - %15 = tt.broadcast %13 : tensor<1x1x4xi32> -> tensor<8x8x4xi32> - %16 = arith.addi %14, %15 : tensor<8x8x4xi32> - %17 = tt.splat %arg0 : !tt.ptr -> tensor<8x8x4x!tt.ptr> - %18 = tt.addptr %17, %16 : tensor<8x8x4x!tt.ptr>, tensor<8x8x4xi32> - tt.store %18, %cst : tensor<8x8x4x!tt.ptr> - tt.return - } -} - -// ----- - -// CHECK-DAG: arith.constant 0.0 : f8E5M2 -// CHECK-DAG: tensor.empty() : tensor<8x8x4xf8E5M2> -// CHECK-DAG: linalg.fill ins(%{{.*}} : f8E5M2) outs(%{{.*}} : tensor<8x8x4xf8E5M2>) -> tensor<8x8x4xf8E5M2> +// RUN: triton-adapter-opt --triton-linearize --discrete-mask-access-conversion --triton-to-annotation --triton-to-unstructure --triton-to-hivm --triton-to-hfusion --triton-to-llvm --bubble-up-operation '--triton-to-linalg=global-kernel=false named-ops=True enable-nd2nz-on-vector=False' --split-input-file %s | FileCheck %s + +// === zeros === +// === i8 u8 version === +module { + tt.func public @fn_npu_u8(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<0> : tensor<8x8x4xi8> + %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> + %3 = tt.expand_dims %2 {axis = 2 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> + %4 = tt.expand_dims %0 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> + %5 = tt.expand_dims %4 {axis = 2 : i32} : tensor<1x8xi32> -> tensor<1x8x1xi32> + %6 = tt.expand_dims %1 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> + %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<1x4xi32> -> tensor<1x1x4xi32> + %8 = tt.broadcast %3 : tensor<8x1x1xi32> -> tensor<8x8x1xi32> + %9 = tt.broadcast %5 : tensor<1x8x1xi32> -> tensor<8x8x1xi32> + %10 = arith.addi %8, %9 : tensor<8x8x1xi32> + %11 = tt.broadcast %10 : tensor<8x8x1xi32> -> tensor<8x8x4xi32> + %12 = tt.broadcast %7 : tensor<1x1x4xi32> -> tensor<8x8x4xi32> + %13 = arith.addi %11, %12 : tensor<8x8x4xi32> + %14 = tt.splat %arg0 : !tt.ptr -> tensor<8x8x4x!tt.ptr> + %15 = tt.addptr %14, %13 : tensor<8x8x4x!tt.ptr>, tensor<8x8x4xi32> + tt.store %15, %cst : tensor<8x8x4x!tt.ptr> + tt.return + } +} + +// CHECK: arith.constant 0 : i8 +// CHECK: tensor.empty() : tensor<8x8x4xi8> +// CHECK: linalg.fill ins(%{{.*}} : i8) outs(%{{.*}} : tensor<8x8x4xi8>) -> tensor<8x8x4xi8> +// CHECK: memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [8, 8, 4] +// CHECK: bufferization.materialize_in_destination %{{.*}} in writable %{{.*}} + + +// === i16 u16 version === +module { + tt.func public @fn_npu_u16(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<0> : tensor<8x8x4xi16> + %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> + %3 = tt.expand_dims %2 {axis = 2 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> + %4 = tt.expand_dims %0 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> + %5 = tt.expand_dims %4 {axis = 2 : i32} : tensor<1x8xi32> -> tensor<1x8x1xi32> + %6 = tt.expand_dims %1 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> + %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<1x4xi32> -> tensor<1x1x4xi32> + %8 = tt.broadcast %3 : tensor<8x1x1xi32> -> tensor<8x8x1xi32> + %9 = tt.broadcast %5 : tensor<1x8x1xi32> -> tensor<8x8x1xi32> + %10 = arith.addi %8, %9 : tensor<8x8x1xi32> + %11 = tt.broadcast %10 : tensor<8x8x1xi32> -> tensor<8x8x4xi32> + %12 = tt.broadcast %7 : tensor<1x1x4xi32> -> tensor<8x8x4xi32> + %13 = arith.addi %11, %12 : tensor<8x8x4xi32> + %14 = tt.splat %arg0 : !tt.ptr -> tensor<8x8x4x!tt.ptr> + %15 = tt.addptr %14, %13 : tensor<8x8x4x!tt.ptr>, tensor<8x8x4xi32> + tt.store %15, %cst : tensor<8x8x4x!tt.ptr> + tt.return + } +} + +// ===== CHECKS ===== +// CHECK-DAG: arith.constant 0 : i16 +// CHECK-DAG: tensor.empty() : tensor<8x8x4xi16> +// CHECK-DAG: linalg.fill ins(%{{.*}} : i16) outs(%{{.*}} : tensor<8x8x4xi16>) -> tensor<8x8x4xi16> + + +// === i32 u32 version === +module { + tt.func public @fn_npu_i32(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<0> : tensor<8x8x4xi32> + %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> + %3 = tt.expand_dims %2 {axis = 2 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> + %4 = tt.expand_dims %0 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> + %5 = tt.expand_dims %4 {axis = 2 : i32} : tensor<1x8xi32> -> tensor<1x8x1xi32> + %6 = tt.expand_dims %1 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> + %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<1x4xi32> -> tensor<1x1x4xi32> + %8 = tt.broadcast %3 : tensor<8x1x1xi32> -> tensor<8x8x1xi32> + %9 = tt.broadcast %5 : tensor<1x8x1xi32> -> tensor<8x8x1xi32> + %10 = arith.addi %8, %9 : tensor<8x8x1xi32> + %11 = tt.broadcast %10 : tensor<8x8x1xi32> -> tensor<8x8x4xi32> + %12 = tt.broadcast %7 : tensor<1x1x4xi32> -> tensor<8x8x4xi32> + %13 = arith.addi %11, %12 : tensor<8x8x4xi32> + %14 = tt.splat %arg0 : !tt.ptr -> tensor<8x8x4x!tt.ptr> + %15 = tt.addptr %14, %13 : tensor<8x8x4x!tt.ptr>, tensor<8x8x4xi32> + tt.store %15, %cst : tensor<8x8x4x!tt.ptr> + tt.return + } +} + +// CHECK-DAG: arith.constant 0 : i32 +// CHECK-DAG: tensor.empty() : tensor<8x8x4xi32> +// CHECK-DAG: linalg.fill ins(%{{.*}} : i32) outs(%{{.*}} : tensor<8x8x4xi32>) -> tensor<8x8x4xi32> + + +// === i64 u64 version === +module { + tt.func public @fn_npu_i64(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<0> : tensor<8x8x4xi64> + %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> + %3 = tt.expand_dims %2 {axis = 2 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> + %4 = tt.expand_dims %0 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> + %5 = tt.expand_dims %4 {axis = 2 : i32} : tensor<1x8xi32> -> tensor<1x8x1xi32> + %6 = tt.expand_dims %1 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> + %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<1x4xi32> -> tensor<1x1x4xi32> + %8 = tt.broadcast %3 : tensor<8x1x1xi32> -> tensor<8x8x1xi32> + %9 = tt.broadcast %5 : tensor<1x8x1xi32> -> tensor<8x8x1xi32> + %10 = arith.addi %8, %9 : tensor<8x8x1xi32> + %11 = tt.broadcast %10 : tensor<8x8x1xi32> -> tensor<8x8x4xi32> + %12 = tt.broadcast %7 : tensor<1x1x4xi32> -> tensor<8x8x4xi32> + %13 = arith.addi %11, %12 : tensor<8x8x4xi32> + %14 = tt.splat %arg0 : !tt.ptr -> tensor<8x8x4x!tt.ptr> + %15 = tt.addptr %14, %13 : tensor<8x8x4x!tt.ptr>, tensor<8x8x4xi32> + tt.store %15, %cst : tensor<8x8x4x!tt.ptr> + tt.return + } +} + +// CHECK-DAG: arith.constant 0 : i64 +// CHECK-DAG: tensor.empty() : tensor<8x8x4xi64> +// CHECK-DAG: linalg.fill ins(%{{.*}} : i64) outs(%{{.*}} : tensor<8x8x4xi64>) -> tensor<8x8x4xi64> + + +// === f8E4M3FN version === +module { + tt.func public @fn_npu_f8E4M3FN(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<0.000000e+00> : tensor<8x8x4xf8E4M3FN> + %cst_0 = arith.constant dense<4> : tensor<1x8x1xi32> + %cst_1 = arith.constant dense<4> : tensor<8x1x1xi32> + %cst_2 = arith.constant dense<8> : tensor<8x1x1xi32> + %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> + %3 = tt.expand_dims %2 {axis = 2 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> + %4 = arith.muli %3, %cst_2 : tensor<8x1x1xi32> + %5 = arith.muli %4, %cst_1 : tensor<8x1x1xi32> + %6 = tt.expand_dims %0 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> + %7 = tt.expand_dims %6 {axis = 2 : i32} : tensor<1x8xi32> -> tensor<1x8x1xi32> + %8 = arith.muli %7, %cst_0 : tensor<1x8x1xi32> + %9 = tt.broadcast %5 : tensor<8x1x1xi32> -> tensor<8x8x1xi32> + %10 = tt.broadcast %8 : tensor<1x8x1xi32> -> tensor<8x8x1xi32> + %11 = arith.addi %9, %10 : tensor<8x8x1xi32> + %12 = tt.expand_dims %1 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<1x4xi32> -> tensor<1x1x4xi32> + %14 = tt.broadcast %11 : tensor<8x8x1xi32> -> tensor<8x8x4xi32> + %15 = tt.broadcast %13 : tensor<1x1x4xi32> -> tensor<8x8x4xi32> + %16 = arith.addi %14, %15 : tensor<8x8x4xi32> + %17 = tt.splat %arg0 : !tt.ptr -> tensor<8x8x4x!tt.ptr> + %18 = tt.addptr %17, %16 : tensor<8x8x4x!tt.ptr>, tensor<8x8x4xi32> + tt.store %18, %cst : tensor<8x8x4x!tt.ptr> + tt.return + } +} + +// ----- + +// CHECK-DAG: arith.constant 0.0 : f8E4M3FN +// CHECK-DAG: tensor.empty() : tensor<8x8x4xf8E4M3FN> +// CHECK-DAG: linalg.fill ins(%{{.*}} : f8E4M3FN) outs(%{{.*}} : tensor<8x8x4xf8E4M3FN>) -> tensor<8x8x4xf8E4M3FN> + + +// === f8E5M2 version === +module { + tt.func public @fn_npu_f8E5M2(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<0.000000e+00> : tensor<8x8x4xf8E5M2> + %cst_0 = arith.constant dense<4> : tensor<1x8x1xi32> + %cst_1 = arith.constant dense<4> : tensor<8x1x1xi32> + %cst_2 = arith.constant dense<8> : tensor<8x1x1xi32> + %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> + %3 = tt.expand_dims %2 {axis = 2 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> + %4 = arith.muli %3, %cst_2 : tensor<8x1x1xi32> + %5 = arith.muli %4, %cst_1 : tensor<8x1x1xi32> + %6 = tt.expand_dims %0 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> + %7 = tt.expand_dims %6 {axis = 2 : i32} : tensor<1x8xi32> -> tensor<1x8x1xi32> + %8 = arith.muli %7, %cst_0 : tensor<1x8x1xi32> + %9 = tt.broadcast %5 : tensor<8x1x1xi32> -> tensor<8x8x1xi32> + %10 = tt.broadcast %8 : tensor<1x8x1xi32> -> tensor<8x8x1xi32> + %11 = arith.addi %9, %10 : tensor<8x8x1xi32> + %12 = tt.expand_dims %1 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<1x4xi32> -> tensor<1x1x4xi32> + %14 = tt.broadcast %11 : tensor<8x8x1xi32> -> tensor<8x8x4xi32> + %15 = tt.broadcast %13 : tensor<1x1x4xi32> -> tensor<8x8x4xi32> + %16 = arith.addi %14, %15 : tensor<8x8x4xi32> + %17 = tt.splat %arg0 : !tt.ptr -> tensor<8x8x4x!tt.ptr> + %18 = tt.addptr %17, %16 : tensor<8x8x4x!tt.ptr>, tensor<8x8x4xi32> + tt.store %18, %cst : tensor<8x8x4x!tt.ptr> + tt.return + } +} + +// ----- + +// CHECK-DAG: arith.constant 0.0 : f8E5M2 +// CHECK-DAG: tensor.empty() : tensor<8x8x4xf8E5M2> +// CHECK-DAG: linalg.fill ins(%{{.*}} : f8E5M2) outs(%{{.*}} : tensor<8x8x4xf8E5M2>) -> tensor<8x8x4xf8E5M2> diff --git a/third_party/ascend/test/Conversion/TritonOp/zeros_like.mlir b/third_party/ascend/test/Conversion/TritonOp/zeros_like.mlir index f0d0fddc8..49eb25be0 100644 --- a/third_party/ascend/test/Conversion/TritonOp/zeros_like.mlir +++ b/third_party/ascend/test/Conversion/TritonOp/zeros_like.mlir @@ -1,135 +1,135 @@ -// RUN: triton-adapter-opt --triton-linearize --discrete-mask-access-conversion --triton-to-annotation --triton-to-unstructure --triton-to-hivm --triton-to-hfusion --triton-to-llvm --bubble-up-operation '--triton-to-linalg=global-kernel=false named-ops=True enable-nd2nz-on-vector=False' --split-input-file %s | FileCheck %s - -// === zeroslike === -// === i8 u8 version === -module { - tt.func public @fn_npu_u8( - %arg0: !tt.ptr {tt.divisibility = 16 : i32}, - %arg1: !tt.ptr {tt.divisibility = 16 : i32} - ) { - %cst = arith.constant dense<0> : tensor<8x8x4xi8> - %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> - %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> - %2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> - %3 = tt.expand_dims %2 {axis = 2 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> - %4 = tt.expand_dims %0 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> - %5 = tt.expand_dims %4 {axis = 2 : i32} : tensor<1x8xi32> -> tensor<1x8x1xi32> - %6 = tt.expand_dims %1 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> - %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<1x4xi32> -> tensor<1x1x4xi32> - %8 = tt.broadcast %3 : tensor<8x1x1xi32> -> tensor<8x8x1xi32> - %9 = tt.broadcast %5 : tensor<1x8x1xi32> -> tensor<8x8x1xi32> - %10 = arith.addi %8, %9 : tensor<8x8x1xi32> - %11 = tt.broadcast %10 : tensor<8x8x1xi32> -> tensor<8x8x4xi32> - %12 = tt.broadcast %7 : tensor<1x1x4xi32> -> tensor<8x8x4xi32> - %13 = arith.addi %11, %12 : tensor<8x8x4xi32> - %14 = tt.splat %arg0 : !tt.ptr -> tensor<8x8x4x!tt.ptr> - %15 = tt.addptr %14, %13 : tensor<8x8x4x!tt.ptr>, tensor<8x8x4xi32> - tt.store %15, %cst : tensor<8x8x4x!tt.ptr> - tt.return - } -} - -// === i16 u16 version === -module { - tt.func public @fn_npu_u16( - %arg0: !tt.ptr {tt.divisibility = 16 : i32}, - %arg1: !tt.ptr {tt.divisibility = 16 : i32} - ) { - %cst = arith.constant dense<0> : tensor<8x8x4xi16> - %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> - %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> - %2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> - %3 = tt.expand_dims %2 {axis = 2 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> - %4 = tt.expand_dims %0 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> - %5 = tt.expand_dims %4 {axis = 2 : i32} : tensor<1x8xi32> -> tensor<1x8x1xi32> - %6 = tt.expand_dims %1 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> - %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<1x4xi32> -> tensor<1x1x4xi32> - %8 = tt.broadcast %3 : tensor<8x1x1xi32> -> tensor<8x8x1xi32> - %9 = tt.broadcast %5 : tensor<1x8x1xi32> -> tensor<8x8x1xi32> - %10 = arith.addi %8, %9 : tensor<8x8x1xi32> - %11 = tt.broadcast %10 : tensor<8x8x1xi32> -> tensor<8x8x4xi32> - %12 = tt.broadcast %7 : tensor<1x1x4xi32> -> tensor<8x8x4xi32> - %13 = arith.addi %11, %12 : tensor<8x8x4xi32> - %14 = tt.splat %arg0 : !tt.ptr -> tensor<8x8x4x!tt.ptr> - %15 = tt.addptr %14, %13 : tensor<8x8x4x!tt.ptr>, tensor<8x8x4xi32> - tt.store %15, %cst : tensor<8x8x4x!tt.ptr> - tt.return - } -} - -// === i32 u32 version === -module { - tt.func public @fn_npu_i32( - %arg0: !tt.ptr {tt.divisibility = 16 : i32}, - %arg1: !tt.ptr {tt.divisibility = 16 : i32} - ) { - %cst = arith.constant dense<0> : tensor<8x8x4xi32> - %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> - %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> - %2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> - %3 = tt.expand_dims %2 {axis = 2 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> - %4 = tt.expand_dims %0 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> - %5 = tt.expand_dims %4 {axis = 2 : i32} : tensor<1x8xi32> -> tensor<1x8x1xi32> - %6 = tt.expand_dims %1 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> - %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<1x4xi32> -> tensor<1x1x4xi32> - %8 = tt.broadcast %3 : tensor<8x1x1xi32> -> tensor<8x8x1xi32> - %9 = tt.broadcast %5 : tensor<1x8x1xi32> -> tensor<8x8x1xi32> - %10 = arith.addi %8, %9 : tensor<8x8x1xi32> - %11 = tt.broadcast %10 : tensor<8x8x1xi32> -> tensor<8x8x4xi32> - %12 = tt.broadcast %7 : tensor<1x1x4xi32> -> tensor<8x8x4xi32> - %13 = arith.addi %11, %12 : tensor<8x8x4xi32> - %14 = tt.splat %arg0 : !tt.ptr -> tensor<8x8x4x!tt.ptr> - %15 = tt.addptr %14, %13 : tensor<8x8x4x!tt.ptr>, tensor<8x8x4xi32> - tt.store %15, %cst : tensor<8x8x4x!tt.ptr> - tt.return - } -} - -// === i64 u64 version === -module { - tt.func public @fn_npu_i64( - %arg0: !tt.ptr {tt.divisibility = 16 : i32}, - %arg1: !tt.ptr {tt.divisibility = 16 : i32} - ) { - %cst = arith.constant dense<0> : tensor<8x8x4xi64> - %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> - %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> - %2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> - %3 = tt.expand_dims %2 {axis = 2 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> - %4 = tt.expand_dims %0 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> - %5 = tt.expand_dims %4 {axis = 2 : i32} : tensor<1x8xi32> -> tensor<1x8x1xi32> - %6 = tt.expand_dims %1 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> - %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<1x4xi32> -> tensor<1x1x4xi32> - %8 = tt.broadcast %3 : tensor<8x1x1xi32> -> tensor<8x8x1xi32> - %9 = tt.broadcast %5 : tensor<1x8x1xi32> -> tensor<8x8x1xi32> - %10 = arith.addi %8, %9 : tensor<8x8x1xi32> - %11 = tt.broadcast %10 : tensor<8x8x1xi32> -> tensor<8x8x4xi32> - %12 = tt.broadcast %7 : tensor<1x1x4xi32> -> tensor<8x8x4xi32> - %13 = arith.addi %11, %12 : tensor<8x8x4xi32> - %14 = tt.splat %arg0 : !tt.ptr -> tensor<8x8x4x!tt.ptr> - %15 = tt.addptr %14, %13 : tensor<8x8x4x!tt.ptr>, tensor<8x8x4xi32> - tt.store %15, %cst : tensor<8x8x4x!tt.ptr> - tt.return - } -} - -// ===== CHECKS ===== -// CHECK-DAG: arith.constant 0 : i8 -// CHECK-DAG: tensor.empty() : tensor<8x8x4xi8> -// CHECK-DAG: linalg.fill ins(%{{.*}} : i8) outs(%{{.*}} : tensor<8x8x4xi8>) -> tensor<8x8x4xi8> - -// CHECK-DAG: arith.constant 0 : i16 -// CHECK-DAG: tensor.empty() : tensor<8x8x4xi16> -// CHECK-DAG: linalg.fill ins(%{{.*}} : i16) outs(%{{.*}} : tensor<8x8x4xi16>) -> tensor<8x8x4xi16> - -// CHECK-DAG: arith.constant 0 : i32 -// CHECK-DAG: tensor.empty() : tensor<8x8x4xi32> -// CHECK-DAG: linalg.fill ins(%{{.*}} : i32) outs(%{{.*}} : tensor<8x8x4xi32>) -> tensor<8x8x4xi32> - -// CHECK-DAG: arith.constant 0 : i64 -// CHECK-DAG: tensor.empty() : tensor<8x8x4xi64> -// CHECK-DAG: linalg.fill ins(%{{.*}} : i64) outs(%{{.*}} : tensor<8x8x4xi64>) -> tensor<8x8x4xi64> - -// Shared materialization pattern (same for all) -// CHECK: memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [8, 8, 4] +// RUN: triton-adapter-opt --triton-linearize --discrete-mask-access-conversion --triton-to-annotation --triton-to-unstructure --triton-to-hivm --triton-to-hfusion --triton-to-llvm --bubble-up-operation '--triton-to-linalg=global-kernel=false named-ops=True enable-nd2nz-on-vector=False' --split-input-file %s | FileCheck %s + +// === zeroslike === +// === i8 u8 version === +module { + tt.func public @fn_npu_u8( + %arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32} + ) { + %cst = arith.constant dense<0> : tensor<8x8x4xi8> + %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> + %3 = tt.expand_dims %2 {axis = 2 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> + %4 = tt.expand_dims %0 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> + %5 = tt.expand_dims %4 {axis = 2 : i32} : tensor<1x8xi32> -> tensor<1x8x1xi32> + %6 = tt.expand_dims %1 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> + %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<1x4xi32> -> tensor<1x1x4xi32> + %8 = tt.broadcast %3 : tensor<8x1x1xi32> -> tensor<8x8x1xi32> + %9 = tt.broadcast %5 : tensor<1x8x1xi32> -> tensor<8x8x1xi32> + %10 = arith.addi %8, %9 : tensor<8x8x1xi32> + %11 = tt.broadcast %10 : tensor<8x8x1xi32> -> tensor<8x8x4xi32> + %12 = tt.broadcast %7 : tensor<1x1x4xi32> -> tensor<8x8x4xi32> + %13 = arith.addi %11, %12 : tensor<8x8x4xi32> + %14 = tt.splat %arg0 : !tt.ptr -> tensor<8x8x4x!tt.ptr> + %15 = tt.addptr %14, %13 : tensor<8x8x4x!tt.ptr>, tensor<8x8x4xi32> + tt.store %15, %cst : tensor<8x8x4x!tt.ptr> + tt.return + } +} + +// === i16 u16 version === +module { + tt.func public @fn_npu_u16( + %arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32} + ) { + %cst = arith.constant dense<0> : tensor<8x8x4xi16> + %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> + %3 = tt.expand_dims %2 {axis = 2 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> + %4 = tt.expand_dims %0 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> + %5 = tt.expand_dims %4 {axis = 2 : i32} : tensor<1x8xi32> -> tensor<1x8x1xi32> + %6 = tt.expand_dims %1 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> + %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<1x4xi32> -> tensor<1x1x4xi32> + %8 = tt.broadcast %3 : tensor<8x1x1xi32> -> tensor<8x8x1xi32> + %9 = tt.broadcast %5 : tensor<1x8x1xi32> -> tensor<8x8x1xi32> + %10 = arith.addi %8, %9 : tensor<8x8x1xi32> + %11 = tt.broadcast %10 : tensor<8x8x1xi32> -> tensor<8x8x4xi32> + %12 = tt.broadcast %7 : tensor<1x1x4xi32> -> tensor<8x8x4xi32> + %13 = arith.addi %11, %12 : tensor<8x8x4xi32> + %14 = tt.splat %arg0 : !tt.ptr -> tensor<8x8x4x!tt.ptr> + %15 = tt.addptr %14, %13 : tensor<8x8x4x!tt.ptr>, tensor<8x8x4xi32> + tt.store %15, %cst : tensor<8x8x4x!tt.ptr> + tt.return + } +} + +// === i32 u32 version === +module { + tt.func public @fn_npu_i32( + %arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32} + ) { + %cst = arith.constant dense<0> : tensor<8x8x4xi32> + %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> + %3 = tt.expand_dims %2 {axis = 2 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> + %4 = tt.expand_dims %0 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> + %5 = tt.expand_dims %4 {axis = 2 : i32} : tensor<1x8xi32> -> tensor<1x8x1xi32> + %6 = tt.expand_dims %1 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> + %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<1x4xi32> -> tensor<1x1x4xi32> + %8 = tt.broadcast %3 : tensor<8x1x1xi32> -> tensor<8x8x1xi32> + %9 = tt.broadcast %5 : tensor<1x8x1xi32> -> tensor<8x8x1xi32> + %10 = arith.addi %8, %9 : tensor<8x8x1xi32> + %11 = tt.broadcast %10 : tensor<8x8x1xi32> -> tensor<8x8x4xi32> + %12 = tt.broadcast %7 : tensor<1x1x4xi32> -> tensor<8x8x4xi32> + %13 = arith.addi %11, %12 : tensor<8x8x4xi32> + %14 = tt.splat %arg0 : !tt.ptr -> tensor<8x8x4x!tt.ptr> + %15 = tt.addptr %14, %13 : tensor<8x8x4x!tt.ptr>, tensor<8x8x4xi32> + tt.store %15, %cst : tensor<8x8x4x!tt.ptr> + tt.return + } +} + +// === i64 u64 version === +module { + tt.func public @fn_npu_i64( + %arg0: !tt.ptr {tt.divisibility = 16 : i32}, + %arg1: !tt.ptr {tt.divisibility = 16 : i32} + ) { + %cst = arith.constant dense<0> : tensor<8x8x4xi64> + %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32> + %1 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32> + %3 = tt.expand_dims %2 {axis = 2 : i32} : tensor<8x1xi32> -> tensor<8x1x1xi32> + %4 = tt.expand_dims %0 {axis = 0 : i32} : tensor<8xi32> -> tensor<1x8xi32> + %5 = tt.expand_dims %4 {axis = 2 : i32} : tensor<1x8xi32> -> tensor<1x8x1xi32> + %6 = tt.expand_dims %1 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32> + %7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<1x4xi32> -> tensor<1x1x4xi32> + %8 = tt.broadcast %3 : tensor<8x1x1xi32> -> tensor<8x8x1xi32> + %9 = tt.broadcast %5 : tensor<1x8x1xi32> -> tensor<8x8x1xi32> + %10 = arith.addi %8, %9 : tensor<8x8x1xi32> + %11 = tt.broadcast %10 : tensor<8x8x1xi32> -> tensor<8x8x4xi32> + %12 = tt.broadcast %7 : tensor<1x1x4xi32> -> tensor<8x8x4xi32> + %13 = arith.addi %11, %12 : tensor<8x8x4xi32> + %14 = tt.splat %arg0 : !tt.ptr -> tensor<8x8x4x!tt.ptr> + %15 = tt.addptr %14, %13 : tensor<8x8x4x!tt.ptr>, tensor<8x8x4xi32> + tt.store %15, %cst : tensor<8x8x4x!tt.ptr> + tt.return + } +} + +// ===== CHECKS ===== +// CHECK-DAG: arith.constant 0 : i8 +// CHECK-DAG: tensor.empty() : tensor<8x8x4xi8> +// CHECK-DAG: linalg.fill ins(%{{.*}} : i8) outs(%{{.*}} : tensor<8x8x4xi8>) -> tensor<8x8x4xi8> + +// CHECK-DAG: arith.constant 0 : i16 +// CHECK-DAG: tensor.empty() : tensor<8x8x4xi16> +// CHECK-DAG: linalg.fill ins(%{{.*}} : i16) outs(%{{.*}} : tensor<8x8x4xi16>) -> tensor<8x8x4xi16> + +// CHECK-DAG: arith.constant 0 : i32 +// CHECK-DAG: tensor.empty() : tensor<8x8x4xi32> +// CHECK-DAG: linalg.fill ins(%{{.*}} : i32) outs(%{{.*}} : tensor<8x8x4xi32>) -> tensor<8x8x4xi32> + +// CHECK-DAG: arith.constant 0 : i64 +// CHECK-DAG: tensor.empty() : tensor<8x8x4xi64> +// CHECK-DAG: linalg.fill ins(%{{.*}} : i64) outs(%{{.*}} : tensor<8x8x4xi64>) -> tensor<8x8x4xi64> + +// Shared materialization pattern (same for all) +// CHECK: memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [8, 8, 4] // CHECK: bufferization.materialize_in_destination %{{.*}} in writable %{{.*}} diff --git a/third_party/ascend/test/sglang/v0.4.8/test_state_get_lse.py b/third_party/ascend/test/sglang/v0.4.8/test_state_get_lse.py index bde07b01b..3a833563f 100644 --- a/third_party/ascend/test/sglang/v0.4.8/test_state_get_lse.py +++ b/third_party/ascend/test/sglang/v0.4.8/test_state_get_lse.py @@ -1,50 +1,50 @@ -import sys -import pytest -import triton -import torch -import triton.language as tl - -sys.path.append("..") -import test_common - - -# source: sgl-kernel/tests/test_merge_state.py -@triton.jit -def state_get_lse(o, m, d): - return m + tl.log2(d) - - -@triton.jit -def _test_state_get_lse_kernel( - m_ptr, - d_ptr, - out_ptr, - n_elements: tl.constexpr, -): - pid = tl.program_id(0) - mask = pid < n_elements - m = tl.load(m_ptr + pid, mask=mask) - d = tl.load(d_ptr + pid, mask=mask) - lse = state_get_lse(None, m, d) - tl.store(out_ptr + pid, lse, mask=mask) - - -def test_context_fwd_kernel(ptfile_path): - try: - data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) - except Exception as e: - pytest.fail(f"load file {ptfile_path} failed: {str(e)}") - - # ptfile format: - # [input_data] (dict): - # [gpu_output] (dict): - # [grid] : - input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') - - _test_state_get_lse_kernel[data["grid"]](**input_data) - - # compare the results of GPU and NPU. - try: - test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') - except ValueError as e: - pytest.fail(f"The testcase failed") +import sys +import pytest +import triton +import torch +import triton.language as tl + +sys.path.append("..") +import test_common + + +# source: sgl-kernel/tests/test_merge_state.py +@triton.jit +def state_get_lse(o, m, d): + return m + tl.log2(d) + + +@triton.jit +def _test_state_get_lse_kernel( + m_ptr, + d_ptr, + out_ptr, + n_elements: tl.constexpr, +): + pid = tl.program_id(0) + mask = pid < n_elements + m = tl.load(m_ptr + pid, mask=mask) + d = tl.load(d_ptr + pid, mask=mask) + lse = state_get_lse(None, m, d) + tl.store(out_ptr + pid, lse, mask=mask) + + +def test_context_fwd_kernel(ptfile_path): + try: + data = torch.load(ptfile_path, map_location=torch.device('cpu'), weights_only=False) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {str(e)}") + + # ptfile format: + # [input_data] (dict): + # [gpu_output] (dict): + # [grid] : + input_data = test_common.convert_tensor_with_device_type(data["input_data"], device_type='npu') + + _test_state_get_lse_kernel[data["grid"]](**input_data) + + # compare the results of GPU and NPU. + try: + test_common.compare_data_precision(data["gpu_output"], input_data, device_type='cpu') + except ValueError as e: + pytest.fail(f"The testcase failed") diff --git a/third_party/ascend/triton_ascend.cc b/third_party/ascend/triton_ascend.cc new file mode 100644 index 000000000..ef70fca4a --- /dev/null +++ b/third_party/ascend/triton_ascend.cc @@ -0,0 +1,385 @@ +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Tools/mlir-opt/MlirOptMain.h" + +#include "incubated/Conversion/DiscreteMaskAccessConversion/Passes.h" +#include "incubated/Conversion/TritonToAnnotation/Passes.h" +#include "incubated/Conversion/TritonToLinalgIncubated/Passes.h" +#include "incubated/Conversion/TritonToStructuredIncubated/Passes.h" +#include "incubated/Conversion/TritonToUnstructureIncubated/Passes.h" +#include "npu/Conversion/TritonToHFusion/Passes.h" +#include "npu/Conversion/TritonToHIVM/Passes.h" +#include "npu/Conversion/TritonToLLVM/Passes.h" +#include "npu/Dialect/TritonAscend/IR/TritonAscendDialect.h" + +#include "ir.h" // TritonOpBuilder +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include + +namespace py = pybind11; +using namespace ir; +using namespace mlir; + +void init_triton_ascend_ir(py::module &&m) { + auto *builder_cls = ir::getBuilderClass(); + builder_cls + ->def("create_extract_scalar", + [](TritonOpBuilder &self, Value &src, + std::vector &indices) -> Value { + llvm::SmallVector arg_indices; + for (const auto &i : indices) { + auto iTy = i.getType(); + if (!iTy.isIndex()) { + auto v = self.create( + self.getBuilder().getIndexType(), i); + arg_indices.push_back(v); + } else { + arg_indices.push_back(i); + } + } + auto ret = self.create(src, arg_indices); + return ret; + }) + .def("create_extract_slice", + [](TritonOpBuilder &self, Value &ful, std::vector &offs_vec, + std::vector &sizs_vec, std::vector &strd_vec) -> Value { + llvm::SmallVector offsets; + for (const auto &o : offs_vec) { + auto oTy = o.getType(); + if (!oTy.isIndex()) { + auto v = self.create( + self.getBuilder().getIndexType(), o); + offsets.push_back(v); + } else { + offsets.push_back(o); + } + } + llvm::SmallVector sizes; + llvm::SmallVector retSizes; + for (const auto &s : sizs_vec) { + auto v = self.create(s); + sizes.push_back(v); + retSizes.push_back(s); + } + llvm::SmallVector strides; + for (const auto &s : strd_vec) { + auto v = self.create(s); + strides.push_back(v); + } + auto retTy = RankedTensorType::get( + retSizes, + cast(ful.getType()).getElementType()); + + return self.create(retTy, ful, offsets, + sizes, strides); + }) + .def("create_insert_slice", + [](TritonOpBuilder &self, Value &ful, Value &sub, + std::vector &offs_vec, std::vector &sizs_vec, + std::vector &strd_vec) -> Value { + llvm::SmallVector offsets; + for (const auto &o : offs_vec) { + auto oTy = o.getType(); + if (!oTy.isIndex()) { + auto v = self.create( + self.getBuilder().getIndexType(), o); + offsets.push_back(v); + } else { + offsets.push_back(o); + } + } + llvm::SmallVector sizes; + llvm::SmallVector retSizes; + for (const auto &s : sizs_vec) { + auto v = self.create(s); + sizes.push_back(v); + retSizes.push_back(s); + } + llvm::SmallVector strides; + for (const auto &s : strd_vec) { + auto v = self.create(s); + strides.push_back(v); + } + auto retTy = RankedTensorType::get( + retSizes, + cast(ful.getType()).getElementType()); + auto ret = self.create(sub, ful, offsets, + sizes, strides); + return ret; + }) + .def("create_custom_op_for_inter_core_sync", + [](TritonOpBuilder &self, std::string &op_name, + std::string &mode_or_sender, int id) -> void { + auto args = self.getBuilder().getArrayAttr( + {self.getBuilder().getStringAttr(mode_or_sender), + self.getBuilder().getI32IntegerAttr(id)}); + self.create(op_name, args, ValueRange()); + }) + .def("create_index_select_simd", + [](TritonOpBuilder &self, Value &src, Value &index, int32_t dim, + std::vector &srcShape, std::vector &srcOffset, + std::vector &readShape, + std::vector &returnShape) -> Value { + auto &builder = self.getBuilder(); + auto loc = self.getLastLoc(); + + // Get element type from source pointer + Type elemType; + if (auto ptrTy = dyn_cast(src.getType())) { + elemType = ptrTy.getPointeeType(); + } else { + llvm::report_fatal_error( + "index_select_simd: src must be pointer type"); + } + + // Create return tensor type + llvm::SmallVector retShape; + for (const auto &s : returnShape) { + retShape.push_back(s); + } + auto retTensorType = RankedTensorType::get(retShape, elemType); + + // Convert srcShape and srcOffset values to index type if needed + llvm::SmallVector srcShapeIndex; + for (auto val : srcShape) { + if (!val.getType().isIndex()) { + val = self.create(builder.getIndexType(), + val); + } + srcShapeIndex.push_back(val); + } + + llvm::SmallVector srcOffsetIndex; + for (auto val : srcOffset) { + if (!val.getType().isIndex()) { + val = self.create(builder.getIndexType(), + val); + } + srcOffsetIndex.push_back(val); + } + + // Create attributes + auto dimAttr = builder.getI32IntegerAttr(dim); + auto readShapeAttr = builder.getDenseI32ArrayAttr(readShape); + + // Create the IndexSelectSimdOp + // Parameter order must match TritonOps.td definition: + // src, index, dim, src_shape, src_offset, read_shape + auto indexSelectSimdOp = + builder.create( + loc, + retTensorType, // result type + src, // src pointer + index, // index tensor + dimAttr, // dim attribute + srcShapeIndex, // src_shape (variadic, index type) + srcOffsetIndex, // src_offset (variadic, index type) + readShapeAttr // read_shape attribute + ); + + return indexSelectSimdOp.getResult(); + }) + .def("create_embedding_gather", + [](TritonOpBuilder &self, Value &src, Value &idx, + const int64_t bound, const int64_t blksiz, + std::vector &offsets, + std::vector &numels) -> Value { + auto elemTy = cast(src.getType()).getPointeeType(); + auto idxTy = cast(idx.getType()); + auto idxShape = idxTy.getShape(); + std::vector retShape(idxShape.begin(), idxShape.end()); + retShape.push_back(blksiz); + auto resType = RankedTensorType::get(retShape, elemTy); + auto idxBitWidth = idxTy.getElementType().getIntOrFloatBitWidth(); + auto bound_val = + self.create(bound, idxBitWidth); + auto blksiz_val = + self.create(blksiz, idxBitWidth); + + return self.create( + resType, src, idx, bound_val, blksiz_val, offsets, numels); + }) + .def("create_index_put", + [](TritonOpBuilder &self, Value &ptr, Value &index, Value &value, + const int32_t dim, const int64_t indexBoundary, + std::vector &endOffset, std::vector &startOffset, + std::vector &dstStride) -> void { + // dim need to be i32 type + auto dimI32Ty = self.getBuilder().getI32Type(); + auto dim_val = self.create(dim, dimI32Ty); + // indexBoundary need to be i64 type + auto BoundI64Ty = self.getBuilder().getI64Type(); + auto bound_val = + self.create(indexBoundary, BoundI64Ty); + + self.create(ptr, index, value, dim_val, + bound_val, endOffset, + startOffset, dstStride); + }) + .def("create_gather_out_to_ub", + [](TritonOpBuilder &self, Value &src, Value &index, + const int64_t indexBoundary, const int32_t dim, + std::vector &srcStride, std::vector &endOffset, + std::vector &startOffset, + std::optional &other) -> Value { + auto elemTy = cast(src.getType()).getPointeeType(); + auto idxTy = cast(index.getType()); + auto idxShape = idxTy.getShape(); + std::vector retShape(idxShape.begin(), idxShape.end()); + auto resType = RankedTensorType::get(retShape, elemTy); + + // indexBoundary need to be i64 type + auto BoundI64Ty = self.getBuilder().getI64Type(); + auto bound_val = + self.create(indexBoundary, BoundI64Ty); + // dim need to be i32 type + auto dimI32Ty = self.getBuilder().getI32Type(); + auto dim_val = self.create(dim, dimI32Ty); + return self.create( + resType, src, index, bound_val, dim_val, srcStride, endOffset, + startOffset, other.value_or(Value())); + }) + .def("create_scatter_ub_to_out", + [](TritonOpBuilder &self, Value &ptr, Value &value, Value &index, + const int64_t indexBoundary, const int32_t dim, + std::vector &dstStride, std::vector &endOffset, + std::vector &startOffset) -> void { + auto idxTy = cast(index.getType()); + + // indexBoundary need to be i64 type + auto BoundI64Ty = self.getBuilder().getI64Type(); + auto bound_val = + self.create(indexBoundary, BoundI64Ty); + // dim need to be i32 type + auto dimI32Ty = self.getBuilder().getI32Type(); + auto dim_val = self.create(dim, dimI32Ty); + + self.create( + ptr, value, index, bound_val, dim_val, dstStride, endOffset, + startOffset); + }) + // Add sort + .def("create_sort", + [](TritonOpBuilder &self, Value src, int64_t dim, + bool descending) -> Value { + auto &builder = self.getBuilder(); + auto loc = self.getLastLoc(); + + auto dimAttr = builder.getI64IntegerAttr(dim); + auto descendingAttr = builder.getBoolAttr(descending); + + auto op = builder.create(loc, src, dimAttr, + descendingAttr); + + return op->getResult(0); + }) + // Add flip + .def("create_flip", + [](TritonOpBuilder &self, Value src, int64_t dim) -> Value { + auto &builder = self.getBuilder(); + auto loc = self.getLastLoc(); + + auto dimAttr = builder.getI64IntegerAttr(dim); + + auto op = + builder.create(loc, src, dimAttr); + + return op->getResult(0); + }) + .def("create_tanh", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + // Add an annotation + .def("create_annotation", + [](TritonOpBuilder &self, Value &ptr, const std::string &attrKey, + Attribute &attrVal) { + auto annotationOp = self.create(ptr); + annotationOp->setAttr(self.getBuilder().getStringAttr(attrKey), + attrVal); + }); +} + +void init_triton_ascend_passes_ttir(py::module &&m) { + m.def("add_triton_to_structure_incubated", [](mlir::PassManager &pm, + bool enableMaskFallbackConversion, + bool optimizeDynamicOffset, + bool compileOn91095) { + pm.addPass(mlir::triton::createTritonToStructuredIncubatedPass( + enableMaskFallbackConversion, optimizeDynamicOffset, compileOn91095)); + }); + + m.def("add_triton_to_annotation", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::createTritonToAnnotationPass()); + }); + + m.def("add_triton_to_linalg_incubated", + [](mlir::PassManager &pm, bool globalKernel, bool namedOps, + bool enableNd2nzOnVector, bool enableSelectAnalysis, + bool compileOn91095) { + pm.addPass(mlir::triton::Incubated::createTritonToLinalgIncubatedPass( + globalKernel, namedOps, enableNd2nzOnVector, enableSelectAnalysis, + compileOn91095)); + }); + + m.def("add_triton_to_unstructure_incubated", + [](mlir::PassManager &pm, bool compileOn91095, bool forceSimtTemplate) { + TritonToUnstructureIncubatedOptions opts; + opts.compileOn91095 = compileOn91095; + opts.forceSimtTemplate = forceSimtTemplate; + pm.addPass(mlir::triton::createTritonToUnstructureIncubatedPass(opts)); + }); + + m.def("add_triton_to_hfusion", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::createTritonToHFusionPass()); + }); + + m.def("add_discrete_mask_access_conversion", + [](mlir::PassManager &pm, bool compileOn91095, bool forceSimtTemplate) { + DiscreteMaskAccessConversionOptions opts; + opts.compileOn91095 = compileOn91095; + opts.forceSimtTemplate = forceSimtTemplate; + pm.addPass( + mlir::triton::createDiscreteMaskAccessConversionPass(opts)); + }); + + m.def("add_triton_to_hivm", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::createTritonToHIVMPass()); + }); + + m.def("add_triton_to_llvm", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::createTritonToLLVMPass()); + }); + + m.def("add_bubble_up_operation", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::createBubbleUpOperationPass()); + }); +} + +// Forward declaration for ascend_ir bindings (defined in ascend_ir.cc) +void init_ascend_ir(py::module &&m); + +void init_triton_ascend(py::module &&m) { + auto passes = m.def_submodule("passes"); + // load dialects + m.def("load_dialects", [](mlir::MLIRContext &context) { + mlir::DialectRegistry registry; + registry.insert(); + context.appendDialectRegistry(registry); + context.loadAllAvailableDialects(); + }); + + init_triton_ascend_passes_ttir(passes.def_submodule("ttir")); + init_triton_ascend_ir(m.def_submodule("ascend_ir")); + + // Initialize ascend IR bindings (ascendnpu_ir_builder, scope/hivm dialects) + init_ascend_ir(m.def_submodule("ir")); +} diff --git a/third_party/ascend/triton_ascend.cpp b/third_party/ascend/triton_ascend.cpp deleted file mode 100644 index 4077908ab..000000000 --- a/third_party/ascend/triton_ascend.cpp +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. - */ -#include "incubated/Conversion/DiscreteMaskAccessConversion/DiscreteMaskAccessConversionPass.h" -#include "incubated/Conversion/TritonLinearize/TritonLinearize.h" -#include "incubated/Conversion/TritonToAnnotation/TritonToAnnotation.h" -#include "incubated/Conversion/TritonToLinalgIncubated/TritonToLinalgIncubatedPass.h" -#include "incubated/Conversion/TritonToUnstructureIncubated/Passes.h" -#include "mlir/Pass/PassManager.h" -#include "npu/Conversion/TritonToHFusion/Passes.h" -#include "npu/Conversion/TritonToHIVM/Passes.h" -#include "npu/Conversion/TritonToLLVM/TritonToLLVM.h" -#include "passes.h" -#include "triton-shared/Conversion/TritonToLinalgExperimental/TritonToLinalgExperimental.h" -#include "npu/Dialect/TritonAscend/IR/TritonAscendDialect.h" - -#define PY_SSIZE_T_CLEAN -#include -namespace py = pybind11; - -void init_triton_ascend_passes_convert(py::module &&m) { - - ADD_PASS_WRAPPER_0("add_triton_to_linalg_pipeline", - mlir::triton::createTritonToLinalgExperimentalPass); - ADD_PASS_WRAPPER_0("add_triton_linearize", - mlir::triton::createTritonLinearizePass); - ADD_PASS_WRAPPER_0("add_triton_to_annotation", - mlir::triton::createTritonToAnnotationPass); - ADD_PASS_WRAPPER_0("add_triton_to_hivm", - mlir::triton::createTritonToHIVMPass); - ADD_PASS_WRAPPER_0("add_triton_to_hfusion", - mlir::triton::createTritonToHFusionPass); - ADD_PASS_WRAPPER_0("add_triton_to_llvm", - mlir::triton::createTritonToLLVMPass); - m.def( - "add_triton_discretemaskaccessconversion", - [](mlir::PassManager &pm, bool compile_on_910_95, - bool force_simt_template) { - DiscreteMaskAccessConversionOptions options; - options.compileOn91095 = compile_on_910_95; - options.forceSimtTemplate = force_simt_template; - pm.addPass( - mlir::triton::createDiscreteMaskAccessConversionPass(options)); - }, - py::arg("pm"), py::arg("compile_on_910_95"), - py::arg("force_simt_template")); - m.def("add_triton_to_unstructure_incubated", [](mlir::PassManager &pm, - bool compileOn91095, bool forceSimtTemplate) { - TritonToUnstructureIncubatedOptions opts; - opts.compileOn91095 = compileOn91095; - opts.forceSimtTemplate = forceSimtTemplate; - pm.addPass(mlir::triton::createTritonToUnstructureIncubatedPass(opts));}); - m.def("add_bubble_up_operation", [](mlir::PassManager &pm) { - pm.addPass(mlir::triton::createBubbleUpOperationPass());}); - m.def( - "add_triton_to_linalg_incubated", - [](mlir::PassManager &pm, bool global_kernel, bool named_ops, - bool enable_nd2nz_on_vector, bool enable_select_analysis, - bool compile_on_910_95) { - pm.addPass(mlir::triton::Incubated::createTritonToLinalgIncubatedPass( - global_kernel, named_ops, enable_nd2nz_on_vector, - enable_select_analysis, compile_on_910_95)); - }, - py::arg("pm"), py::arg("global_kernel"), py::arg("named_ops"), - py::arg("enable_nd2nz_on_vector"), py::arg("enable_select_analysis"), - py::arg("compile_on_910_95")); -} - -// register ascend passes to triton -void init_triton_ascend(py::module &&m) { - auto passes = m.def_submodule("passes"); - init_triton_ascend_passes_convert(passes.def_submodule("convert")); -} diff --git a/third_party/ascend/examples/tutorials/01-vector-add.py b/third_party/ascend/tutorials/01-vector-add.py similarity index 98% rename from third_party/ascend/examples/tutorials/01-vector-add.py rename to third_party/ascend/tutorials/01-vector-add.py index 0520a3d08..e59935f45 100644 --- a/third_party/ascend/examples/tutorials/01-vector-add.py +++ b/third_party/ascend/tutorials/01-vector-add.py @@ -1,4 +1,6 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# Copyright 2018-2020 Philippe Tillet +# Copyright 2020-2022 OpenAI # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal diff --git a/third_party/ascend/examples/tutorials/02-fused-softmax.py b/third_party/ascend/tutorials/02-fused-softmax.py similarity index 100% rename from third_party/ascend/examples/tutorials/02-fused-softmax.py rename to third_party/ascend/tutorials/02-fused-softmax.py diff --git a/third_party/ascend/examples/tutorials/03-layer-norm.py b/third_party/ascend/tutorials/03-layer-norm.py similarity index 100% rename from third_party/ascend/examples/tutorials/03-layer-norm.py rename to third_party/ascend/tutorials/03-layer-norm.py diff --git a/third_party/ascend/examples/tutorials/04-fused-attention.py b/third_party/ascend/tutorials/04-fused-attention.py similarity index 100% rename from third_party/ascend/examples/tutorials/04-fused-attention.py rename to third_party/ascend/tutorials/04-fused-attention.py diff --git a/third_party/ascend/examples/tutorials/05-matrix-multiplication.py b/third_party/ascend/tutorials/05-matrix-multiplication.py similarity index 100% rename from third_party/ascend/examples/tutorials/05-matrix-multiplication.py rename to third_party/ascend/tutorials/05-matrix-multiplication.py diff --git a/third_party/ascend/examples/tutorials/06-demo-autotune.py b/third_party/ascend/tutorials/06-demo-autotune.py similarity index 100% rename from third_party/ascend/examples/tutorials/06-demo-autotune.py rename to third_party/ascend/tutorials/06-demo-autotune.py diff --git a/third_party/ascend/examples/tutorials/07-profiler.py b/third_party/ascend/tutorials/07-profiler.py similarity index 100% rename from third_party/ascend/examples/tutorials/07-profiler.py rename to third_party/ascend/tutorials/07-profiler.py diff --git a/third_party/ascend/examples/tutorials/08-demo-libentry.py b/third_party/ascend/tutorials/08-demo-libentry.py similarity index 100% rename from third_party/ascend/examples/tutorials/08-demo-libentry.py rename to third_party/ascend/tutorials/08-demo-libentry.py diff --git a/third_party/ascend/examples/tutorials/09-gather.py b/third_party/ascend/tutorials/09-gather.py similarity index 100% rename from third_party/ascend/examples/tutorials/09-gather.py rename to third_party/ascend/tutorials/09-gather.py diff --git a/third_party/ascend/examples/tutorials/10-gather_sorted.py b/third_party/ascend/tutorials/10-gather_sorted.py similarity index 100% rename from third_party/ascend/examples/tutorials/10-gather_sorted.py rename to third_party/ascend/tutorials/10-gather_sorted.py diff --git a/third_party/ascend/examples/tutorials/11-rab_time.py b/third_party/ascend/tutorials/11-rab_time.py similarity index 100% rename from third_party/ascend/examples/tutorials/11-rab_time.py rename to third_party/ascend/tutorials/11-rab_time.py diff --git a/third_party/ascend/examples/tutorials/12-hstu_attention.py b/third_party/ascend/tutorials/12-hstu_attention.py similarity index 100% rename from third_party/ascend/examples/tutorials/12-hstu_attention.py rename to third_party/ascend/tutorials/12-hstu_attention.py diff --git a/third_party/ascend/examples/tutorials/13-matrix-multiplication-optimized.py b/third_party/ascend/tutorials/13-matrix-multiplication-optimized.py similarity index 100% rename from third_party/ascend/examples/tutorials/13-matrix-multiplication-optimized.py rename to third_party/ascend/tutorials/13-matrix-multiplication-optimized.py diff --git a/third_party/ascend/examples/tutorials/14-accuracy-comparison.py b/third_party/ascend/tutorials/14-accuracy-comparison.py similarity index 100% rename from third_party/ascend/examples/tutorials/14-accuracy-comparison.py rename to third_party/ascend/tutorials/14-accuracy-comparison.py diff --git a/third_party/ascend/examples/tutorials/15-embedding_gather_demo.py b/third_party/ascend/tutorials/15-embedding_gather_demo.py similarity index 100% rename from third_party/ascend/examples/tutorials/15-embedding_gather_demo.py rename to third_party/ascend/tutorials/15-embedding_gather_demo.py diff --git a/third_party/ascend/unittest/autotune_ut/test_common.py b/third_party/ascend/unittest/autotune_ut/test_common.py new file mode 100644 index 000000000..d512d3358 --- /dev/null +++ b/third_party/ascend/unittest/autotune_ut/test_common.py @@ -0,0 +1,86 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import unittest.mock as mock +import pytest + + +def MockAutoTilingTunerRun(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + + # generate key + all_args = {**self.nargs, **kwargs} + try: + self._autoparse_axis_params(all_args) + except ValueError as e: + if "Missing required arguments" in str(e): + pass + else: + raise + return { + "keys": self.keys, + "split_params": self.split_params, + "tiling_params": self.tiling_params, + "low_dim_axes": self.low_dim_axes, + "reduction_axes": self.reduction_axes, + } + + +def check_axes_parse_res(act: dict, ref: dict): + """ + Compare two axes parse results that may use different symbolic axis names, + but map to the same semantic dimensions via the 'keys' field. + """ + ref_keys = ref["keys"] + act_keys = act["keys"] + + assert set(ref_keys.values()) == set(act_keys.values()), \ + f"Semantic dimensions mismatch: ref={set(ref_keys.values())}, act={set(act_keys.values())}" + + def normalize_param_dict(param_dict: dict, sym_to_sem: dict) -> dict: + """Convert {symbol: value} -> {semantic: value}""" + return {sym_to_sem[sym]: value for sym, value in param_dict.items()} + + ref_split = normalize_param_dict(ref["split_params"], ref_keys) + act_split = normalize_param_dict(act["split_params"], act_keys) + + ref_tiling = normalize_param_dict(ref["tiling_params"], ref_keys) + act_tiling = normalize_param_dict(act["tiling_params"], act_keys) + + def normalize_axis_list(axis_list: list, sym_to_sem: dict) -> list: + return sorted(sym_to_sem[sym] for sym in axis_list) + + ref_low = normalize_axis_list(ref["low_dim_axes"], ref_keys) + act_low = normalize_axis_list(act["low_dim_axes"], act_keys) + + ref_red = normalize_axis_list(ref["reduction_axes"], ref_keys) + act_red = normalize_axis_list(act["reduction_axes"], act_keys) + + # Compare normalized structures + assert ref_split == act_split, f"split_params mismatch: {ref_split} vs {act_split}" + assert ref_tiling == act_tiling, f"tiling_params mismatch: {ref_tiling} vs {act_tiling}" + assert ref_low == act_low, f"low_dim_axes mismatch: {ref_low} vs {act_low}" + assert ref_red == act_red, f"reduction_axes mismatch: {ref_red} vs {act_red}" + + +@pytest.fixture +def mock_autotuner(): + with mock.patch("triton.backends.ascend.runtime.autotuner.AutoTilingTuner.run", new=MockAutoTilingTunerRun): + yield diff --git a/third_party/ascend/unittest/autotune_ut/test_reduction_axes_parse.py b/third_party/ascend/unittest/autotune_ut/test_reduction_axes_parse.py new file mode 100644 index 000000000..2893bf347 --- /dev/null +++ b/third_party/ascend/unittest/autotune_ut/test_reduction_axes_parse.py @@ -0,0 +1,65 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2026. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import triton.language as tl +from test_common import check_axes_parse_res, mock_autotuner + + +def test_triton_max_last_dim_case(mock_autotuner): + import triton.backends.ascend.runtime + + @triton.autotune(configs=[], key=["x0_numel", "r1_numel"]) + @triton.jit + def triton_max_last_dim( + in_ptr0, + out_ptr0, + x0_numel, + r1_numel, + X0BLOCK: tl.constexpr, + X0BLOCK_SUB: tl.constexpr, + R1BLOCK_SUB: tl.constexpr, + ): + x0_offset = tl.program_id(0) * X0BLOCK + base_x0 = tl.arange(0, X0BLOCK_SUB) + loops_x0 = (X0BLOCK + X0BLOCK_SUB - 1) // X0BLOCK_SUB + base_r1 = tl.arange(0, R1BLOCK_SUB) + loops_r1 = (r1_numel + R1BLOCK_SUB - 1) // R1BLOCK_SUB + for loop_x0 in range(loops_x0): + x0 = x0_offset + (loop_x0 * X0BLOCK_SUB) + base_x0[:, None] + x0_mask = x0 < min(X0BLOCK + x0_offset, x0_numel) + block_val = tl.full([X0BLOCK_SUB, R1BLOCK_SUB], float("-inf"), tl.float32) + for loop_r1 in range(loops_r1): + r1 = (loop_r1 * R1BLOCK_SUB) + base_r1[None, :] + r1_mask = r1 < r1_numel + tmp = tl.load(in_ptr0 + (r1 + r1_numel * x0), r1_mask & x0_mask, other=float("-inf")) + block_val = tl.maximum(block_val, tmp) + block_res = tl.max(block_val, axis=1)[:, None] + tl.store(out_ptr0 + x0, block_res, x0_mask) + + ref_res = { + "keys": {"x": "x0_numel", "ry": "r1_numel"}, + "split_params": {"x": "X0BLOCK"}, + "tiling_params": {"x": "X0BLOCK_SUB", "ry": "R1BLOCK_SUB"}, + "low_dim_axes": ["ry"], + "reduction_axes": ["ry"], + } + act_res = triton_max_last_dim[(1, )]() + + check_axes_parse_res(act_res, ref_res) diff --git a/third_party/ascend/examples/generalization_cases/conftest.py b/third_party/ascend/unittest/conftest.py similarity index 88% rename from third_party/ascend/examples/generalization_cases/conftest.py rename to third_party/ascend/unittest/conftest.py index 633db8b1f..a284dcdec 100644 --- a/third_party/ascend/examples/generalization_cases/conftest.py +++ b/third_party/ascend/unittest/conftest.py @@ -31,3 +31,8 @@ def assign_npu(worker_id): idx = int(worker_id.replace("gw", "")) npu_id = idx % npu_count torch.npu.set_device(npu_id) + + +def pytest_addoption(parser): + parser.addoption("--kernel", action="append", default=None, + help="run only specified kernel(s); can be supplied multiple times") diff --git a/third_party/ascend/examples/generalization_cases/acc_util.py b/third_party/ascend/unittest/generalization_cases/acc_util.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/acc_util.py rename to third_party/ascend/unittest/generalization_cases/acc_util.py diff --git a/third_party/ascend/examples/conftest.py b/third_party/ascend/unittest/generalization_cases/conftest.py similarity index 100% rename from third_party/ascend/examples/conftest.py rename to third_party/ascend/unittest/generalization_cases/conftest.py diff --git a/third_party/ascend/examples/generalization_cases/test_abs.py b/third_party/ascend/unittest/generalization_cases/test_abs.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_abs.py rename to third_party/ascend/unittest/generalization_cases/test_abs.py diff --git a/third_party/ascend/examples/generalization_cases/test_advance.py b/third_party/ascend/unittest/generalization_cases/test_advance.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_advance.py rename to third_party/ascend/unittest/generalization_cases/test_advance.py diff --git a/third_party/ascend/examples/generalization_cases/test_and.py b/third_party/ascend/unittest/generalization_cases/test_and.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_and.py rename to third_party/ascend/unittest/generalization_cases/test_and.py diff --git a/third_party/ascend/examples/generalization_cases/test_argmax.py b/third_party/ascend/unittest/generalization_cases/test_argmax.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_argmax.py rename to third_party/ascend/unittest/generalization_cases/test_argmax.py diff --git a/third_party/ascend/examples/generalization_cases/test_argmin.py b/third_party/ascend/unittest/generalization_cases/test_argmin.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_argmin.py rename to third_party/ascend/unittest/generalization_cases/test_argmin.py diff --git a/third_party/ascend/examples/generalization_cases/test_associative_scan.py b/third_party/ascend/unittest/generalization_cases/test_associative_scan.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_associative_scan.py rename to third_party/ascend/unittest/generalization_cases/test_associative_scan.py diff --git a/third_party/ascend/examples/generalization_cases/test_atan.py b/third_party/ascend/unittest/generalization_cases/test_atan.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_atan.py rename to third_party/ascend/unittest/generalization_cases/test_atan.py diff --git a/third_party/ascend/examples/generalization_cases/test_atomic_add.py b/third_party/ascend/unittest/generalization_cases/test_atomic_add.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_atomic_add.py rename to third_party/ascend/unittest/generalization_cases/test_atomic_add.py diff --git a/third_party/ascend/examples/generalization_cases/test_atomic_and.py b/third_party/ascend/unittest/generalization_cases/test_atomic_and.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_atomic_and.py rename to third_party/ascend/unittest/generalization_cases/test_atomic_and.py diff --git a/third_party/ascend/examples/generalization_cases/test_atomic_cas.py b/third_party/ascend/unittest/generalization_cases/test_atomic_cas.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_atomic_cas.py rename to third_party/ascend/unittest/generalization_cases/test_atomic_cas.py diff --git a/third_party/ascend/examples/generalization_cases/test_atomic_max.py b/third_party/ascend/unittest/generalization_cases/test_atomic_max.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_atomic_max.py rename to third_party/ascend/unittest/generalization_cases/test_atomic_max.py diff --git a/third_party/ascend/examples/generalization_cases/test_atomic_min.py b/third_party/ascend/unittest/generalization_cases/test_atomic_min.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_atomic_min.py rename to third_party/ascend/unittest/generalization_cases/test_atomic_min.py diff --git a/third_party/ascend/examples/generalization_cases/test_atomic_or.py b/third_party/ascend/unittest/generalization_cases/test_atomic_or.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_atomic_or.py rename to third_party/ascend/unittest/generalization_cases/test_atomic_or.py diff --git a/third_party/ascend/examples/generalization_cases/test_atomic_xchg.py b/third_party/ascend/unittest/generalization_cases/test_atomic_xchg.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_atomic_xchg.py rename to third_party/ascend/unittest/generalization_cases/test_atomic_xchg.py diff --git a/third_party/ascend/examples/generalization_cases/test_atomic_xor.py b/third_party/ascend/unittest/generalization_cases/test_atomic_xor.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_atomic_xor.py rename to third_party/ascend/unittest/generalization_cases/test_atomic_xor.py diff --git a/third_party/ascend/examples/generalization_cases/test_broadcast.py b/third_party/ascend/unittest/generalization_cases/test_broadcast.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_broadcast.py rename to third_party/ascend/unittest/generalization_cases/test_broadcast.py diff --git a/third_party/ascend/examples/generalization_cases/test_broadcast_to.py b/third_party/ascend/unittest/generalization_cases/test_broadcast_to.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_broadcast_to.py rename to third_party/ascend/unittest/generalization_cases/test_broadcast_to.py diff --git a/third_party/ascend/examples/generalization_cases/test_cast.py b/third_party/ascend/unittest/generalization_cases/test_cast.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_cast.py rename to third_party/ascend/unittest/generalization_cases/test_cast.py diff --git a/third_party/ascend/examples/generalization_cases/test_cdiv.py b/third_party/ascend/unittest/generalization_cases/test_cdiv.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_cdiv.py rename to third_party/ascend/unittest/generalization_cases/test_cdiv.py diff --git a/third_party/ascend/examples/generalization_cases/test_ceil.py b/third_party/ascend/unittest/generalization_cases/test_ceil.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_ceil.py rename to third_party/ascend/unittest/generalization_cases/test_ceil.py diff --git a/third_party/ascend/examples/generalization_cases/test_common.py b/third_party/ascend/unittest/generalization_cases/test_common.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_common.py rename to third_party/ascend/unittest/generalization_cases/test_common.py diff --git a/third_party/ascend/examples/generalization_cases/test_cos.py b/third_party/ascend/unittest/generalization_cases/test_cos.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_cos.py rename to third_party/ascend/unittest/generalization_cases/test_cos.py diff --git a/third_party/ascend/examples/generalization_cases/test_count_dim0.py b/third_party/ascend/unittest/generalization_cases/test_count_dim0.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_count_dim0.py rename to third_party/ascend/unittest/generalization_cases/test_count_dim0.py diff --git a/third_party/ascend/examples/generalization_cases/test_count_dim1.py b/third_party/ascend/unittest/generalization_cases/test_count_dim1.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_count_dim1.py rename to third_party/ascend/unittest/generalization_cases/test_count_dim1.py diff --git a/third_party/ascend/examples/generalization_cases/test_cumprod.py b/third_party/ascend/unittest/generalization_cases/test_cumprod.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_cumprod.py rename to third_party/ascend/unittest/generalization_cases/test_cumprod.py diff --git a/third_party/ascend/examples/generalization_cases/test_cumsum.py b/third_party/ascend/unittest/generalization_cases/test_cumsum.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_cumsum.py rename to third_party/ascend/unittest/generalization_cases/test_cumsum.py diff --git a/third_party/ascend/examples/generalization_cases/test_debug_barrier.py b/third_party/ascend/unittest/generalization_cases/test_debug_barrier.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_debug_barrier.py rename to third_party/ascend/unittest/generalization_cases/test_debug_barrier.py diff --git a/third_party/ascend/examples/generalization_cases/test_device_print.py b/third_party/ascend/unittest/generalization_cases/test_device_print.py similarity index 96% rename from third_party/ascend/examples/generalization_cases/test_device_print.py rename to third_party/ascend/unittest/generalization_cases/test_device_print.py index 1b8e86df0..5421db96e 100644 --- a/third_party/ascend/examples/generalization_cases/test_device_print.py +++ b/third_party/ascend/unittest/generalization_cases/test_device_print.py @@ -1,106 +1,106 @@ -import torch -import torch_npu -import triton -import triton.language as tl -import pytest -import sys -import os -import subprocess -import tempfile -import textwrap - -os.environ["TRITON_DEVICE_PRINT"] = "1" -os.environ["TRITON_ENABLE_TASKQUEUE"] = "0" - -shape = (8, ) -XS = 8 -XVALS_INT = [ - 0, - torch.iinfo(torch.int8).min, - torch.iinfo(torch.int8).max, - torch.iinfo(torch.int16).min, - torch.iinfo(torch.int16).max, - torch.iinfo(torch.int32).min, - torch.iinfo(torch.int32).max, - torch.iinfo(torch.int32).max + 1 -] - - -@pytest.mark.parametrize('sigtype', ['int32', 'int64', 'int16', 'int8', 'float32', 'float16', 'bfloat16']) -def test_device_print_int32(sigtype): - - with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: - temp_script = f.name - - f.write( - textwrap.dedent(f""" -import torch -import torch_npu -import triton -import triton.language as tl -import os -import sys - -os.environ["TRITON_DEVICE_PRINT"] = "1" -os.environ["TRITON_ENABLE_TASKQUEUE"] = "0" - -@triton.jit -def triton_kernel(out_ptr0, in_ptr0, in_ptr1, XBLOCK: tl.constexpr): - idx = tl.arange(0, XBLOCK) - tmp0 = tl.load(in_ptr0 + idx) - tmp1 = tl.load(in_ptr1 + idx) - tmp2 = tmp0 + tmp1 - tl.device_print("OUTPUT = ", tmp2) - tl.store(out_ptr0 + idx, tmp2) - -def main(): - shape = (8,) - XS = 8 - dtype = torch.{sigtype} - - x0 = torch.zeros(shape, dtype=dtype).npu() - x1 = torch.ones(shape, dtype=dtype).npu() - - XVALS_INT = [0, -128, 127, -32768, 32767, -2147483648, 2147483647, 2147483648] - for i in range(x1.numel()): - x1[i] = XVALS_INT[i] - - out = torch.empty_like(x0) - - triton_kernel[1,](out, x0, x1, XS) - - print("Kernel execution completed") - - return out - -if __name__ == "__main__": - result = main() - print(f"Result shape: {{result.shape}}") - """)) - dtype = eval(f"torch.{sigtype}") - x0 = torch.zeros(shape, dtype=dtype).npu() - x1 = torch.ones(shape, dtype=dtype).npu() - for i in range(x1.numel()): - x1[i] = XVALS_INT[i] - - torch_ref = x0 + x1 - if 'int' in sigtype: - torch_ref_str = ','.join([str(int(val)) for val in torch_ref.cpu().numpy()]) - else: - values = torch_ref.cpu() - if values.dtype == torch.bfloat16: - values = values.float() - torch_ref_str = ','.join([f"{float(val):.6f}" for val in values.numpy()]) - - result = subprocess.run([sys.executable, temp_script], capture_output=True, text=True, env=os.environ.copy()) - - captured_output = result.stdout + "\n=== STDERR ===\n" + result.stderr - - ##with open(f"manual_capture_{sigtype}.txt", "w") as f: - ##f.write(captured_output) - ##f.write(f"torch_ref:{torch_ref_str}") - - if os.path.exists(temp_script): - os.remove(temp_script) - - assert torch_ref_str in captured_output +import torch +import torch_npu +import triton +import triton.language as tl +import pytest +import sys +import os +import subprocess +import tempfile +import textwrap + +os.environ["TRITON_DEVICE_PRINT"] = "1" +os.environ["TRITON_ENABLE_TASKQUEUE"] = "0" + +shape = (8, ) +XS = 8 +XVALS_INT = [ + 0, + torch.iinfo(torch.int8).min, + torch.iinfo(torch.int8).max, + torch.iinfo(torch.int16).min, + torch.iinfo(torch.int16).max, + torch.iinfo(torch.int32).min, + torch.iinfo(torch.int32).max, + torch.iinfo(torch.int32).max + 1 +] + + +@pytest.mark.parametrize('sigtype', ['int32', 'int64', 'int16', 'int8', 'float32', 'float16', 'bfloat16']) +def test_device_print_int32(sigtype): + + with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: + temp_script = f.name + + f.write( + textwrap.dedent(f""" +import torch +import torch_npu +import triton +import triton.language as tl +import os +import sys + +os.environ["TRITON_DEVICE_PRINT"] = "1" +os.environ["TRITON_ENABLE_TASKQUEUE"] = "0" + +@triton.jit +def triton_kernel(out_ptr0, in_ptr0, in_ptr1, XBLOCK: tl.constexpr): + idx = tl.arange(0, XBLOCK) + tmp0 = tl.load(in_ptr0 + idx) + tmp1 = tl.load(in_ptr1 + idx) + tmp2 = tmp0 + tmp1 + tl.device_print("OUTPUT = ", tmp2) + tl.store(out_ptr0 + idx, tmp2) + +def main(): + shape = (8,) + XS = 8 + dtype = torch.{sigtype} + + x0 = torch.zeros(shape, dtype=dtype).npu() + x1 = torch.ones(shape, dtype=dtype).npu() + + XVALS_INT = [0, -128, 127, -32768, 32767, -2147483648, 2147483647, 2147483648] + for i in range(x1.numel()): + x1[i] = XVALS_INT[i] + + out = torch.empty_like(x0) + + triton_kernel[1,](out, x0, x1, XS) + + print("Kernel execution completed") + + return out + +if __name__ == "__main__": + result = main() + print(f"Result shape: {{result.shape}}") + """)) + dtype = eval(f"torch.{sigtype}") + x0 = torch.zeros(shape, dtype=dtype).npu() + x1 = torch.ones(shape, dtype=dtype).npu() + for i in range(x1.numel()): + x1[i] = XVALS_INT[i] + + torch_ref = x0 + x1 + if 'int' in sigtype: + torch_ref_str = ','.join([str(int(val)) for val in torch_ref.cpu().numpy()]) + else: + values = torch_ref.cpu() + if values.dtype == torch.bfloat16: + values = values.float() + torch_ref_str = ','.join([f"{float(val):.6f}" for val in values.numpy()]) + + result = subprocess.run([sys.executable, temp_script], capture_output=True, text=True, env=os.environ.copy()) + + captured_output = result.stdout + "\n=== STDERR ===\n" + result.stderr + + ##with open(f"manual_capture_{sigtype}.txt", "w") as f: + ##f.write(captured_output) + ##f.write(f"torch_ref:{torch_ref_str}") + + if os.path.exists(temp_script): + os.remove(temp_script) + + assert torch_ref_str in captured_output diff --git a/third_party/ascend/examples/generalization_cases/test_div_rn.py b/third_party/ascend/unittest/generalization_cases/test_div_rn.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_div_rn.py rename to third_party/ascend/unittest/generalization_cases/test_div_rn.py diff --git a/third_party/ascend/examples/generalization_cases/test_dot_scaled.py b/third_party/ascend/unittest/generalization_cases/test_dot_scaled.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_dot_scaled.py rename to third_party/ascend/unittest/generalization_cases/test_dot_scaled.py diff --git a/third_party/ascend/examples/generalization_cases/test_eq.py b/third_party/ascend/unittest/generalization_cases/test_eq.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_eq.py rename to third_party/ascend/unittest/generalization_cases/test_eq.py diff --git a/third_party/ascend/examples/generalization_cases/test_erf.py b/third_party/ascend/unittest/generalization_cases/test_erf.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_erf.py rename to third_party/ascend/unittest/generalization_cases/test_erf.py diff --git a/third_party/ascend/examples/generalization_cases/test_exp.py b/third_party/ascend/unittest/generalization_cases/test_exp.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_exp.py rename to third_party/ascend/unittest/generalization_cases/test_exp.py diff --git a/third_party/ascend/examples/generalization_cases/test_exp2.py b/third_party/ascend/unittest/generalization_cases/test_exp2.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_exp2.py rename to third_party/ascend/unittest/generalization_cases/test_exp2.py diff --git a/third_party/ascend/examples/generalization_cases/test_expand_dims.py b/third_party/ascend/unittest/generalization_cases/test_expand_dims.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_expand_dims.py rename to third_party/ascend/unittest/generalization_cases/test_expand_dims.py diff --git a/third_party/ascend/examples/generalization_cases/test_fdiv.py b/third_party/ascend/unittest/generalization_cases/test_fdiv.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_fdiv.py rename to third_party/ascend/unittest/generalization_cases/test_fdiv.py diff --git a/third_party/ascend/examples/generalization_cases/test_flip.py b/third_party/ascend/unittest/generalization_cases/test_flip.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_flip.py rename to third_party/ascend/unittest/generalization_cases/test_flip.py diff --git a/third_party/ascend/examples/generalization_cases/test_full_op.py b/third_party/ascend/unittest/generalization_cases/test_full_op.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_full_op.py rename to third_party/ascend/unittest/generalization_cases/test_full_op.py diff --git a/third_party/ascend/examples/generalization_cases/test_ge_op.py b/third_party/ascend/unittest/generalization_cases/test_ge_op.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_ge_op.py rename to third_party/ascend/unittest/generalization_cases/test_ge_op.py diff --git a/third_party/ascend/examples/generalization_cases/test_general_add.py b/third_party/ascend/unittest/generalization_cases/test_general_add.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_general_add.py rename to third_party/ascend/unittest/generalization_cases/test_general_add.py diff --git a/third_party/ascend/examples/generalization_cases/test_general_arange.py b/third_party/ascend/unittest/generalization_cases/test_general_arange.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_general_arange.py rename to third_party/ascend/unittest/generalization_cases/test_general_arange.py diff --git a/third_party/ascend/examples/generalization_cases/test_general_cat.py b/third_party/ascend/unittest/generalization_cases/test_general_cat.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_general_cat.py rename to third_party/ascend/unittest/generalization_cases/test_general_cat.py diff --git a/third_party/ascend/examples/generalization_cases/test_general_clamp.py b/third_party/ascend/unittest/generalization_cases/test_general_clamp.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_general_clamp.py rename to third_party/ascend/unittest/generalization_cases/test_general_clamp.py diff --git a/third_party/ascend/examples/generalization_cases/test_general_div.py b/third_party/ascend/unittest/generalization_cases/test_general_div.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_general_div.py rename to third_party/ascend/unittest/generalization_cases/test_general_div.py diff --git a/third_party/ascend/examples/generalization_cases/test_general_floor.py b/third_party/ascend/unittest/generalization_cases/test_general_floor.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_general_floor.py rename to third_party/ascend/unittest/generalization_cases/test_general_floor.py diff --git a/third_party/ascend/examples/generalization_cases/test_general_floordiv.py b/third_party/ascend/unittest/generalization_cases/test_general_floordiv.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_general_floordiv.py rename to third_party/ascend/unittest/generalization_cases/test_general_floordiv.py diff --git a/third_party/ascend/examples/generalization_cases/test_general_fma.py b/third_party/ascend/unittest/generalization_cases/test_general_fma.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_general_fma.py rename to third_party/ascend/unittest/generalization_cases/test_general_fma.py diff --git a/third_party/ascend/examples/generalization_cases/test_general_gather.py b/third_party/ascend/unittest/generalization_cases/test_general_gather.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_general_gather.py rename to third_party/ascend/unittest/generalization_cases/test_general_gather.py diff --git a/third_party/ascend/examples/generalization_cases/test_general_interleave.py b/third_party/ascend/unittest/generalization_cases/test_general_interleave.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_general_interleave.py rename to third_party/ascend/unittest/generalization_cases/test_general_interleave.py diff --git a/third_party/ascend/examples/generalization_cases/test_general_join.py b/third_party/ascend/unittest/generalization_cases/test_general_join.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_general_join.py rename to third_party/ascend/unittest/generalization_cases/test_general_join.py diff --git a/third_party/ascend/examples/generalization_cases/test_general_log.py b/third_party/ascend/unittest/generalization_cases/test_general_log.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_general_log.py rename to third_party/ascend/unittest/generalization_cases/test_general_log.py diff --git a/third_party/ascend/examples/generalization_cases/test_general_log2.py b/third_party/ascend/unittest/generalization_cases/test_general_log2.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_general_log2.py rename to third_party/ascend/unittest/generalization_cases/test_general_log2.py diff --git a/third_party/ascend/examples/generalization_cases/test_general_maximum.py b/third_party/ascend/unittest/generalization_cases/test_general_maximum.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_general_maximum.py rename to third_party/ascend/unittest/generalization_cases/test_general_maximum.py diff --git a/third_party/ascend/examples/generalization_cases/test_general_minimum.py b/third_party/ascend/unittest/generalization_cases/test_general_minimum.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_general_minimum.py rename to third_party/ascend/unittest/generalization_cases/test_general_minimum.py diff --git a/third_party/ascend/examples/generalization_cases/test_general_mul.py b/third_party/ascend/unittest/generalization_cases/test_general_mul.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_general_mul.py rename to third_party/ascend/unittest/generalization_cases/test_general_mul.py diff --git a/third_party/ascend/examples/generalization_cases/test_general_ravel.py b/third_party/ascend/unittest/generalization_cases/test_general_ravel.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_general_ravel.py rename to third_party/ascend/unittest/generalization_cases/test_general_ravel.py diff --git a/third_party/ascend/examples/generalization_cases/test_general_reshape.py b/third_party/ascend/unittest/generalization_cases/test_general_reshape.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_general_reshape.py rename to third_party/ascend/unittest/generalization_cases/test_general_reshape.py diff --git a/third_party/ascend/examples/generalization_cases/test_general_rsqrt.py b/third_party/ascend/unittest/generalization_cases/test_general_rsqrt.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_general_rsqrt.py rename to third_party/ascend/unittest/generalization_cases/test_general_rsqrt.py diff --git a/third_party/ascend/examples/generalization_cases/test_general_sigmoid.py b/third_party/ascend/unittest/generalization_cases/test_general_sigmoid.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_general_sigmoid.py rename to third_party/ascend/unittest/generalization_cases/test_general_sigmoid.py diff --git a/third_party/ascend/examples/generalization_cases/test_general_sin.py b/third_party/ascend/unittest/generalization_cases/test_general_sin.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_general_sin.py rename to third_party/ascend/unittest/generalization_cases/test_general_sin.py diff --git a/third_party/ascend/examples/generalization_cases/test_general_softmax.py b/third_party/ascend/unittest/generalization_cases/test_general_softmax.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_general_softmax.py rename to third_party/ascend/unittest/generalization_cases/test_general_softmax.py diff --git a/third_party/ascend/examples/generalization_cases/test_general_split.py b/third_party/ascend/unittest/generalization_cases/test_general_split.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_general_split.py rename to third_party/ascend/unittest/generalization_cases/test_general_split.py diff --git a/third_party/ascend/examples/generalization_cases/test_general_sub.py b/third_party/ascend/unittest/generalization_cases/test_general_sub.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_general_sub.py rename to third_party/ascend/unittest/generalization_cases/test_general_sub.py diff --git a/third_party/ascend/examples/generalization_cases/test_general_tensor_descriptor.py b/third_party/ascend/unittest/generalization_cases/test_general_tensor_descriptor.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_general_tensor_descriptor.py rename to third_party/ascend/unittest/generalization_cases/test_general_tensor_descriptor.py diff --git a/third_party/ascend/examples/generalization_cases/test_general_view.py b/third_party/ascend/unittest/generalization_cases/test_general_view.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_general_view.py rename to third_party/ascend/unittest/generalization_cases/test_general_view.py diff --git a/third_party/ascend/examples/generalization_cases/test_gt_op.py b/third_party/ascend/unittest/generalization_cases/test_gt_op.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_gt_op.py rename to third_party/ascend/unittest/generalization_cases/test_gt_op.py diff --git a/third_party/ascend/examples/generalization_cases/test_invalid_fp8.py b/third_party/ascend/unittest/generalization_cases/test_invalid_fp8.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_invalid_fp8.py rename to third_party/ascend/unittest/generalization_cases/test_invalid_fp8.py diff --git a/third_party/ascend/examples/generalization_cases/test_invert.py b/third_party/ascend/unittest/generalization_cases/test_invert.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_invert.py rename to third_party/ascend/unittest/generalization_cases/test_invert.py diff --git a/third_party/ascend/examples/generalization_cases/test_le_op.py b/third_party/ascend/unittest/generalization_cases/test_le_op.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_le_op.py rename to third_party/ascend/unittest/generalization_cases/test_le_op.py diff --git a/third_party/ascend/examples/generalization_cases/test_load_store.py b/third_party/ascend/unittest/generalization_cases/test_load_store.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_load_store.py rename to third_party/ascend/unittest/generalization_cases/test_load_store.py diff --git a/third_party/ascend/examples/generalization_cases/test_log1p.py b/third_party/ascend/unittest/generalization_cases/test_log1p.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_log1p.py rename to third_party/ascend/unittest/generalization_cases/test_log1p.py diff --git a/third_party/ascend/examples/generalization_cases/test_logical_and_op.py b/third_party/ascend/unittest/generalization_cases/test_logical_and_op.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_logical_and_op.py rename to third_party/ascend/unittest/generalization_cases/test_logical_and_op.py diff --git a/third_party/ascend/examples/generalization_cases/test_logical_or_op.py b/third_party/ascend/unittest/generalization_cases/test_logical_or_op.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_logical_or_op.py rename to third_party/ascend/unittest/generalization_cases/test_logical_or_op.py diff --git a/third_party/ascend/examples/generalization_cases/test_lshift_op.py b/third_party/ascend/unittest/generalization_cases/test_lshift_op.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_lshift_op.py rename to third_party/ascend/unittest/generalization_cases/test_lshift_op.py diff --git a/third_party/ascend/examples/generalization_cases/test_lt_op.py b/third_party/ascend/unittest/generalization_cases/test_lt_op.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_lt_op.py rename to third_party/ascend/unittest/generalization_cases/test_lt_op.py diff --git a/third_party/ascend/examples/generalization_cases/test_make_blkptr_matmul.py b/third_party/ascend/unittest/generalization_cases/test_make_blkptr_matmul.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_make_blkptr_matmul.py rename to third_party/ascend/unittest/generalization_cases/test_make_blkptr_matmul.py diff --git a/third_party/ascend/examples/generalization_cases/test_make_block_ptr.py b/third_party/ascend/unittest/generalization_cases/test_make_block_ptr.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_make_block_ptr.py rename to third_party/ascend/unittest/generalization_cases/test_make_block_ptr.py diff --git a/third_party/ascend/examples/generalization_cases/test_matmul.py b/third_party/ascend/unittest/generalization_cases/test_matmul.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_matmul.py rename to third_party/ascend/unittest/generalization_cases/test_matmul.py diff --git a/third_party/ascend/examples/generalization_cases/test_max.py b/third_party/ascend/unittest/generalization_cases/test_max.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_max.py rename to third_party/ascend/unittest/generalization_cases/test_max.py diff --git a/third_party/ascend/examples/generalization_cases/test_min.py b/third_party/ascend/unittest/generalization_cases/test_min.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_min.py rename to third_party/ascend/unittest/generalization_cases/test_min.py diff --git a/third_party/ascend/examples/generalization_cases/test_mod.py b/third_party/ascend/unittest/generalization_cases/test_mod.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_mod.py rename to third_party/ascend/unittest/generalization_cases/test_mod.py diff --git a/third_party/ascend/examples/generalization_cases/test_ne.py b/third_party/ascend/unittest/generalization_cases/test_ne.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_ne.py rename to third_party/ascend/unittest/generalization_cases/test_ne.py diff --git a/third_party/ascend/examples/generalization_cases/test_neg.py b/third_party/ascend/unittest/generalization_cases/test_neg.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_neg.py rename to third_party/ascend/unittest/generalization_cases/test_neg.py diff --git a/third_party/ascend/examples/generalization_cases/test_not.py b/third_party/ascend/unittest/generalization_cases/test_not.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_not.py rename to third_party/ascend/unittest/generalization_cases/test_not.py diff --git a/third_party/ascend/examples/generalization_cases/test_or.py b/third_party/ascend/unittest/generalization_cases/test_or.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_or.py rename to third_party/ascend/unittest/generalization_cases/test_or.py diff --git a/third_party/ascend/examples/generalization_cases/test_permute_1d_2d.py b/third_party/ascend/unittest/generalization_cases/test_permute_1d_2d.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_permute_1d_2d.py rename to third_party/ascend/unittest/generalization_cases/test_permute_1d_2d.py diff --git a/third_party/ascend/examples/generalization_cases/test_permute_3d.py b/third_party/ascend/unittest/generalization_cases/test_permute_3d.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_permute_3d.py rename to third_party/ascend/unittest/generalization_cases/test_permute_3d.py diff --git a/third_party/ascend/examples/generalization_cases/test_permute_4d_5d.py b/third_party/ascend/unittest/generalization_cases/test_permute_4d_5d.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_permute_4d_5d.py rename to third_party/ascend/unittest/generalization_cases/test_permute_4d_5d.py diff --git a/third_party/ascend/examples/generalization_cases/test_rand.py b/third_party/ascend/unittest/generalization_cases/test_rand.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_rand.py rename to third_party/ascend/unittest/generalization_cases/test_rand.py diff --git a/third_party/ascend/examples/generalization_cases/test_range.py b/third_party/ascend/unittest/generalization_cases/test_range.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_range.py rename to third_party/ascend/unittest/generalization_cases/test_range.py diff --git a/third_party/ascend/examples/generalization_cases/test_reduce.py b/third_party/ascend/unittest/generalization_cases/test_reduce.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_reduce.py rename to third_party/ascend/unittest/generalization_cases/test_reduce.py diff --git a/third_party/ascend/examples/generalization_cases/test_relu.py b/third_party/ascend/unittest/generalization_cases/test_relu.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_relu.py rename to third_party/ascend/unittest/generalization_cases/test_relu.py diff --git a/third_party/ascend/examples/generalization_cases/test_rshift_op.py b/third_party/ascend/unittest/generalization_cases/test_rshift_op.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_rshift_op.py rename to third_party/ascend/unittest/generalization_cases/test_rshift_op.py diff --git a/third_party/ascend/examples/generalization_cases/test_scalar_tensor.py b/third_party/ascend/unittest/generalization_cases/test_scalar_tensor.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_scalar_tensor.py rename to third_party/ascend/unittest/generalization_cases/test_scalar_tensor.py diff --git a/third_party/ascend/examples/generalization_cases/test_sort.py b/third_party/ascend/unittest/generalization_cases/test_sort.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_sort.py rename to third_party/ascend/unittest/generalization_cases/test_sort.py diff --git a/third_party/ascend/examples/generalization_cases/test_sqrt.py b/third_party/ascend/unittest/generalization_cases/test_sqrt.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_sqrt.py rename to third_party/ascend/unittest/generalization_cases/test_sqrt.py diff --git a/third_party/ascend/examples/generalization_cases/test_sqrt_rn.py b/third_party/ascend/unittest/generalization_cases/test_sqrt_rn.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_sqrt_rn.py rename to third_party/ascend/unittest/generalization_cases/test_sqrt_rn.py diff --git a/third_party/ascend/examples/generalization_cases/test_static_print_and_assert_op.py b/third_party/ascend/unittest/generalization_cases/test_static_print_and_assert_op.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_static_print_and_assert_op.py rename to third_party/ascend/unittest/generalization_cases/test_static_print_and_assert_op.py diff --git a/third_party/ascend/examples/generalization_cases/test_sum.py b/third_party/ascend/unittest/generalization_cases/test_sum.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_sum.py rename to third_party/ascend/unittest/generalization_cases/test_sum.py diff --git a/third_party/ascend/examples/generalization_cases/test_sum_dim0.py b/third_party/ascend/unittest/generalization_cases/test_sum_dim0.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_sum_dim0.py rename to third_party/ascend/unittest/generalization_cases/test_sum_dim0.py diff --git a/third_party/ascend/examples/generalization_cases/test_sum_dim1.py b/third_party/ascend/unittest/generalization_cases/test_sum_dim1.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_sum_dim1.py rename to third_party/ascend/unittest/generalization_cases/test_sum_dim1.py diff --git a/third_party/ascend/examples/generalization_cases/test_swizzle2d.py b/third_party/ascend/unittest/generalization_cases/test_swizzle2d.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_swizzle2d.py rename to third_party/ascend/unittest/generalization_cases/test_swizzle2d.py diff --git a/third_party/ascend/examples/generalization_cases/test_tan.py b/third_party/ascend/unittest/generalization_cases/test_tan.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_tan.py rename to third_party/ascend/unittest/generalization_cases/test_tan.py diff --git a/third_party/ascend/examples/generalization_cases/test_trans_1d_2d.py b/third_party/ascend/unittest/generalization_cases/test_trans_1d_2d.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_trans_1d_2d.py rename to third_party/ascend/unittest/generalization_cases/test_trans_1d_2d.py diff --git a/third_party/ascend/examples/generalization_cases/test_trans_3d.py b/third_party/ascend/unittest/generalization_cases/test_trans_3d.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_trans_3d.py rename to third_party/ascend/unittest/generalization_cases/test_trans_3d.py diff --git a/third_party/ascend/examples/generalization_cases/test_trans_4d_5d.py b/third_party/ascend/unittest/generalization_cases/test_trans_4d_5d.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_trans_4d_5d.py rename to third_party/ascend/unittest/generalization_cases/test_trans_4d_5d.py diff --git a/third_party/ascend/examples/generalization_cases/test_umulhi.py b/third_party/ascend/unittest/generalization_cases/test_umulhi.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_umulhi.py rename to third_party/ascend/unittest/generalization_cases/test_umulhi.py diff --git a/third_party/ascend/examples/generalization_cases/test_where.py b/third_party/ascend/unittest/generalization_cases/test_where.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_where.py rename to third_party/ascend/unittest/generalization_cases/test_where.py diff --git a/third_party/ascend/examples/generalization_cases/test_xor.py b/third_party/ascend/unittest/generalization_cases/test_xor.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_xor.py rename to third_party/ascend/unittest/generalization_cases/test_xor.py diff --git a/third_party/ascend/examples/generalization_cases/test_xorsum.py b/third_party/ascend/unittest/generalization_cases/test_xorsum.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_xorsum.py rename to third_party/ascend/unittest/generalization_cases/test_xorsum.py diff --git a/third_party/ascend/examples/generalization_cases/test_zeros_op.py b/third_party/ascend/unittest/generalization_cases/test_zeros_op.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_zeros_op.py rename to third_party/ascend/unittest/generalization_cases/test_zeros_op.py diff --git a/third_party/ascend/examples/generalization_cases/test_zeroslike.py b/third_party/ascend/unittest/generalization_cases/test_zeroslike.py similarity index 100% rename from third_party/ascend/examples/generalization_cases/test_zeroslike.py rename to third_party/ascend/unittest/generalization_cases/test_zeroslike.py diff --git a/third_party/ascend/unittest/kernels/README.md b/third_party/ascend/unittest/kernels/README.md new file mode 100644 index 000000000..20eb7e42a --- /dev/null +++ b/third_party/ascend/unittest/kernels/README.md @@ -0,0 +1,62 @@ +# 指导:如何新增kernel测试用例 +新增kernel测试用例可以分为三大步: +1、准备pt文件 +2、在triton-ascend仓中添加kernel算子,完成本地kernel测试 +3、将pt文件上传到obs桶中 + +## 1、准备pt文件 + +pt 文件用于把 GPU(或参考实现)上的输入与输出作为 golden 数据,后续测试会在 NPU 上运行 Triton kernel 并与之比对。 + +**三步生成流程** + +- **步骤 1 — 构造GPU输入并保存副本预处理成NPU kernel的输入**:根据GPU上kernel或pytorch算子的参数构造 `input_data`(键名须与 kernel 参数一致),把所有 Tensor 克隆到 CPU,形成 `input_data_before`,若GPU上算子的输入和NPU上算子有出入,需要提前预处理使`input_data_before`符合NPU上算子入参的要求。 +- **步骤 2 — 运行GPU Kernel获取输出**:在GPU上运行GPU kernel,得到 `gpu_output`,并将 Tensor 转为 CPU。 +- **步骤 3 — 打包并保存**:把 `input_data_before`、`grid`、`gpu_output` 封装为字典,通过 `torch.save` 保存为 `{kernel_name}.pt`。如果有多组用例,保存为 list-of-dicts(`[case0, case1]`)。 + +**精简示例** + +```python +import copy +import torch + +DEVICE = torch.device("cuda:0") +batch_size = 2 +grid = (batch_size,) + +input_data = { + "output_token_ids_ptr": torch.zeros((batch_size, 4), dtype=torch.int32, device=DEVICE), + "cu_num_draft_tokens_ptr": torch.tensor([2, 1], dtype=torch.int32, device=DEVICE), + # ... 其它字段 +} + +# 保存输入副本到 CPU +input_data_before = { + k: (v.clone().cpu() if isinstance(v, torch.Tensor) else copy.deepcopy(v)) + for k, v in input_data.items() +} +# 预处理 input_data_before 符合 NPU kernel 输入 +input_data_before["npu_need_param_key"] = NPU_NEED_PARAMS_VALUE +# 运行 kernel(在 GPU / 参考实现上)并收集输出 +triton_kernel[grid](**input_data) +# 这里用 input_data 作为示例,实际应调用对应的 triton/pytorch 函数 +gpu_output = {k: (v.cpu() if isinstance(v, torch.Tensor) else v) for k, v in input_data.items()} + +save_obj = {"input_data": input_data_before, "grid": grid, "gpu_output": gpu_output} +torch.save(save_obj, ".pt") +# 多组用例场景:torch.save([save_obj1, save_obj2], ".pt") +``` + +## 2、在triton-ascend新增三方kernel测试用例 + +- **步骤 1 — 在triton-ascend仓中新增kernel算子** :本地验证阶段,在 kernels/xxx(例如vllm、sglang) 下新增与算子同名的 Python 文件,内容为Triton kernel函数。 +- **步骤 2 — 本地测试** :将pt文件放在kernels目录下,在项目根目录运行 +python -m pytest -v third_party/ascend/unittest/kernels/test_triton_kernel.py + +**说明** +- 指定单个 kernel:在项目根目录下执行 python -m pytest -v ascend/test/common/test_triton_kernel.py --kernel={kernel_name} +- pt文件查找策略:优先使用仓库内匹配的本地 pt,若本地不存在则按需从远端 OBS 下载 {kernel_name}.pt文件。 +- 本地已存在的pt文件,在执行完测试后不会删除,从obs桶取的文件在跑完测试后会被测试程序直接删除。 + +## 3、将pt文件上传至obs桶 +本地验证通过后,将pt文件统一上传到OBS桶当中,OBS桶链接:https://triton-ascend-artifacts.obs.cn-southwest-2.myhuaweicloud.com/test/kernels/{xxx}_pt/{kernel_name}.pt,xxx为vllm或sglang diff --git a/third_party/ascend/unittest/kernels/common_kernel.py b/third_party/ascend/unittest/kernels/common_kernel.py new file mode 100644 index 000000000..fbbce42dd --- /dev/null +++ b/third_party/ascend/unittest/kernels/common_kernel.py @@ -0,0 +1,7 @@ +import triton +import triton.language as tl + + +@triton.jit +def safe_exp(x): + return tl.exp(tl.where(x <= 0, x, float("-inf"))) diff --git a/third_party/ascend/unittest/kernels/test_common.py b/third_party/ascend/unittest/kernels/test_common.py new file mode 100644 index 000000000..ababcb154 --- /dev/null +++ b/third_party/ascend/unittest/kernels/test_common.py @@ -0,0 +1,111 @@ +from typing import Optional +import torch +import pytest + +DEVICE_TYPE_NPU = 'npu' + + +def validate_cmp(dtype, y_cal, y_ref, overflow_mode: Optional[str] = None, device_type: Optional[str] = None): + if device_type is not None: + target_device = torch.device(device_type) + y_cal = y_cal.to(target_device) + y_ref = y_ref.to(target_device) + else: + y_cal = y_cal.npu() + y_ref = y_ref.npu() + if overflow_mode == "saturate": + if dtype in ['float32', 'float16']: + min_value = -torch.finfo(dtype).min + max_value = torch.finfo(dtype).max + elif dtype in ['int32', 'int16', 'int8']: + min_value = torch.iinfo(dtype).min + max_value = torch.iinfo(dtype).max + elif dtype == 'bool': + min_value = 0 + max_value = 1 + else: + raise ValueError('Invalid parameter "dtype" is found : {}'.format(dtype)) + y_ref = torch.clamp(y_ref, min=min_value, max=max_value) + if dtype == 'float16': + torch.testing.assert_close(y_ref, y_cal, rtol=5e-03, atol=5e-03, equal_nan=True) + elif dtype == 'bfloat16': + torch.testing.assert_close(y_ref.to(torch.float32), y_cal.to(torch.float32), rtol=5e-03, atol=5e-03, + equal_nan=True) + elif dtype == 'float32': + torch.testing.assert_close(y_ref, y_cal, rtol=1e-03, atol=1e-03, equal_nan=True) + elif dtype in ['int64', 'int32', 'int16', 'int8']: + assert torch.equal(y_cal, y_ref) + elif dtype == 'bool': + assert torch.equal(y_cal, y_ref) + else: + raise ValueError('Invalid parameter \"dtype\" is found : {}'.format(dtype)) + + +def convert_tensor_with_device_type(indata: dict, device_type: str): + target_device = torch.device(device_type) + outdata = {} + + for key, value in indata.items(): + if isinstance(value, torch.Tensor): + if value.device.type != target_device.type: + outdata[key] = value.to(target_device) + else: + outdata[key] = value + else: + outdata[key] = value + + return outdata + + +def compare_data_precision(dict_ref: dict, dict_cal: dict, device_type: str): + keys_ref, keys_cal = set(dict_ref.keys()), set(dict_cal.keys()) + if not keys_ref.issubset(keys_cal): + raise ValueError("The keys of dict_ref is not subset of dict_cal") + + for key in dict_ref.keys(): + val_a, val_b = dict_ref[key], dict_cal[key] + if not isinstance(val_b, type(val_a)): + raise ValueError("The data type of two dicts are different") + + if isinstance(val_a, torch.Tensor): + validate_cmp(dtype=str(val_a.dtype).split('.')[-1], y_ref=val_a, y_cal=val_b, device_type=device_type) + + +def run_and_compare_ptfile(ptfile_path: str, kernel_runner, device_type: str = DEVICE_TYPE_NPU): + try: + datas = torch.load(ptfile_path, map_location=torch.device('cpu')) + except Exception as e: + pytest.fail(f"load file {ptfile_path} failed: {e}") + + def _run_single_case(data): + if not isinstance(data, dict): + pytest.fail("Each case loaded from pt file must be a dict") + + input_data = convert_tensor_with_device_type(data.get("input_data", {}), device_type=device_type) + grid = data.get("grid") + try: + kernel_runner(input_data, grid) + except Exception as e: + pytest.fail(f"kernel_runner execution failed: {e}") + + output_data_cpu = convert_tensor_with_device_type(input_data, device_type='cpu') + expected = data.get("gpu_output", {}) + expected_filtered = {k: expected[k] for k in output_data_cpu.keys() if k in expected} + if not expected_filtered: + pytest.fail("No matching expected outputs found in pt file for comparison") + try: + compare_data_precision(expected_filtered, output_data_cpu, device_type='cpu') + except Exception as e: + pytest.fail(f"The testcase failed: {e}") + + # Supports three scenarios: + # 1) The file stores a single dict (existing behavior) + # 2) The file stores a list, where each element is a case dict + # 3) The file stores a dict, but some tensors represent multiple cases in batch on the 0th dimension (no automatic splitting; it is recommended to use a list) + if isinstance(datas, list): + for _, data in enumerate(datas): + _run_single_case(data) + elif isinstance(datas, dict): + _run_single_case(datas) + else: + pytest.fail("Unsupported pt file format: must be a dict or a list of dicts") diff --git a/third_party/ascend/unittest/kernels/test_triton_kernel.py b/third_party/ascend/unittest/kernels/test_triton_kernel.py new file mode 100644 index 000000000..528c8eb08 --- /dev/null +++ b/third_party/ascend/unittest/kernels/test_triton_kernel.py @@ -0,0 +1,73 @@ +import importlib +import os +import urllib.request +from pathlib import Path + +import pytest + +import test_common + + +def discover_kernels(): + kernels = [] + kernels_root_path = Path(__file__).parents[0] + for p in kernels_root_path.rglob("*.py"): + if not p.is_file(): + continue + if p.parent == kernels_root_path: + continue + rel = p.relative_to(kernels_root_path) + if len(rel.parts) == 1 or p.name == "__init__.py": + continue + module_path = ".".join(rel.with_suffix("").parts) + kernels.append((module_path, p.stem)) + return sorted(kernels, key=lambda x: x[1]) + + +KERNEL_ITEMS = discover_kernels() + + +@pytest.mark.parametrize("module_path, kernel_name", KERNEL_ITEMS) +def test_triton_kernel(module_path, kernel_name, pytestconfig): + selected = pytestconfig.getoption("kernel") + if selected: + if kernel_name not in selected: + pytest.skip(f"skip {kernel_name} due to --kernel filter") + base_url = "https://triton-ascend-artifacts.obs.cn-southwest-2.myhuaweicloud.com" + rel = module_path + parts = rel.split(".") if rel else [] + pt_url = f"{base_url}/test/kernels/{parts[0]}_pt/{kernel_name}.pt" + local_pt = Path(__file__).parent / f"{kernel_name}.pt" + downloaded = False + if not local_pt.exists(): + try: + urllib.request.urlretrieve(pt_url, local_pt) + downloaded = True + except Exception as e: + pytest.fail( + f"Failed to download the {kernel_name}.pt file. Please check whether the {kernel_name}.pt file has been uploaded to the OBS bucket: {e}" + ) + try: + mod = importlib.import_module(module_path) + except Exception as e: + pytest.fail(f"import {module_path} failed: {e}") + + if hasattr(mod, kernel_name): + kernel_attr = kernel_name + else: + candidates = [a for a in dir(mod) if a.endswith("_kernel")] + kernel_attr = candidates[0] if candidates else None + + if not kernel_attr: + pytest.fail(f"No kernel callable found in {module_path}") + + kernel_callable = getattr(mod, kernel_attr) + + def runner(input_data, grid): + kernel_callable[grid](**input_data) + + try: + test_common.run_and_compare_ptfile(str(local_pt), runner, device_type='npu') + finally: + if downloaded and local_pt.exists(): + local_pt.unlink() diff --git a/third_party/ascend/unittest/kernels/vllm/expand_kernel.py b/third_party/ascend/unittest/kernels/vllm/expand_kernel.py new file mode 100644 index 000000000..8c87b6d0b --- /dev/null +++ b/third_party/ascend/unittest/kernels/vllm/expand_kernel.py @@ -0,0 +1,33 @@ +import triton +import triton.language as tl +import triton.language.extra.cann.extension as extension + + +@triton.jit(do_not_specialize=["replace_from", "replace_to"]) +def expand_kernel( + output_ptr, # [num_tokens] + input_ptr, # [batch_size] + cu_num_tokens_ptr, # [batch_size] + replace_from, + replace_to, + vec_len, + MAX_NUM_TOKENS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + req_idx = tl.program_id(0) + offset = req_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + len_mask = offset < vec_len + + start_idx = tl.where(offset == 0, 0, tl.load(cu_num_tokens_ptr + offset - 1, len_mask)) + end_idx = tl.load(cu_num_tokens_ptr + offset, len_mask) + num_tokens = end_idx - start_idx + + src_val = tl.load(input_ptr + offset, len_mask) + src_val = tl.where(src_val == replace_from, replace_to, src_val) + + for i in tl.range(0, BLOCK_SIZE): + num_tokens1 = extension.get_element(num_tokens, (i, )) + start_idx1 = extension.get_element(start_idx, (i, )) + src_val1 = extension.get_element(src_val, (i, )) + offset1 = tl.arange(0, MAX_NUM_TOKENS) + tl.store(output_ptr + start_idx1 + offset1, src_val1, mask=offset1 < num_tokens1) diff --git a/third_party/ascend/unittest/kernels/vllm/rejection_random_sample_kernel.py b/third_party/ascend/unittest/kernels/vllm/rejection_random_sample_kernel.py new file mode 100644 index 000000000..b5d124a91 --- /dev/null +++ b/third_party/ascend/unittest/kernels/vllm/rejection_random_sample_kernel.py @@ -0,0 +1,55 @@ +import triton +import triton.language as tl + + +@triton.jit(do_not_specialize=["max_spec_len"]) +def rejection_random_sample_kernel( + output_token_ids_ptr, # [batch_size, max_spec_len + 1] + cu_num_draft_tokens_ptr, # [batch_size] + draft_token_ids_ptr, # [num_tokens] + draft_probs_ptr, # [num_tokens, vocab_size] or None + target_probs_ptr, # [num_tokens, vocab_size] + bonus_token_ids_ptr, # [batch_size] + recovered_token_ids_ptr, # [num_tokens] + uniform_probs_ptr, # [num_tokens] + is_greedy_ptr, # [batch_size] + max_spec_len, + vocab_size, + NO_DRAFT_PROBS: tl.constexpr, +): + req_idx = tl.program_id(0) + is_greedy = tl.load(is_greedy_ptr + req_idx) + if is_greedy: + # Early exost for greedy sampling requests + return + + start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1) + end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) + num_draft_tokens = end_idx - start_idx + + rejected = False + for pos in range(num_draft_tokens): + if not rejected: + draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) + if NO_DRAFT_PROBS: + draft_prob = 1 + else: + draft_prob = tl.load(draft_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id) + target_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id) + uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos) + if draft_prob > 0 and target_prob / draft_prob >= uniform_prob: + # Accept + token_id = draft_token_id + else: + # Reject. Use recovered token + rejected = True + token_id = tl.load(recovered_token_ids_ptr + start_idx + pos) + tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, token_id) + + if not rejected: + # If all tokens are accepted, append the bonus token + bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx) + tl.store( + output_token_ids_ptr + req_idx * (max_spec_len + 1) + num_draft_tokens, + bonus_token_id, + ) diff --git a/third_party/ascend/unittest/kernels/vllm/sample_recovered_tokens_kernel.py b/third_party/ascend/unittest/kernels/vllm/sample_recovered_tokens_kernel.py new file mode 100644 index 000000000..24aa9c7b7 --- /dev/null +++ b/third_party/ascend/unittest/kernels/vllm/sample_recovered_tokens_kernel.py @@ -0,0 +1,77 @@ +import triton +import triton.language as tl +import triton.language.extra.cann.extension as extension + + +@triton.jit +def sample_recovered_tokens_kernel( + output_token_ids_ptr, # [num_tokens] + cu_num_draft_tokens_ptr, # [batch_size] + draft_token_ids_ptr, # [num_tokens] + draft_probs_ptr, # [num_tokens, vocab_size] or None + target_probs_ptr, # [num_tokens, vocab_size] + q_ptr, # [batch_size, vocab_size] + vocab_size, + PADDED_VOCAB_SIZE: tl.constexpr, + NO_DRAFT_PROBS: tl.constexpr, + SUB_BLOCK: tl.constexpr, +): + req_idx = tl.program_id(0) + start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1) + end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) + num_draft_tokens = end_idx - start_idx + + # Early exit for out-of-range positions. + pos = tl.program_id(1) + if pos >= num_draft_tokens: + return + + loop = (vocab_size + SUB_BLOCK - 1) // SUB_BLOCK + global_recovered_id = -1 + global_max_p = -1.0 + if NO_DRAFT_PROBS: + draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) + orig_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id) + # Temporarily zero out the probability of the draft token. + # This is essentially the same as target_prob - draft_prob, except that + # n-gram does not have draft_prob. We regard it as 1. + tl.store(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, 0) + for loop_i in range(loop): + vocab_start = loop_i * SUB_BLOCK + vocab_offset = vocab_start + tl.arange(0, SUB_BLOCK) + prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, mask=vocab_offset + < vocab_size, other=0) + q = tl.load(q_ptr + req_idx * vocab_size + vocab_offset, mask=vocab_offset < vocab_size, + other=float("-inf")) + new_p = prob / q + recovered_id = tl.argmax(new_p, axis=-1) + max_p = extension.get_element(new_p, (recovered_id, )) + if max_p > global_max_p: + global_max_p = max_p + global_recovered_id = vocab_start + recovered_id + else: + for loop_i in range(loop): + vocab_start = loop_i * SUB_BLOCK + vocab_offset = vocab_start + tl.arange(0, SUB_BLOCK) + draft_prob = tl.load(draft_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, mask=vocab_offset + < vocab_size, other=0) + target_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, mask=vocab_offset + < vocab_size, other=0) + prob = tl.maximum(target_prob - draft_prob, 0) + # NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because + # `tl.argmax` will select the maximum value. + + q = tl.load(q_ptr + req_idx * vocab_size + vocab_offset, mask=vocab_offset < vocab_size, + other=float("-inf")) + new_p = prob / q + recovered_id = tl.argmax(new_p, axis=-1) + max_p = extension.get_element(new_p, (recovered_id, )) + if max_p > global_max_p: + global_max_p = max_p + global_recovered_id = vocab_start + recovered_id + + tl.store(output_token_ids_ptr + start_idx + pos, global_recovered_id) + + if NO_DRAFT_PROBS: + # Restore the original probability. + tl.store(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, orig_prob) diff --git a/third_party/ascend/examples/pytest_ut/attn_cp_triton_kernel_3d.py b/third_party/ascend/unittest/pytest_ut/attn_cp_triton_kernel_3d.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/attn_cp_triton_kernel_3d.py rename to third_party/ascend/unittest/pytest_ut/attn_cp_triton_kernel_3d.py diff --git a/third_party/ascend/examples/pytest_ut/attn_cp_triton_kernel_3d_la.py b/third_party/ascend/unittest/pytest_ut/attn_cp_triton_kernel_3d_la.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/attn_cp_triton_kernel_3d_la.py rename to third_party/ascend/unittest/pytest_ut/attn_cp_triton_kernel_3d_la.py diff --git a/third_party/ascend/unittest/pytest_ut/conftest.py b/third_party/ascend/unittest/pytest_ut/conftest.py new file mode 100644 index 000000000..7c76d57ce --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/conftest.py @@ -0,0 +1,48 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import pytest + + +@pytest.fixture(scope="module", autouse=True) +def assign_npu(request, worker_id): + marker = request.node.get_closest_marker("backend") + if marker: + backend = marker.args[0] + else: + backend = "torch_npu" + if backend == "torch_npu": + import torch + npu_count = torch.npu.device_count() + if worker_id == "master": + npu_id = 0 + else: + idx = int(worker_id.replace("gw", "")) + npu_id = idx % npu_count + torch.npu.set_device(npu_id) + elif backend == "mindspore": + import mindspore + npu_count = mindspore.device_context.ascend.device_count() + if worker_id == "master": + npu_id = 0 + else: + idx = int(worker_id.replace("gw", "")) + npu_id = idx % npu_count + mindspore.set_device("Ascend", npu_id) diff --git a/third_party/ascend/examples/pytest_ut/test_2d_permute.py b/third_party/ascend/unittest/pytest_ut/test_2d_permute.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_2d_permute.py rename to third_party/ascend/unittest/pytest_ut/test_2d_permute.py diff --git a/third_party/ascend/examples/pytest_ut/test_3Dgrid.py b/third_party/ascend/unittest/pytest_ut/test_3Dgrid.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_3Dgrid.py rename to third_party/ascend/unittest/pytest_ut/test_3Dgrid.py diff --git a/third_party/ascend/examples/pytest_ut/test_abs.py b/third_party/ascend/unittest/pytest_ut/test_abs.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_abs.py rename to third_party/ascend/unittest/pytest_ut/test_abs.py diff --git a/third_party/ascend/examples/pytest_ut/test_abs_2.py b/third_party/ascend/unittest/pytest_ut/test_abs_2.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_abs_2.py rename to third_party/ascend/unittest/pytest_ut/test_abs_2.py diff --git a/third_party/ascend/examples/pytest_ut/test_acos.py b/third_party/ascend/unittest/pytest_ut/test_acos.py similarity index 97% rename from third_party/ascend/examples/pytest_ut/test_acos.py rename to third_party/ascend/unittest/pytest_ut/test_acos.py index bdbee9102..18c0e6bcb 100644 --- a/third_party/ascend/examples/pytest_ut/test_acos.py +++ b/third_party/ascend/unittest/pytest_ut/test_acos.py @@ -22,7 +22,7 @@ import triton import triton.language as tl -import triton.language.extra.ascend.libdevice as libdevice +import triton.language.extra.cann.libdevice as libdevice import test_common import torch diff --git a/third_party/ascend/examples/pytest_ut/test_acosh.py b/third_party/ascend/unittest/pytest_ut/test_acosh.py similarity index 97% rename from third_party/ascend/examples/pytest_ut/test_acosh.py rename to third_party/ascend/unittest/pytest_ut/test_acosh.py index 5349b4599..78705b33b 100644 --- a/third_party/ascend/examples/pytest_ut/test_acosh.py +++ b/third_party/ascend/unittest/pytest_ut/test_acosh.py @@ -22,7 +22,7 @@ import triton import triton.language as tl -import triton.language.extra.ascend.libdevice as libdevice +import triton.language.extra.cann.libdevice as libdevice import test_common import torch diff --git a/third_party/ascend/examples/pytest_ut/test_add.py b/third_party/ascend/unittest/pytest_ut/test_add.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_add.py rename to third_party/ascend/unittest/pytest_ut/test_add.py diff --git a/third_party/ascend/examples/autotune_cases/01-vector-add.py b/third_party/ascend/unittest/pytest_ut/test_add_mindspore.py similarity index 70% rename from third_party/ascend/examples/autotune_cases/01-vector-add.py rename to third_party/ascend/unittest/pytest_ut/test_add_mindspore.py index bff76e891..5d5a1638f 100644 --- a/third_party/ascend/examples/autotune_cases/01-vector-add.py +++ b/third_party/ascend/unittest/pytest_ut/test_add_mindspore.py @@ -17,23 +17,17 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. -""" -Vector Add -============= -""" import os - -import torch -import torch_npu import triton import triton.language as tl -from triton.testing import do_bench_npu +import numpy as np +import mindspore +import pytest + +pytestmark = pytest.mark.backend("mindspore") -# split_params={"x": "BLOCK_SIZE"}, tiling_params={}, low_dims=["x"] -# persistent_reduction=False, dual_reduction=False -@triton.autotune(configs=[], hints={"enable_ascend_autotune": True}, key=["n_elements"]) @triton.jit def add_kernel(x_ptr, # *Pointer* to first input vector. y_ptr, # *Pointer* to second input vector. @@ -62,34 +56,29 @@ def add_kernel(x_ptr, # *Pointer* to first input vector. tl.store(output_ptr + offsets, output, mask=mask) -def add_torch(x, y): - return x + y - - -def add_autotune(x, y): - output = torch.empty_like(x) +def add(x: mindspore.Tensor, y: mindspore.Tensor): + output = mindspore.mint.empty_like(x) n_elements = output.numel() - grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]), ) - add_kernel[grid](x, y, output, n_elements) + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) return output -def test_add(size: int): - os.environ["TRITON_BENCH_METHOD"] = ( - "npu" # use torch_npu.profiler to get calculating time - ) - x = torch.rand(size, device="npu") - y = torch.rand(size, device="npu") - - output_torch = add_torch(x, y) - output_triton = add_autotune(x, y) - assert torch.allclose(output_triton, output_torch) - - time_eager = do_bench_npu(lambda: add_torch(x, y)) - time_triton = do_bench_npu(lambda: add_autotune(x, y)) - assert (time_eager / time_triton) >= 0.8 - print(f"Vector Add {size} PASSED!") +def add_mindspore(x, y): + return x + y -if __name__ == "__main__": - test_add(98432) +@pytest.mark.parametrize('param_list', [ + ['float32', (2, 4096, 8)], + ['float16', (2, 4096, 8)], +]) +def test_add_mindspore(param_list): + os.environ["TRITON_BACKEND"] = "mindspore" + dtype, shape = param_list + mindspore.set_seed(0) + x = mindspore.ops.randn(shape, dtype=eval('mindspore.' + dtype)) + y = mindspore.ops.randn(shape, dtype=eval('mindspore.' + dtype)) + output_triton = add(x, y) + output_mindspore = add_mindspore(x, y) + assert np.allclose(output_triton.asnumpy(), output_mindspore.asnumpy(), rtol=1e-3, atol=1e-3) + del os.environ["TRITON_BACKEND"] diff --git a/third_party/ascend/examples/pytest_ut/test_advance.py b/third_party/ascend/unittest/pytest_ut/test_advance.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_advance.py rename to third_party/ascend/unittest/pytest_ut/test_advance.py diff --git a/third_party/ascend/unittest/pytest_ut/test_alloc.py b/third_party/ascend/unittest/pytest_ut/test_alloc.py new file mode 100644 index 000000000..a1b2b4360 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_alloc.py @@ -0,0 +1,79 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import os + +import pytest +import triton +import triton.language as tl +from triton.compiler.compiler import ASTSource +from triton.compiler.code_generator import ast_to_ttir +import triton.extension.buffer.language as bl +import triton.language.extra.cann.extension as al +from triton._C.libtriton import ir, buffer_ir +from triton._C.libtriton.ascend import ir as ascend_ir + +os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = "0" + + +class Options: + num_warps = 4 + num_stages = 3 + num_ctas = 1 + cluster_dims = (1, 1, 1) + enable_fp_fusion = True + debug = False + + +def compile_kernel(kernel, signature, constants): + """Helper to compile a kernel to MLIR.""" + src = ASTSource(kernel, signature, constants) + context = ir.context() + ir.load_dialects(context) + buffer_ir.load_dialects(context) + ascend_ir.load_dialects(context) + module = ast_to_ttir(kernel, src, context, Options(), {"create_address_space": al.semantic.create_address_space}, + {}) + return str(module) + + +# ============== Kernel definitions ============== + + +@triton.jit +def allocate_local_buffer(XBLOCK: tl.constexpr): + # this statement has no effect, just to test the builder + bl.alloc(tl.float32, [XBLOCK]) + bl.alloc(tl.float32, [XBLOCK, XBLOCK], al.ascend_address_space.UB) + bl.alloc(tl.float32, [XBLOCK, XBLOCK], al.ascend_address_space.L1) + bl.alloc(tl.float32, [XBLOCK, XBLOCK], al.ascend_address_space.L0A) + bl.alloc(tl.float32, [XBLOCK, XBLOCK], al.ascend_address_space.L0B) + bl.alloc(tl.float32, [XBLOCK, XBLOCK], al.ascend_address_space.L0C) + + +# ============== Main for manual testing ============== + +if __name__ == "__main__": + print("=" * 60) + print("Test 1: Nested Scopes") + print("=" * 60) + mlir = compile_kernel(allocate_local_buffer, {}, {"XBLOCK": 256}) + print(f"✅ Generated MLIR ({len(mlir)} chars):\n") + print(mlir) diff --git a/third_party/ascend/examples/pytest_ut/test_and.py b/third_party/ascend/unittest/pytest_ut/test_and.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_and.py rename to third_party/ascend/unittest/pytest_ut/test_and.py diff --git a/third_party/ascend/unittest/pytest_ut/test_annotations.py b/third_party/ascend/unittest/pytest_ut/test_annotations.py new file mode 100644 index 000000000..d5fb421ee --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_annotations.py @@ -0,0 +1,71 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# Copyright 2018-2020 Philippe Tillet +# Copyright 2020-2022 OpenAI +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +from __future__ import annotations +import torch +import triton +import triton.language as tl +import pytest + + +def annotated_function(return_type=None, **arg_types): + """A decorator to add annotations to a function.""" + + def decorator(func): + func.__annotations__ = {**arg_types, 'return': return_type} + return func + + return decorator + + +# Test integer annotations +@pytest.mark.parametrize(("signed", "width"), [ + (signed, width) for signed in [False, True]\ + for width in [8, 16, 32, 64] +] + [(False, 1)] + ) +def test_int_annotation(signed, width, device="npu"): + + @triton.jit + @annotated_function(X=torch.tensor, v=f"tl.{'' if signed else 'u'}int{width}") + def _kernel(X, v): + tl.store(X, v) + + h = _kernel[(1, )](torch.empty(1, device=device), 3) + pfx = 'si' if signed else 'ui' + assert f'%arg1: i{width}' in h.asm["ttir"] + assert f'arith.{pfx}tofp' in h.asm["ttir"] + + +# Test that unknown annotations do not emit an error +def test_unknown_annotation(device="npu"): + + @triton.jit + def _kernel(X: torch.Tensor, N: int, BLOCK_SIZE: tl.constexpr): + pass + + x = torch.empty(1, device=device) + _kernel[(1, )](x, x.shape[0], 32) + try: + _kernel[(1, )](x.shape[0], x.shape[0], 32) + except AttributeError: + pass diff --git a/third_party/ascend/examples/pytest_ut/test_arange.py b/third_party/ascend/unittest/pytest_ut/test_arange.py similarity index 80% rename from third_party/ascend/examples/pytest_ut/test_arange.py rename to third_party/ascend/unittest/pytest_ut/test_arange.py index f8cceae49..5c049bec8 100644 --- a/third_party/ascend/examples/pytest_ut/test_arange.py +++ b/third_party/ascend/unittest/pytest_ut/test_arange.py @@ -98,30 +98,16 @@ def test_case_access(param_list): @pytest.mark.parametrize('invalid_param_list', [ [0, 10000000], + [1024, 128], ]) -@test_common.raises_with_match(triton.compiler.errors.CompilationError, - "end - start must be less than or equal to TRITON_MAX_TENSOR_NUMEL = 1048576") def test_arange_invalid_range(invalid_param_list): start, end = invalid_param_list shape = [end - start] block = end - start - - y_cal = torch.zeros(shape, dtype=torch.int32).npu() - - triton_arange[(1, )](y_cal, START=start, END=end, BLOCK=block) - - -@pytest.mark.parametrize('invalid_param_list', [ - [1024, 128], -]) -@test_common.raises_with_match(triton.compiler.errors.CompilationError, - "arange's end argument must be greater than the start argument") -def test_arange_invalid_revinput(invalid_param_list): - start, end = invalid_param_list - range = abs(end - start) - shape = [range] - block = range - - y_cal = torch.zeros(shape, dtype=torch.int32).npu() - - triton_arange[(1, )](y_cal, START=start, END=end, BLOCK=block) + flag = False + try: + y_cal = torch.zeros(shape, dtype=torch.int32).npu() + triton_arange[(1, )](y_cal, START=start, END=end, BLOCK=block) + except Exception as e: + flag = True + assert flag diff --git a/third_party/ascend/unittest/pytest_ut/test_ascend_barrier.py b/third_party/ascend/unittest/pytest_ut/test_ascend_barrier.py new file mode 100755 index 000000000..78573ac16 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_ascend_barrier.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 +import os + +import pytest +import triton +import triton.language as tl +import triton.language.extra.cann.extension as al +from triton.compiler.compiler import ASTSource +from triton.compiler.code_generator import ast_to_ttir +from triton._C.libtriton import ir +from triton._C.libtriton.ascend import ir as ascend_ir + +os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = "0" + + +class Options: + num_warps = 4 + num_stages = 3 + num_ctas = 1 + cluster_dims = (1, 1, 1) + enable_fp_fusion = True + debug = False + + +def compile_kernel(kernel, signature, constants): + """Helper to compile a kernel to MLIR.""" + src = ASTSource(kernel, signature, constants) + context = ir.context() + ir.load_dialects(context) + ascend_ir.load_dialects(context) + module = ast_to_ttir(kernel, src, context, Options(), {}, {}) + return str(module) + + +# ============== Kernel definitions ============== + + +@triton.jit +def kernel_debug_barrier(x_ptr, y_ptr, out_ptr, n, BLOCK: tl.constexpr): + """Test debug barrier.""" + i = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) + with al.scope(core_mode="vector"): + al.debug_barrier(al.SYNC_IN_VF.VV_ALL) + x = tl.load(x_ptr + i, mask=i < n) + y = tl.load(y_ptr + i, mask=i < n) + result = x + y + tl.store(out_ptr + i, result, mask=i < n) + + +# ============== Pytest tests ============== + + +def test_debug_barrier(): + """Test debug barrier generates.""" + mlir = compile_kernel( + kernel_debug_barrier, + {"x_ptr": "*fp32", "y_ptr": "*fp32", "out_ptr": "*fp32", "n": "i32"}, + {"BLOCK": 256}, + ) + assert "annotation.mark" in mlir and "VV_ALL" in mlir + + +# ============== Main for manual testing ============== + +if __name__ == "__main__": + print("=" * 60) + print("Test: debug barrier") + print("=" * 60) + mlir = compile_kernel( + kernel_debug_barrier, + {"x_ptr": "*fp32", "y_ptr": "*fp32", "out_ptr": "*fp32", "n": "i32"}, + {"BLOCK": 256}, + ) + print(f"✅ Generated MLIR ({len(mlir)} chars):\n") + print(mlir) diff --git a/third_party/ascend/examples/pytest_ut/test_asin.py b/third_party/ascend/unittest/pytest_ut/test_asin.py similarity index 95% rename from third_party/ascend/examples/pytest_ut/test_asin.py rename to third_party/ascend/unittest/pytest_ut/test_asin.py index 9beadf733..8d9d7249e 100644 --- a/third_party/ascend/examples/pytest_ut/test_asin.py +++ b/third_party/ascend/unittest/pytest_ut/test_asin.py @@ -1,90 +1,90 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import pytest -import triton -import torch -import triton.language as tl -import triton.language.extra.ascend.libdevice as libdevice -import test_common - - -@triton.jit -def asin_kernel(x_ptr, y_ptr, n_elements, BLOCK_SIZE: tl.constexpr): - pid = tl.program_id(axis=0) - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - - mask = offsets < n_elements - - x = tl.load(x_ptr + offsets, mask=mask) - - y = libdevice.asin(x) - - tl.store(y_ptr + offsets, y, mask=mask) - - -@pytest.mark.parametrize('shape', [ - (12, 16), -]) -@pytest.mark.parametrize('dtype', ['float32']) -def test_asin(shape, dtype): - n_elements = shape[0] * shape[1] - - x = test_common.generate_tensor(shape, dtype).npu() - - # Ensure to include some boundary cases - x[0, 0] = 0.0 - x[0, 1] = 0.5 - x[0, 2] = -0.5 - x[0, 3] = 1.0 - x[0, 4] = -1.0 - x[0, 5] = 0.707 # sin(π/4) - x[0, 6] = 0.866 # sin(π/3) - - # Add some out-of-range values - x[0, 7] = 1.1 - x[0, 8] = -1.1 - - y = torch.empty_like(x) - - BLOCK_SIZE = 192 - grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) - - asin_kernel[grid](x, y, n_elements, BLOCK_SIZE=BLOCK_SIZE) - - expected = torch.asin(x) - - # Check the accuracy for values within the effective range. - valid_mask = (x >= -1) & (x <= 1) - - if torch.any(valid_mask): - valid_y = y[valid_mask] - valid_expected = expected[valid_mask] - - torch.testing.assert_close(valid_y, valid_expected, rtol=1e-3, atol=1e-3) - - # Check if values outside the range return NaN - invalid_mask = (x < -1) | (x > 1) - if torch.any(invalid_mask): - invalid_y = y[invalid_mask] - assert torch.all(torch.isnan(invalid_y)), "Invalid inputs should return NaN" - - print("✓ ASIN test PASSED!") +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import pytest +import triton +import torch +import triton.language as tl +import triton.language.extra.cann.libdevice as libdevice +import test_common + + +@triton.jit +def asin_kernel(x_ptr, y_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + mask = offsets < n_elements + + x = tl.load(x_ptr + offsets, mask=mask) + + y = libdevice.asin(x) + + tl.store(y_ptr + offsets, y, mask=mask) + + +@pytest.mark.parametrize('shape', [ + (12, 16), +]) +@pytest.mark.parametrize('dtype', ['float32']) +def test_asin(shape, dtype): + n_elements = shape[0] * shape[1] + + x = test_common.generate_tensor(shape, dtype).npu() + + # Ensure to include some boundary cases + x[0, 0] = 0.0 + x[0, 1] = 0.5 + x[0, 2] = -0.5 + x[0, 3] = 1.0 + x[0, 4] = -1.0 + x[0, 5] = 0.707 # sin(π/4) + x[0, 6] = 0.866 # sin(π/3) + + # Add some out-of-range values + x[0, 7] = 1.1 + x[0, 8] = -1.1 + + y = torch.empty_like(x) + + BLOCK_SIZE = 192 + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + + asin_kernel[grid](x, y, n_elements, BLOCK_SIZE=BLOCK_SIZE) + + expected = torch.asin(x) + + # Check the accuracy for values within the effective range. + valid_mask = (x >= -1) & (x <= 1) + + if torch.any(valid_mask): + valid_y = y[valid_mask] + valid_expected = expected[valid_mask] + + torch.testing.assert_close(valid_y, valid_expected, rtol=1e-3, atol=1e-3) + + # Check if values outside the range return NaN + invalid_mask = (x < -1) | (x > 1) + if torch.any(invalid_mask): + invalid_y = y[invalid_mask] + assert torch.all(torch.isnan(invalid_y)), "Invalid inputs should return NaN" + + print("✓ ASIN test PASSED!") diff --git a/third_party/ascend/examples/pytest_ut/test_asinh.py b/third_party/ascend/unittest/pytest_ut/test_asinh.py similarity index 97% rename from third_party/ascend/examples/pytest_ut/test_asinh.py rename to third_party/ascend/unittest/pytest_ut/test_asinh.py index 3f7809808..38865e92b 100644 --- a/third_party/ascend/examples/pytest_ut/test_asinh.py +++ b/third_party/ascend/unittest/pytest_ut/test_asinh.py @@ -22,7 +22,7 @@ import triton import triton.language as tl -import triton.language.extra.ascend.libdevice as libdevice +import triton.language.extra.cann.libdevice as libdevice import test_common import torch diff --git a/third_party/ascend/examples/pytest_ut/test_asm.py b/third_party/ascend/unittest/pytest_ut/test_asm.py similarity index 96% rename from third_party/ascend/examples/pytest_ut/test_asm.py rename to third_party/ascend/unittest/pytest_ut/test_asm.py index da9898b97..02e69bddd 100644 --- a/third_party/ascend/examples/pytest_ut/test_asm.py +++ b/third_party/ascend/unittest/pytest_ut/test_asm.py @@ -1,52 +1,52 @@ -import triton -import triton.language as tl -import numpy as np -import torch -import pytest -import test_common - - -def torch_add(x, y): - res = x + y - return res - - -@triton.jit -def triton_asm_add( - x_ptr, - y_ptr, - output_ptr, - n_elements, - BLOCK_SIZE: tl.constexpr, -): - pid = tl.program_id(axis=0) - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - x = tl.load(x_ptr + offsets, mask=mask) - y = tl.load(y_ptr + offsets, mask=mask) - output = tl.inline_asm_elementwise( - asm=""" - ADD.s64 $0, $1, $2 - """, - constraints=("=l,l,l"), - args=[x, y], - dtype=tl.int64, - is_pure=True, - pack=1, - ) - tl.store(output_ptr + offsets, output, mask=mask) - - -@pytest.mark.parametrize('param_list', [ - ['int64', 4096, 1024], -]) -def test_case(param_list): - dtype, length, block_size = param_list - ncore = length // block_size - x = test_common.generate_tensor((length, ), dtype).npu() - y = test_common.generate_tensor((length, ), dtype).npu() - res_ref = torch_add(x, y) - res_cal = torch.zeros((length, ), dtype=eval('torch.' + dtype)).npu() - triton_asm_add[(ncore, )](x, y, res_cal, length, BLOCK_SIZE=block_size) - test_common.validate_cmp(dtype, res_cal, res_ref) +import triton +import triton.language as tl +import numpy as np +import torch +import pytest +import test_common + + +def torch_add(x, y): + res = x + y + return res + + +@triton.jit +def triton_asm_add( + x_ptr, + y_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = tl.inline_asm_elementwise( + asm=""" + ADD.s64 $0, $1, $2 + """, + constraints=("=l,l,l"), + args=[x, y], + dtype=tl.int64, + is_pure=True, + pack=1, + ) + tl.store(output_ptr + offsets, output, mask=mask) + + +@pytest.mark.parametrize('param_list', [ + ['int64', 4096, 1024], +]) +def test_case(param_list): + dtype, length, block_size = param_list + ncore = length // block_size + x = test_common.generate_tensor((length, ), dtype).npu() + y = test_common.generate_tensor((length, ), dtype).npu() + res_ref = torch_add(x, y) + res_cal = torch.zeros((length, ), dtype=eval('torch.' + dtype)).npu() + triton_asm_add[(ncore, )](x, y, res_cal, length, BLOCK_SIZE=block_size) + test_common.validate_cmp(dtype, res_cal, res_ref) diff --git a/third_party/ascend/examples/pytest_ut/test_associative_scan.py b/third_party/ascend/unittest/pytest_ut/test_associative_scan.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_associative_scan.py rename to third_party/ascend/unittest/pytest_ut/test_associative_scan.py diff --git a/third_party/ascend/examples/pytest_ut/test_associative_scan_multi_input.py b/third_party/ascend/unittest/pytest_ut/test_associative_scan_multi_input.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_associative_scan_multi_input.py rename to third_party/ascend/unittest/pytest_ut/test_associative_scan_multi_input.py diff --git a/third_party/ascend/unittest/pytest_ut/test_assume.py b/third_party/ascend/unittest/pytest_ut/test_assume.py new file mode 100644 index 000000000..b18c0f2db --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_assume.py @@ -0,0 +1,51 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import triton +import triton.language as tl +import pytest +import test_common + + +@triton.jit +def triton_assume(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + tl.assume((XBLOCK & (XBLOCK - 1)) == 0) + for xoffset_sub in range(0, XBLOCK, XBLOCK_SUB): + xindex = xoffset + xoffset_sub + tl.arange(0, XBLOCK_SUB)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(in_ptr0 + (x0), xmask) + num = tl.sum(tmp0) + tl.assume(num > 0) + tmp2 = tmp0 + tl.store(out_ptr0 + (xindex), tmp2, xmask) + + +@pytest.mark.parametrize('param_list', [ + ['float32', (2, 4096, 8), 2, 32768, 1024], +]) +def test_assume(param_list): + dtype, shape, ncore, xblock, xblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + y_ref = x0 + y_cal = test_common.generate_tensor(shape, dtype).npu() + triton_assume[(ncore, )](x0, y_cal, x0.numel(), xblock, xblock_sub) + test_common.validate_cmp(dtype, y_cal, y_ref) diff --git a/third_party/ascend/examples/pytest_ut/test_atan.py b/third_party/ascend/unittest/pytest_ut/test_atan.py similarity index 98% rename from third_party/ascend/examples/pytest_ut/test_atan.py rename to third_party/ascend/unittest/pytest_ut/test_atan.py index cc95df770..eca6b31f0 100644 --- a/third_party/ascend/examples/pytest_ut/test_atan.py +++ b/third_party/ascend/unittest/pytest_ut/test_atan.py @@ -27,7 +27,7 @@ import torch import torch_npu -import triton.language.extra.ascend.libdevice as libdevice +import triton.language.extra.cann.libdevice as libdevice def standard_unary(x0, dtype): diff --git a/third_party/ascend/examples/pytest_ut/test_atan2.py b/third_party/ascend/unittest/pytest_ut/test_atan2.py similarity index 97% rename from third_party/ascend/examples/pytest_ut/test_atan2.py rename to third_party/ascend/unittest/pytest_ut/test_atan2.py index 74be45616..121866037 100644 --- a/third_party/ascend/examples/pytest_ut/test_atan2.py +++ b/third_party/ascend/unittest/pytest_ut/test_atan2.py @@ -27,7 +27,7 @@ import torch import torch_npu -from triton.language import math +import triton.language.extra.cann.libdevice as libdevice def standard_unary(x0, y0, dtype): @@ -45,7 +45,7 @@ def triton_elementwise_unary(in_ptr0, in_ptr1, out_ptr0, N: tl.constexpr, NUMEL: idx_block = tl.arange(0, NUMEL) x = tl.load(in_ptr0 + idx_block, mask=idx_block < N) y = tl.load(in_ptr1 + idx_block, mask=idx_block < N) - ret = math.atan2(y, x) + ret = libdevice.atan2(y, x) tl.store(out_ptr0 + idx_block, ret, mask=idx_block < N) diff --git a/third_party/ascend/examples/pytest_ut/test_atanh.py b/third_party/ascend/unittest/pytest_ut/test_atanh.py similarity index 98% rename from third_party/ascend/examples/pytest_ut/test_atanh.py rename to third_party/ascend/unittest/pytest_ut/test_atanh.py index 5e8126191..a8d8ad3fe 100644 --- a/third_party/ascend/examples/pytest_ut/test_atanh.py +++ b/third_party/ascend/unittest/pytest_ut/test_atanh.py @@ -22,7 +22,7 @@ import triton import triton.language as tl -import triton.language.extra.ascend.libdevice as libdevice +import triton.language.extra.cann.libdevice as libdevice import test_common import torch diff --git a/third_party/ascend/examples/pytest_ut/test_atomic_add.py b/third_party/ascend/unittest/pytest_ut/test_atomic_add.py similarity index 85% rename from third_party/ascend/examples/pytest_ut/test_atomic_add.py rename to third_party/ascend/unittest/pytest_ut/test_atomic_add.py index 78f2f0923..10aa30520 100644 --- a/third_party/ascend/examples/pytest_ut/test_atomic_add.py +++ b/third_party/ascend/unittest/pytest_ut/test_atomic_add.py @@ -77,6 +77,30 @@ def test_atomic_add(param_list): test_common.validate_cmp(dtype, x1, x1_ref) +@pytest.mark.parametrize('param_list', [ + ['int16', (32, 32), 1], + ['int32', (32, 32), 1], + ['float32', (32, 32), 1], + ['float16', (64, 64), 1], +]) +def test_atomic_add_return_value(param_list): + dtype, shape, ncore = param_list + block_size = shape[0] * shape[1] / ncore + split_size = shape[0] // ncore + x0_value = 3 + x0 = torch.full(shape, x0_value, dtype=eval(f'torch.{dtype}')).npu() + x1 = torch.full((split_size, shape[1]), 2, dtype=eval(f'torch.{dtype}')).npu() + y = torch.full((split_size, shape[1]), -10, dtype=eval(f'torch.{dtype}')).npu() + + y_ref = x1 + 0 + x1_ref = x1 + ncore * x0_value + + n_elements = shape[0] * shape[1] + atomic_add[ncore, 1, 1](x0, x1, y, n_elements, BLOCK_SIZE=split_size * shape[1]) + test_common.validate_cmp(dtype, x1, x1_ref) + test_common.validate_cmp(dtype, y, y_ref) + + @triton.jit def atomic_add_2d(in_ptr0, out_ptr0, out_ptr1, numel_0, numel_1, BLOCK_SIZE_0: tl.constexpr, BLOCK_SIZE_1: tl.constexpr): diff --git a/third_party/ascend/examples/pytest_ut/test_atomic_and.py b/third_party/ascend/unittest/pytest_ut/test_atomic_and.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_atomic_and.py rename to third_party/ascend/unittest/pytest_ut/test_atomic_and.py diff --git a/third_party/ascend/examples/pytest_ut/test_atomic_cas.py b/third_party/ascend/unittest/pytest_ut/test_atomic_cas.py similarity index 68% rename from third_party/ascend/examples/pytest_ut/test_atomic_cas.py rename to third_party/ascend/unittest/pytest_ut/test_atomic_cas.py index 547aa8489..3e3b3a6fc 100644 --- a/third_party/ascend/examples/pytest_ut/test_atomic_cas.py +++ b/third_party/ascend/unittest/pytest_ut/test_atomic_cas.py @@ -75,3 +75,38 @@ def test_atomic_cas(param_list): n_elements = shape[0] * shape[1] atomic_cas[ncore, 1, 1](val, cmp, pointer, pointer_old, n_elements, BLOCK_SIZE=split_size * shape[1]) test_common.validate_cmp(dtype, pointer, pointer_ref) + + +@pytest.mark.parametrize('param_list', [ + ['int16', (8, 8), 1], + ['int32', (32, 32), 1], + ['float32', (32, 32), 1], +]) +def test_atomic_cas_return_value(param_list): + dtype, shape, ncore = param_list + block_size = shape[0] * shape[1] // ncore + split_size = shape[0] // ncore + + import random + cmp_val = [random.randint(0, 10) for _ in range(ncore)] + + cmp = torch.ones(split_size, shape[1], dtype=eval(f'torch.{dtype}')).to().npu() * cmp_val[0] + for i in range(1, ncore): + append = torch.ones(split_size, shape[1], dtype=eval(f'torch.{dtype}')).to().npu() * cmp_val[i] + cmp = torch.cat([cmp, append], dim=0) + + val = torch.randint(low=0, high=10, size=shape, dtype=eval(f'torch.{dtype}')).npu() + + pointer = torch.randint(low=0, high=10, size=(split_size, shape[1]), dtype=eval(f'torch.{dtype}')).npu() + pointer_old_ref = pointer.clone() + pointer_old = torch.full_like(pointer, -10).npu() + pointer_ref = pointer.clone() + + for i in range(ncore): + val_subview = val[(i * split_size):((i + 1) * split_size)] + pointer_ref = torch.where(pointer_ref == cmp_val[i], val_subview, pointer_ref) + + n_elements = shape[0] * shape[1] + atomic_cas[ncore, 1, 1](val, cmp, pointer, pointer_old, n_elements, BLOCK_SIZE=split_size * shape[1]) + test_common.validate_cmp(dtype, pointer, pointer_ref) + test_common.validate_cmp(dtype, pointer_old, pointer_old_ref) diff --git a/third_party/ascend/examples/pytest_ut/test_atomic_max.py b/third_party/ascend/unittest/pytest_ut/test_atomic_max.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_atomic_max.py rename to third_party/ascend/unittest/pytest_ut/test_atomic_max.py diff --git a/third_party/ascend/examples/pytest_ut/test_atomic_min.py b/third_party/ascend/unittest/pytest_ut/test_atomic_min.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_atomic_min.py rename to third_party/ascend/unittest/pytest_ut/test_atomic_min.py diff --git a/third_party/ascend/examples/pytest_ut/test_atomic_or.py b/third_party/ascend/unittest/pytest_ut/test_atomic_or.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_atomic_or.py rename to third_party/ascend/unittest/pytest_ut/test_atomic_or.py diff --git a/third_party/ascend/examples/pytest_ut/test_atomic_xchg.py b/third_party/ascend/unittest/pytest_ut/test_atomic_xchg.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_atomic_xchg.py rename to third_party/ascend/unittest/pytest_ut/test_atomic_xchg.py diff --git a/third_party/ascend/examples/pytest_ut/test_atomic_xor.py b/third_party/ascend/unittest/pytest_ut/test_atomic_xor.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_atomic_xor.py rename to third_party/ascend/unittest/pytest_ut/test_atomic_xor.py diff --git a/third_party/ascend/examples/pytest_ut/test_attn_cp.py b/third_party/ascend/unittest/pytest_ut/test_attn_cp.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_attn_cp.py rename to third_party/ascend/unittest/pytest_ut/test_attn_cp.py diff --git a/third_party/ascend/unittest/pytest_ut/test_bind_buffer.py b/third_party/ascend/unittest/pytest_ut/test_bind_buffer.py new file mode 100644 index 000000000..a8e59d076 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_bind_buffer.py @@ -0,0 +1,67 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +#!/usr/bin/env python3 +import os + +import triton +import triton.language as tl +import triton.extension.buffer.language as bl +import triton.language.extra.cann.extension as al +from triton.compiler.compiler import ASTSource +from triton.compiler.code_generator import ast_to_ttir +from triton._C.libtriton import ir +from triton._C.libtriton.ascend import ir as ascend_ir + +os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = "0" + + +class Options: + num_warps = 4 + num_stages = 3 + num_ctas = 1 + cluster_dims = (1, 1, 1) + enable_fp_fusion = True + debug = False + + +def compile_kernel(kernel, signature, constants): + """Helper to compile a kernel to MLIR.""" + src = ASTSource(kernel, signature, constants) + context = ir.context() + ir.load_dialects(context) + ascend_ir.load_dialects(context) + module = ast_to_ttir(kernel, src, context, Options(), {}, {}) + return str(module) + + +@triton.jit +def bind_buffer(): + alloc = bl.alloc(tl.float32, [32, 32], al.ascend_address_space.UB) + tensor = tl.full((32, 32), 0, dtype=tl.float32) + bl.to_buffer(tensor, bind_buffer=alloc) + + +# ============== Main for manual testing ============== + +if __name__ == "__main__": + mlir = compile_kernel(bind_buffer, {}, {}) + assert len(mlir) > 0 + print(mlir) diff --git a/third_party/ascend/examples/pytest_ut/test_block_ptr.py b/third_party/ascend/unittest/pytest_ut/test_block_ptr.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_block_ptr.py rename to third_party/ascend/unittest/pytest_ut/test_block_ptr.py diff --git a/third_party/ascend/examples/pytest_ut/test_broadcast.py b/third_party/ascend/unittest/pytest_ut/test_broadcast.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_broadcast.py rename to third_party/ascend/unittest/pytest_ut/test_broadcast.py diff --git a/third_party/ascend/examples/pytest_ut/test_broadcast_op.py b/third_party/ascend/unittest/pytest_ut/test_broadcast_op.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_broadcast_op.py rename to third_party/ascend/unittest/pytest_ut/test_broadcast_op.py diff --git a/third_party/ascend/unittest/pytest_ut/test_cannonicalize_tl_where.py b/third_party/ascend/unittest/pytest_ut/test_cannonicalize_tl_where.py new file mode 100644 index 000000000..80a259d76 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_cannonicalize_tl_where.py @@ -0,0 +1,124 @@ +import pytest + +import triton +import triton.language as tl +import test_common + +import torch +import torch_npu + +types_all = [ + (torch.float32, 'float32'), +] + +shapes_common = [(128, 256), (127, 256), (127, 16), (129, 256), (77, 1024), (69, 512), (512, 512)] + +block_size = [128, 256, 1024] + + +def ceil_div(a, b): + return (a + b - 1) // b + + +def profiler_wrapper(fn, *args): + result_path = "./result_profiling_tl_where" + skip_first = 10 + wait = 0 + warmup = 3 + active = 30 + repeat = 1 + stream = torch.npu.current_stream() + experimental_config = torch_npu.profiler._ExperimentalConfig( + aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, + profiler_level=torch_npu.profiler.ProfilerLevel.Level1, l2_cache=False, data_simplification=False) + with torch_npu.profiler.profile( + activities=[torch_npu.profiler.ProfilerActivity.CPU, torch_npu.profiler.ProfilerActivity.NPU], + schedule=torch_npu.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=repeat, + skip_first=skip_first), + on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(result_path), record_shapes=True, + profile_memory=False, with_stack=False, with_flops=False, with_modules=False, + experimental_config=experimental_config) as prof: + stream.synchronize() + for _ in range(skip_first + (wait + warmup + active) * repeat): + fn(*args) + prof.step() + stream.synchronize() + + +@triton.jit +def tl_where_kernel( + in_ptr, + output_ptr, + N: tl.constexpr, + M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + pid = tl.program_id(axis=0) + offset = tl.multiple_of(pid * BLOCK_SIZE_N, N) + x1 = (offset + tl.arange(0, BLOCK_SIZE_N)) // N + mask1 = tl.where(x1 < M, 1, 0).to(tl.int1) + data = tl.load(in_ptr + x1 * N, mask=mask1, other=0) + x2 = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + tl.store(output_ptr + x2, data) + + +def torch_tl_where(in_tensor): + M = in_tensor.shape[0] // 2 + N = in_tensor.shape[1] + + output = torch.zeros_like(in_tensor) + + first_elements = in_tensor[:M, 0:1] + output[:M] = first_elements.expand(-1, N) + + return output + + +@pytest.mark.parametrize('dtype, sigtype', types_all) +@pytest.mark.parametrize('M, N', shapes_common) +@pytest.mark.parametrize('BLOCK_SIZE_N', block_size) +def test_tl_where(M, N, BLOCK_SIZE_N, dtype, sigtype): + + in_tensor = torch.randn(2 * M, N, dtype=dtype).npu() + + triton_output = torch.zeros_like(in_tensor) + + grid = (ceil_div(2 * M * N, BLOCK_SIZE_N), ) + + tl_where_kernel[grid](in_tensor, triton_output, N=N, M=M, BLOCK_SIZE_N=BLOCK_SIZE_N, optimize_dynamic_offset=False) + + torch_output = torch_tl_where(in_tensor.clone()) + assert torch.allclose(triton_output, torch_output, rtol=1e-5, atol=1e-8) + + +def triton_tl_where(in_tensor, BLOCK_SIZE): + M = in_tensor.shape[0] // 2 + N = in_tensor.shape[1] + + triton_output = torch.zeros_like(in_tensor) + grid = (ceil_div(2 * M * N, BLOCK_SIZE), ) + + tl_where_kernel[grid](in_tensor, triton_output, N=N, M=M, BLOCK_SIZE_N=BLOCK_SIZE, optimize_dynamic_offset=True) + + +def profile_performance_test(M, N, dtype, BLOCK_SIZE): + print(f"\nDetailed performance analysis: M={M}, N={N}, dtype={dtype}, block_size={BLOCK_SIZE}") + + in_tensor = torch.randn(2 * M, N, dtype=dtype).npu() + + def wrapper_func(x): + triton_tl_where(x, BLOCK_SIZE=BLOCK_SIZE) + + # Run performance analysis + profiler_wrapper(wrapper_func, in_tensor) + + +if __name__ == "__main__": + + # Optional: Run detailed profiler test (specific configuration) + profile_performance_test(512, 512, torch.float32, BLOCK_SIZE=1024) + + print("\n" + "=" * 80) + print("Test completed!") + print(f"Detailed performance analysis results saved in: ./result_profiling_tl_where/") + print("=" * 80) diff --git a/third_party/ascend/examples/pytest_ut/test_cast_full.py b/third_party/ascend/unittest/pytest_ut/test_cast_full.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_cast_full.py rename to third_party/ascend/unittest/pytest_ut/test_cast_full.py diff --git a/third_party/ascend/examples/pytest_ut/test_cat.py b/third_party/ascend/unittest/pytest_ut/test_cat.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_cat.py rename to third_party/ascend/unittest/pytest_ut/test_cat.py diff --git a/third_party/ascend/examples/pytest_ut/test_cat_dim.py b/third_party/ascend/unittest/pytest_ut/test_cat_dim.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_cat_dim.py rename to third_party/ascend/unittest/pytest_ut/test_cat_dim.py diff --git a/third_party/ascend/examples/pytest_ut/test_cdiv.py b/third_party/ascend/unittest/pytest_ut/test_cdiv.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_cdiv.py rename to third_party/ascend/unittest/pytest_ut/test_cdiv.py diff --git a/third_party/ascend/examples/pytest_ut/test_ceil.py b/third_party/ascend/unittest/pytest_ut/test_ceil.py similarity index 96% rename from third_party/ascend/examples/pytest_ut/test_ceil.py rename to third_party/ascend/unittest/pytest_ut/test_ceil.py index 2ce0ebf38..faecd4fa8 100644 --- a/third_party/ascend/examples/pytest_ut/test_ceil.py +++ b/third_party/ascend/unittest/pytest_ut/test_ceil.py @@ -54,11 +54,11 @@ def triton_ceil(in_ptr0, out_ptr0, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexp ['float16', (2, 4096, 8), 32, 2048, 64], # ['bfloat16', (2, 4096, 8), 32, 2048, 64], ['float32', (2, 4096, 8), 32, 2048, 64], - # ['int8', (2, 4096, 8), 32, 2048, 64], + ['int8', (2, 4096, 8), 32, 2048, 64], # ['int16', (2, 4096, 8), 32, 2048, 64], # ['int32', (2, 4096, 8), 32, 2048, 64], # ['int64', (2, 4096, 8), 32, 2048, 64], - # ['uint8', (2, 4096, 8), 32, 2048, 64], + ['uint8', (2, 4096, 8), 32, 2048, 64], # ['uint16', (2, 4096, 8), 32, 2048, 64], # ['uint32', (2, 4096, 8), 32, 2048, 64], # ['uint64', (2, 4096, 8), 32, 2048, 64], diff --git a/third_party/ascend/examples/pytest_ut/test_clamp.py b/third_party/ascend/unittest/pytest_ut/test_clamp.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_clamp.py rename to third_party/ascend/unittest/pytest_ut/test_clamp.py diff --git a/third_party/ascend/examples/pytest_ut/test_common.py b/third_party/ascend/unittest/pytest_ut/test_common.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_common.py rename to third_party/ascend/unittest/pytest_ut/test_common.py diff --git a/third_party/ascend/examples/pytest_ut/test_compile_hint.py b/third_party/ascend/unittest/pytest_ut/test_compile_hint.py similarity index 87% rename from third_party/ascend/examples/pytest_ut/test_compile_hint.py rename to third_party/ascend/unittest/pytest_ut/test_compile_hint.py index 98fe5d33a..87b7fb346 100644 --- a/third_party/ascend/examples/pytest_ut/test_compile_hint.py +++ b/third_party/ascend/unittest/pytest_ut/test_compile_hint.py @@ -20,6 +20,7 @@ import triton import triton.language as tl +import triton.language.extra.cann.extension as extension import pytest import test_common @@ -35,12 +36,12 @@ def triton_compile_hint(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr, XBLOCK_ xmask = xindex < xnumel x0 = xindex tmp0 = tl.load(in_ptr0 + (x0), xmask) - tl.compile_hint(tmp0, "hint_a") - tl.multibuffer(tmp0, 2) + extension.compile_hint(tmp0, "hint_a") + extension.multibuffer(tmp0, 2) tmp2 = tmp0 - tl.compile_hint(tmp2, "hint_b", 42) - tl.compile_hint(tmp2, "hint_c", True) - tl.compile_hint(tmp2, "hint_d", [XBLOCK, XBLOCK_SUB]) + extension.compile_hint(tmp2, "hint_b", 42) + extension.compile_hint(tmp2, "hint_c", True) + extension.compile_hint(tmp2, "hint_d", [XBLOCK, XBLOCK_SUB]) tl.store(out_ptr0 + (xindex), tmp2, xmask) diff --git a/third_party/ascend/unittest/pytest_ut/test_copy.py b/third_party/ascend/unittest/pytest_ut/test_copy.py new file mode 100644 index 000000000..f0a2a778b --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_copy.py @@ -0,0 +1,91 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import os +import pytest +import triton +import triton.language as tl +import triton.extension.buffer.language as bl +import triton.language.extra.cann.extension as al +from triton.compiler.compiler import ASTSource +from triton.compiler.code_generator import ast_to_ttir +from triton._C.libtriton import ir, buffer_ir +from triton._C.libtriton.ascend import ir as ascend_ir + +os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = "0" + + +class Options: + num_warps = 4 + num_stages = 3 + num_ctas = 1 + cluster_dims = (1, 1, 1) + enable_fp_fusion = True + debug = False + arch = "Ascend910_95" + + +def compile_kernel(kernel, signature, constants): + """Helper to compile a kernel to MLIR.""" + src = ASTSource(kernel, signature, constants) + context = ir.context() + ir.load_dialects(context) + buffer_ir.load_dialects(context) + ascend_ir.load_dialects(context) + module = ast_to_ttir(kernel, src, context, Options(), {}, {}) + return str(module) + + +@triton.jit +def copy( + A_ptr, + A1_ptr, + M: tl.constexpr, + N: tl.constexpr, +): + offs_a = tl.arange(0, M)[:, None] + offs_b = tl.arange(0, N)[None, :] + + offs_c = (offs_a) * M + (offs_b) + a_ptr = A_ptr + offs_c + a_val = tl.load(a_ptr) + a1_ptr = A1_ptr + offs_c + a1_val = tl.load(a1_ptr) + + add = tl.add(a_val, a1_val) + + add_ub = bl.to_buffer(add, al.ascend_address_space.UB) + A_l1 = bl.alloc(tl.float32, [M, N], al.ascend_address_space.L1) + al.copy_from_ub_to_l1(add_ub, A_l1) + + +# ============== Main for manual testing ============== + +if __name__ == "__main__": + print("=" * 60) + print("Test 1: copy ") + print("=" * 60) + mlir = compile_kernel( + copy, + {"A_ptr": "*fp32", "A1_ptr": "*fp32"}, + {"M": 16, "N": 16}, + ) + print(f"✅ Generated MLIR ({len(mlir)} chars):\n") + print(mlir) diff --git a/third_party/ascend/examples/pytest_ut/test_copysign.py b/third_party/ascend/unittest/pytest_ut/test_copysign.py similarity index 95% rename from third_party/ascend/examples/pytest_ut/test_copysign.py rename to third_party/ascend/unittest/pytest_ut/test_copysign.py index 21f63c186..40aa235f2 100644 --- a/third_party/ascend/examples/pytest_ut/test_copysign.py +++ b/third_party/ascend/unittest/pytest_ut/test_copysign.py @@ -1,108 +1,108 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import pytest -import triton -import torch -import triton.language as tl -import triton.language.extra.ascend.libdevice as libdevice -import test_common - - -@triton.jit -def copysign_kernel(x_ptr, y_ptr, z_ptr, n_elements, BLOCK_SIZE: tl.constexpr): - pid = tl.program_id(axis=0) - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - - mask = offsets < n_elements - - x = tl.load(x_ptr + offsets, mask=mask) - y = tl.load(y_ptr + offsets, mask=mask) - - z = libdevice.copysign(x, y) - - tl.store(z_ptr + offsets, z, mask=mask) - - -@pytest.mark.parametrize('shape', [ - (12, 16), -]) -@pytest.mark.parametrize('dtype', ['float32']) -def test_copysign(shape, dtype): - n_elements = shape[0] * shape[1] - x = test_common.generate_tensor(shape, dtype).npu() - y = test_common.generate_tensor(shape, dtype).npu() - - # Ensure to include some boundary cases - x[0, 0] = 3.14 - y[0, 0] = 1.0 # The result should be 3.14 - - x[0, 1] = 3.14 - y[0, 1] = -1.0 # The result should be -3.14 - - x[0, 2] = -3.14 - y[0, 2] = 1.0 # The result should be 3.14 - - x[0, 3] = -3.14 - y[0, 3] = -1.0 # The result should be -3.14 - - x[0, 4] = 0.0 - y[0, 4] = -1.0 # The result should be -0.0 - - x[0, 5] = 0.0 - y[0, 5] = 1.0 # The result should be 0.0 - - x[0, 6] = 3.14 - y[0, 6] = 0.0 # The result should be 3.14 - - x[0, 7] = 3.14 - y[0, 7] = -0.0 # The result should be -3.14 - - x[0, 8] = -3.14 - y[0, 8] = 0.0 # The result should be 3.14 - - x[0, 9] = -3.14 - y[0, 9] = -0.0 # The result should be -3.14 - - x[0, 10] = 0.0 - y[0, 10] = 0.0 # The result should be 0.0 - - x[0, 11] = 0.0 - y[0, 11] = -0.0 # The result should be -0.0 - - x[0, 12] = -0.0 - y[0, 12] = 0.0 # The result should be 0.0 - - x[0, 13] = -0.0 - y[0, 13] = -0.0 # The result should be -0.0 - - z = torch.empty_like(x) - - BLOCK_SIZE = 192 - grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) - - copysign_kernel[grid](x, y, z, n_elements, BLOCK_SIZE=BLOCK_SIZE) - - expected = torch.copysign(x, y) - - torch.testing.assert_close(z, expected, rtol=1e-3, atol=1e-3) - - print("✓ COPYSIGN test PASSED!") +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import pytest +import triton +import torch +import triton.language as tl +import triton.language.extra.cann.libdevice as libdevice +import test_common + + +@triton.jit +def copysign_kernel(x_ptr, y_ptr, z_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + mask = offsets < n_elements + + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + + z = libdevice.copysign(x, y) + + tl.store(z_ptr + offsets, z, mask=mask) + + +@pytest.mark.parametrize('shape', [ + (12, 16), +]) +@pytest.mark.parametrize('dtype', ['float32']) +def test_copysign(shape, dtype): + n_elements = shape[0] * shape[1] + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + + # Ensure to include some boundary cases + x[0, 0] = 3.14 + y[0, 0] = 1.0 # The result should be 3.14 + + x[0, 1] = 3.14 + y[0, 1] = -1.0 # The result should be -3.14 + + x[0, 2] = -3.14 + y[0, 2] = 1.0 # The result should be 3.14 + + x[0, 3] = -3.14 + y[0, 3] = -1.0 # The result should be -3.14 + + x[0, 4] = 0.0 + y[0, 4] = -1.0 # The result should be -0.0 + + x[0, 5] = 0.0 + y[0, 5] = 1.0 # The result should be 0.0 + + x[0, 6] = 3.14 + y[0, 6] = 0.0 # The result should be 3.14 + + x[0, 7] = 3.14 + y[0, 7] = -0.0 # The result should be -3.14 + + x[0, 8] = -3.14 + y[0, 8] = 0.0 # The result should be 3.14 + + x[0, 9] = -3.14 + y[0, 9] = -0.0 # The result should be -3.14 + + x[0, 10] = 0.0 + y[0, 10] = 0.0 # The result should be 0.0 + + x[0, 11] = 0.0 + y[0, 11] = -0.0 # The result should be -0.0 + + x[0, 12] = -0.0 + y[0, 12] = 0.0 # The result should be 0.0 + + x[0, 13] = -0.0 + y[0, 13] = -0.0 # The result should be -0.0 + + z = torch.empty_like(x) + + BLOCK_SIZE = 192 + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + + copysign_kernel[grid](x, y, z, n_elements, BLOCK_SIZE=BLOCK_SIZE) + + expected = torch.copysign(x, y) + + torch.testing.assert_close(z, expected, rtol=1e-3, atol=1e-3) + + print("✓ COPYSIGN test PASSED!") diff --git a/third_party/ascend/examples/pytest_ut/test_cos.py b/third_party/ascend/unittest/pytest_ut/test_cos.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_cos.py rename to third_party/ascend/unittest/pytest_ut/test_cos.py diff --git a/third_party/ascend/examples/pytest_ut/test_cos_2.py b/third_party/ascend/unittest/pytest_ut/test_cos_2.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_cos_2.py rename to third_party/ascend/unittest/pytest_ut/test_cos_2.py diff --git a/third_party/ascend/examples/pytest_ut/test_cosh.py b/third_party/ascend/unittest/pytest_ut/test_cosh.py similarity index 97% rename from third_party/ascend/examples/pytest_ut/test_cosh.py rename to third_party/ascend/unittest/pytest_ut/test_cosh.py index e04c97fd5..ee35268b6 100644 --- a/third_party/ascend/examples/pytest_ut/test_cosh.py +++ b/third_party/ascend/unittest/pytest_ut/test_cosh.py @@ -22,7 +22,7 @@ import triton import triton.language as tl -import triton.language.extra.ascend.libdevice as libdevice +import triton.language.extra.cann.libdevice as libdevice import test_common import torch diff --git a/third_party/ascend/examples/pytest_ut/test_cumprod.py b/third_party/ascend/unittest/pytest_ut/test_cumprod.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_cumprod.py rename to third_party/ascend/unittest/pytest_ut/test_cumprod.py diff --git a/third_party/ascend/examples/pytest_ut/test_cumsum.py b/third_party/ascend/unittest/pytest_ut/test_cumsum.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_cumsum.py rename to third_party/ascend/unittest/pytest_ut/test_cumsum.py diff --git a/third_party/ascend/unittest/pytest_ut/test_custom.py b/third_party/ascend/unittest/pytest_ut/test_custom.py new file mode 100755 index 000000000..9c2d35c0b --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_custom.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python3 +import subprocess +import triton +import triton.language as tl +import triton.language.extra.cann.extension as al +from triton.compiler.compiler import ASTSource +from triton.compiler.code_generator import ast_to_ttir +from triton._C.libtriton import ir +from triton._C.libtriton.ascend import ir as ascend_ir +from triton.backends.ascend.compiler import NPUOptions, ttir_to_linalg + + +def compile_kernel(kernel, signature, constants): + """Helper to compile a kernel function to MLIR in linalg dialect.""" + src = ASTSource(kernel, signature, constants) + context = ir.context() + ir.load_dialects(context) + ascend_ir.load_dialects(context) + try: + options = NPUOptions() + ttir = ast_to_ttir(kernel, src, context, options, {}, {}) + metadata = { + **options.__dict__, + } + linalg = ttir_to_linalg(ttir, metadata, options, named_ops=True) + return str(linalg) + except subprocess.CalledProcessError as ex: + print(ex.stdout.decode()) + print(ex.stderr.decode()) + print("failed") + return None + + +# ============== Kernel definitions ============== + + +@al.register_custom_op +class my_custom_op: + core = al.CORE.VECTOR + pipe = al.PIPE.PIPE_V + mode = al.MODE.SIMT + + def __init__(self, x, ptr1, ptr2, offset: tl.int64, other, out=None): + pass + + +@triton.jit +def my_kernel(x_ptr, y_ptr, out_ptr, n, BLOCK: tl.constexpr): + i = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) + x = tl.load(x_ptr + i, mask=i < n) + y = tl.load(y_ptr + i, mask=i < n) + result = al.custom("my_custom_op", x, x_ptr, y_ptr + i, (1, 2, 3), [4.1, 5.2], out=y) + a = 123 + result = al.custom("my_custom_op", x, x_ptr, y_ptr, (a, n), (1.2, 3.4), out=result) + tl.store(out_ptr + i, result, mask=i < n) + + +# ============== Pytest tests ============== + + +def test_custom_op(): + """Test custom op compile to linalg MLIR.""" + mlir = compile_kernel(my_kernel, {"x_ptr": "*fp32", "y_ptr": "*fp32", "out_ptr": "*fp32", "n": "i32"}, + {"BLOCK": 256}) + assert mlir and len(mlir) > 0 + assert "func.func @my_kernel(" in mlir + assert "hivm.hir.custom" in mlir + for line in mlir.splitlines(): + if "hivm.hir.custom" in line: + # custom op name + assert '"my_custom_op"' in line + # All tt.ptr converted to memref. + assert "tt.ptr" not in line + # Required attributes are set. + assert "hivm.pipe = #hivm.pipe" in line + assert "hivm.tcore_type = #hivm.tcore_type" in line + assert "hivm.vf_mode = #hivm.vf_mode" in line + # All offset converted to int64. + assert 'i64, ' in line + assert 'i32, ' not in line + + +# ============== Main for manual testing ============== + +if __name__ == "__main__": + mlir = compile_kernel(my_kernel, {"x_ptr": "*fp32", "y_ptr": "*fp32", "out_ptr": "*fp32", "n": "i32"}, + {"BLOCK": 256}) + print(f"✅ Generated MLIR ({len(mlir)} chars):\n") + print(mlir) diff --git a/third_party/ascend/examples/pytest_ut/test_cyl_bessel_i0.py b/third_party/ascend/unittest/pytest_ut/test_cyl_bessel_i0.py similarity index 97% rename from third_party/ascend/examples/pytest_ut/test_cyl_bessel_i0.py rename to third_party/ascend/unittest/pytest_ut/test_cyl_bessel_i0.py index a267351fb..5eba4e969 100644 --- a/third_party/ascend/examples/pytest_ut/test_cyl_bessel_i0.py +++ b/third_party/ascend/unittest/pytest_ut/test_cyl_bessel_i0.py @@ -22,7 +22,7 @@ import triton import triton.language as tl -import triton.language.extra.ascend.libdevice as libdevice +import triton.language.extra.cann.libdevice as libdevice import test_common import torch diff --git a/third_party/ascend/examples/pytest_ut/test_debug_barrier.py b/third_party/ascend/unittest/pytest_ut/test_debug_barrier.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_debug_barrier.py rename to third_party/ascend/unittest/pytest_ut/test_debug_barrier.py diff --git a/third_party/ascend/examples/pytest_ut/test_device_print.py b/third_party/ascend/unittest/pytest_ut/test_device_print.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_device_print.py rename to third_party/ascend/unittest/pytest_ut/test_device_print.py diff --git a/third_party/ascend/examples/pytest_ut/test_device_print_comprehensive.py b/third_party/ascend/unittest/pytest_ut/test_device_print_comprehensive.py similarity index 95% rename from third_party/ascend/examples/pytest_ut/test_device_print_comprehensive.py rename to third_party/ascend/unittest/pytest_ut/test_device_print_comprehensive.py index f2be525b0..d623fc51d 100644 --- a/third_party/ascend/examples/pytest_ut/test_device_print_comprehensive.py +++ b/third_party/ascend/unittest/pytest_ut/test_device_print_comprehensive.py @@ -1,122 +1,122 @@ -import os -import sys -import subprocess -import tempfile -import textwrap - -import pytest -import torch -import torch_npu -import triton -import triton.language as tl - -expected_prints = [ - "Offsets:", - "Mask:", - "Pointer offsets:", - "Loaded x:", - "Scalar factor:", - "Temp result (x * 2):", - "Final y (x * 2 + 1):", - "Positive mask:", - "Block ID:", - "Block start:", - "Valid elements in this block:", -] - - -def test_comprehensive_print(): - - with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: - temp_script = f.name - - f.write( - textwrap.dedent(f""" -import os -import sys -import subprocess -import tempfile -import textwrap - -import pytest -import torch -import torch_npu -import triton -import triton.language as tl - -os.environ["TRITON_DEVICE_PRINT"] = "1" - -@triton.jit -def comprehensive_print_kernel( - x_ptr, - y_ptr, - mask_ptr, - n_elements, - BLOCK_SIZE: tl.constexpr, -): - pid = tl.program_id(0) - block_start = pid * BLOCK_SIZE - - offsets = block_start + tl.arange(0, BLOCK_SIZE) - tl.device_print("Offsets: ", offsets) - - mask = offsets < n_elements - tl.device_print("Mask: ", mask) - - x_ptrs = x_ptr + offsets - tl.device_print("Pointer offsets: ", offsets) - - x = tl.load(x_ptrs, mask=mask, other=0.0) - tl.device_print("Loaded x: ", x) - - scalar_factor = 2.0 - tl.device_print("Scalar factor: ", scalar_factor) - - y_temp = x * scalar_factor - tl.device_print("Temp result (x * 2): ", y_temp) - - y = y_temp + 1.0 - tl.device_print("Final y (x * 2 + 1): ", y) - - positive_mask = y > 0.0 - tl.device_print("Positive mask: ", positive_mask) - - tl.device_print("Block ID: ", pid) - tl.device_print("Block start: ", block_start) - - y_ptrs = y_ptr + offsets - tl.store(y_ptrs, y, mask=mask) - mask_count = tl.sum(mask.to(tl.int32)) - tl.device_print("Valid elements in this block: ", mask_count) - - -def test_comprehensive_print(): - size = 16 - x = torch.randn(size).npu() - y = torch.zeros(size).npu() - mask = torch.ones(size, dtype=torch.bool).npu() - BLOCK_SIZE = 32 - - h = comprehensive_print_kernel[1,](x, y, mask, size, BLOCK_SIZE=BLOCK_SIZE) - - expected = x * 2.0 + 1.0 - torch.testing.assert_close(y, expected, rtol=1e-5, atol=1e-5) - - for i in range(11): - opStr = "call @triton_print_" + str(i) - assert opStr in h.asm["ttadapter"] - - print("passed!") - - -if __name__ == "__main__": - test_comprehensive_print() - """)) - - result = subprocess.run([sys.executable, temp_script], capture_output=True, text=True, env=os.environ.copy()) - - captured_output = result.stdout + "\n=== STDERR ===\n" + result.stderr - - assert "passed!" in captured_output - for prefix in expected_prints: - assert prefix in captured_output +import os +import sys +import subprocess +import tempfile +import textwrap + +import pytest +import torch +import torch_npu +import triton +import triton.language as tl + +expected_prints = [ + "Offsets:", + "Mask:", + "Pointer offsets:", + "Loaded x:", + "Scalar factor:", + "Temp result (x * 2):", + "Final y (x * 2 + 1):", + "Positive mask:", + "Block ID:", + "Block start:", + "Valid elements in this block:", +] + + +def test_comprehensive_print(): + + with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: + temp_script = f.name + + f.write( + textwrap.dedent(f""" +import os +import sys +import subprocess +import tempfile +import textwrap + +import pytest +import torch +import torch_npu +import triton +import triton.language as tl + +os.environ["TRITON_DEVICE_PRINT"] = "1" + +@triton.jit +def comprehensive_print_kernel( + x_ptr, + y_ptr, + mask_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + + offsets = block_start + tl.arange(0, BLOCK_SIZE) + tl.device_print("Offsets: ", offsets) + + mask = offsets < n_elements + tl.device_print("Mask: ", mask) + + x_ptrs = x_ptr + offsets + tl.device_print("Pointer offsets: ", offsets) + + x = tl.load(x_ptrs, mask=mask, other=0.0) + tl.device_print("Loaded x: ", x) + + scalar_factor = 2.0 + tl.device_print("Scalar factor: ", scalar_factor) + + y_temp = x * scalar_factor + tl.device_print("Temp result (x * 2): ", y_temp) + + y = y_temp + 1.0 + tl.device_print("Final y (x * 2 + 1): ", y) + + positive_mask = y > 0.0 + tl.device_print("Positive mask: ", positive_mask) + + tl.device_print("Block ID: ", pid) + tl.device_print("Block start: ", block_start) + + y_ptrs = y_ptr + offsets + tl.store(y_ptrs, y, mask=mask) + mask_count = tl.sum(mask.to(tl.int32)) + tl.device_print("Valid elements in this block: ", mask_count) + + +def test_comprehensive_print(): + size = 16 + x = torch.randn(size).npu() + y = torch.zeros(size).npu() + mask = torch.ones(size, dtype=torch.bool).npu() + BLOCK_SIZE = 32 + + h = comprehensive_print_kernel[1,](x, y, mask, size, BLOCK_SIZE=BLOCK_SIZE) + + expected = x * 2.0 + 1.0 + torch.testing.assert_close(y, expected, rtol=1e-5, atol=1e-5) + + for i in range(11): + opStr = "call @triton_print_" + str(i) + assert opStr in h.asm["ttadapter"] + + print("passed!") + + +if __name__ == "__main__": + test_comprehensive_print() + """)) + + result = subprocess.run([sys.executable, temp_script], capture_output=True, text=True, env=os.environ.copy()) + + captured_output = result.stdout + "\n=== STDERR ===\n" + result.stderr + + assert "passed!" in captured_output + for prefix in expected_prints: + assert prefix in captured_output diff --git a/third_party/ascend/examples/pytest_ut/test_device_print_script.py b/third_party/ascend/unittest/pytest_ut/test_device_print_script.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_device_print_script.py rename to third_party/ascend/unittest/pytest_ut/test_device_print_script.py diff --git a/third_party/ascend/examples/pytest_ut/test_discrete_mask_loadstore.py b/third_party/ascend/unittest/pytest_ut/test_discrete_mask_loadstore.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_discrete_mask_loadstore.py rename to third_party/ascend/unittest/pytest_ut/test_discrete_mask_loadstore.py diff --git a/third_party/ascend/examples/pytest_ut/test_div.py b/third_party/ascend/unittest/pytest_ut/test_div.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_div.py rename to third_party/ascend/unittest/pytest_ut/test_div.py diff --git a/third_party/ascend/unittest/pytest_ut/test_dot.py b/third_party/ascend/unittest/pytest_ut/test_dot.py new file mode 100644 index 000000000..a4837d044 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_dot.py @@ -0,0 +1,71 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import pytest +import torch +import torch_npu +import triton +import triton.language as tl +import test_common + + +def torch_dot_None(x0, x1): + res = torch.matmul(x0, x1) + return res + + +@triton.jit +def triton_dot_2_None(output_ptr, x_ptr, y_ptr, B: tl.constexpr, C: tl.constexpr, D: tl.constexpr): + bidx = tl.arange(0, B) + cidx = tl.arange(0, C) + didx = tl.arange(0, D) + + x_mask = (bidx[:, None] < B) & (cidx[None, :] < C) + y_mask = (cidx[:, None] < C) & (didx[None, :] < D) + out_mask = (bidx[:, None] < B) & (didx[None, :] < D) + Xidx = bidx[:, None] * C + cidx[None, :] + Yidx = cidx[:, None] * D + didx[None, :] + X = tl.load(x_ptr + Xidx, mask=x_mask, other=0.0) + Y = tl.load(y_ptr + Yidx, mask=y_mask, other=0.0) + ret = tl.dot(X, Y, input_precision="hf32") + oidx = bidx[:, None] * D + didx[None, :] + tl.store(output_ptr + oidx, ret, mask=out_mask) + + +testlist1 = [ + (10, 13, 35, 39), +] + +testlist2 = [(16, 32, 16)] + +typelist = [ + 'float32', +] + + +@pytest.mark.parametrize("B, C, D", testlist2) +@pytest.mark.parametrize("sigtype", typelist) +def test_dot_2(sigtype, B, C, D): + x = test_common.generate_tensor((B, C), sigtype).npu() + y = test_common.generate_tensor((C, D), sigtype).npu() + z_ref = torch_dot_None(x, y).to(torch.float32) + z = torch.zeros((B, D), dtype=torch.float32).npu() + triton_dot_2_None[1, 1, 1](z, x, y, B, C, D) + test_common.validate_cmp(sigtype, z, z_ref) diff --git a/third_party/ascend/examples/pytest_ut/test_dot_scaled.py b/third_party/ascend/unittest/pytest_ut/test_dot_scaled.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_dot_scaled.py rename to third_party/ascend/unittest/pytest_ut/test_dot_scaled.py diff --git a/third_party/ascend/examples/pytest_ut/test_downgrade.py b/third_party/ascend/unittest/pytest_ut/test_downgrade.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_downgrade.py rename to third_party/ascend/unittest/pytest_ut/test_downgrade.py diff --git a/third_party/ascend/examples/pytest_ut/test_elementwise_ceil.py b/third_party/ascend/unittest/pytest_ut/test_elementwise_ceil.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_elementwise_ceil.py rename to third_party/ascend/unittest/pytest_ut/test_elementwise_ceil.py diff --git a/third_party/ascend/examples/pytest_ut/test_elementwise_clip.py b/third_party/ascend/unittest/pytest_ut/test_elementwise_clip.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_elementwise_clip.py rename to third_party/ascend/unittest/pytest_ut/test_elementwise_clip.py diff --git a/third_party/ascend/examples/pytest_ut/test_elementwise_f2i.py b/third_party/ascend/unittest/pytest_ut/test_elementwise_f2i.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_elementwise_f2i.py rename to third_party/ascend/unittest/pytest_ut/test_elementwise_f2i.py diff --git a/third_party/ascend/examples/pytest_ut/test_elementwise_floor.py b/third_party/ascend/unittest/pytest_ut/test_elementwise_floor.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_elementwise_floor.py rename to third_party/ascend/unittest/pytest_ut/test_elementwise_floor.py diff --git a/third_party/ascend/examples/pytest_ut/test_elementwise_i2f.py b/third_party/ascend/unittest/pytest_ut/test_elementwise_i2f.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_elementwise_i2f.py rename to third_party/ascend/unittest/pytest_ut/test_elementwise_i2f.py diff --git a/third_party/ascend/examples/pytest_ut/test_elementwise_round.py b/third_party/ascend/unittest/pytest_ut/test_elementwise_round.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_elementwise_round.py rename to third_party/ascend/unittest/pytest_ut/test_elementwise_round.py diff --git a/third_party/ascend/examples/pytest_ut/test_eq.py b/third_party/ascend/unittest/pytest_ut/test_eq.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_eq.py rename to third_party/ascend/unittest/pytest_ut/test_eq.py diff --git a/third_party/ascend/examples/pytest_ut/test_eq_2.py b/third_party/ascend/unittest/pytest_ut/test_eq_2.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_eq_2.py rename to third_party/ascend/unittest/pytest_ut/test_eq_2.py diff --git a/third_party/ascend/examples/pytest_ut/test_erfinv.py b/third_party/ascend/unittest/pytest_ut/test_erfinv.py similarity index 98% rename from third_party/ascend/examples/pytest_ut/test_erfinv.py rename to third_party/ascend/unittest/pytest_ut/test_erfinv.py index 956b7e46e..9e45a5c55 100644 --- a/third_party/ascend/examples/pytest_ut/test_erfinv.py +++ b/third_party/ascend/unittest/pytest_ut/test_erfinv.py @@ -24,7 +24,7 @@ import torch_npu import pytest import test_common -import triton.language.extra.ascend.libdevice as libdevice +import triton.language.extra.cann.libdevice as libdevice @triton.jit diff --git a/third_party/ascend/examples/pytest_ut/test_exp.py b/third_party/ascend/unittest/pytest_ut/test_exp.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_exp.py rename to third_party/ascend/unittest/pytest_ut/test_exp.py diff --git a/third_party/ascend/examples/pytest_ut/test_exp2.py b/third_party/ascend/unittest/pytest_ut/test_exp2.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_exp2.py rename to third_party/ascend/unittest/pytest_ut/test_exp2.py diff --git a/third_party/ascend/examples/pytest_ut/test_exp_.py b/third_party/ascend/unittest/pytest_ut/test_exp_.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_exp_.py rename to third_party/ascend/unittest/pytest_ut/test_exp_.py diff --git a/third_party/ascend/examples/pytest_ut/test_expand_dims.py b/third_party/ascend/unittest/pytest_ut/test_expand_dims.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_expand_dims.py rename to third_party/ascend/unittest/pytest_ut/test_expand_dims.py diff --git a/third_party/ascend/examples/pytest_ut/test_expm1.py b/third_party/ascend/unittest/pytest_ut/test_expm1.py similarity index 97% rename from third_party/ascend/examples/pytest_ut/test_expm1.py rename to third_party/ascend/unittest/pytest_ut/test_expm1.py index 78d562988..90665b030 100644 --- a/third_party/ascend/examples/pytest_ut/test_expm1.py +++ b/third_party/ascend/unittest/pytest_ut/test_expm1.py @@ -22,7 +22,7 @@ import triton import triton.language as tl -import triton.language.extra.ascend.libdevice as libdevice +import triton.language.extra.cann.libdevice as libdevice import test_common import torch diff --git a/third_party/ascend/examples/pytest_ut/test_extract_slice.py b/third_party/ascend/unittest/pytest_ut/test_extract_slice.py similarity index 94% rename from third_party/ascend/examples/pytest_ut/test_extract_slice.py rename to third_party/ascend/unittest/pytest_ut/test_extract_slice.py index 58906a31c..7014e73e1 100644 --- a/third_party/ascend/examples/pytest_ut/test_extract_slice.py +++ b/third_party/ascend/unittest/pytest_ut/test_extract_slice.py @@ -23,6 +23,7 @@ import triton import triton.language as tl +import triton.language.extra.cann.extension as extension import pytest @@ -36,7 +37,7 @@ def triton_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr x = tl.load(x_ptr + offsets, mask=mask) y = tl.load(y_ptr + offsets, mask=mask) output = x + y - out_sub = tl.extract_slice(output, [block_start], [32], [1]) + out_sub = extension.extract_slice(output, [block_start], [32], [1]) out_idx = block_start + tl.arange(0, 32) out_msk = out_idx < n_elements tl.store(output_ptr + out_idx, out_sub, mask=out_msk) diff --git a/third_party/ascend/examples/pytest_ut/test_fdiv.py b/third_party/ascend/unittest/pytest_ut/test_fdiv.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_fdiv.py rename to third_party/ascend/unittest/pytest_ut/test_fdiv.py diff --git a/third_party/ascend/unittest/pytest_ut/test_fixpipe.py b/third_party/ascend/unittest/pytest_ut/test_fixpipe.py new file mode 100644 index 000000000..22b9aa98f --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_fixpipe.py @@ -0,0 +1,84 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import os +import pytest +import triton +import triton.language as tl +import triton.extension.buffer.language as bl +import triton.language.extra.cann.extension as al +from triton.compiler.compiler import ASTSource +from triton.compiler.code_generator import ast_to_ttir +from triton._C.libtriton import ir +from triton._C.libtriton.ascend import ir as ascend_ir + +os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = "0" + + +class Options: + num_warps = 4 + num_stages = 3 + num_ctas = 1 + cluster_dims = (1, 1, 1) + enable_fp_fusion = True + debug = False + arch = "Ascend910_95" + + +def compile_kernel(kernel, signature, constants): + """Helper to compile a kernel to MLIR.""" + src = ASTSource(kernel, signature, constants) + context = ir.context() + ir.load_dialects(context) + ascend_ir.load_dialects(context) + module = ast_to_ttir(kernel, src, context, Options(), {}, {}) + return str(module) + + +@triton.jit +def fixpipe( + A_ptr, + M: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, +): + + row_matmul = tl.program_id(0) + + offs_i = tl.arange(0, tl.constexpr(M))[:, None] # [M,1] (row axis) + offs_k = tl.arange(0, K) # [K] + + a_ptrs = A_ptr + (row_matmul + offs_i) * K + offs_k[None, :] + a_vals = tl.load(a_ptrs) # [M, K] + + ub = bl.alloc(tl.float32, [M, N], al.ascend_address_space.UB) + al.fixpipe(a_vals, ub, dual_dst_mode=al.FixpipeDualDstMode.NO_DUAL) + + +@pytest.mark.parametrize("M, K, N", [(16, 16, 16)]) +def test_fixpipe(M, K, N): + mlir = compile_kernel( + fixpipe, + { + "A_ptr": "*fp32", + }, + {"M": M, "K": K, "N": N}, + ) + assert len(mlir) > 0 diff --git a/third_party/ascend/examples/pytest_ut/test_flip.py b/third_party/ascend/unittest/pytest_ut/test_flip.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_flip.py rename to third_party/ascend/unittest/pytest_ut/test_flip.py diff --git a/third_party/ascend/examples/pytest_ut/test_floor.py b/third_party/ascend/unittest/pytest_ut/test_floor.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_floor.py rename to third_party/ascend/unittest/pytest_ut/test_floor.py diff --git a/third_party/ascend/examples/pytest_ut/test_floordiv.py b/third_party/ascend/unittest/pytest_ut/test_floordiv.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_floordiv.py rename to third_party/ascend/unittest/pytest_ut/test_floordiv.py diff --git a/third_party/ascend/unittest/pytest_ut/test_for_ptr.py b/third_party/ascend/unittest/pytest_ut/test_for_ptr.py new file mode 100644 index 000000000..47b8d28ba --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_for_ptr.py @@ -0,0 +1,116 @@ +import pytest + +import triton +import triton.language as tl +import test_common + +import torch +import torch_npu + +types_all = [ + (torch.float32, 'float32'), +] + +shapes_common = [(128, 256), (127, 256), (127, 16), (129, 256), (77, 1024), (69, 512)] + +block_size = [128, 256, 1024] + + +def ceil_div(a, b): + return (a + b - 1) // b + + +def profiler_wrapper(fn, *args): + result_path = "./result_profiling_for" + skip_first = 10 + wait = 0 + warmup = 3 + active = 30 + repeat = 1 + stream = torch.npu.current_stream() + experimental_config = torch_npu.profiler._ExperimentalConfig( + aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, + profiler_level=torch_npu.profiler.ProfilerLevel.Level1, l2_cache=False, data_simplification=False) + with torch_npu.profiler.profile( + activities=[torch_npu.profiler.ProfilerActivity.CPU, torch_npu.profiler.ProfilerActivity.NPU], + schedule=torch_npu.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=repeat, + skip_first=skip_first), + on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(result_path), record_shapes=True, + profile_memory=False, with_stack=False, with_flops=False, with_modules=False, + experimental_config=experimental_config) as prof: + stream.synchronize() + for _ in range(skip_first + (wait + warmup + active) * repeat): + fn(*args) + prof.step() + stream.synchronize() + + +@triton.jit +def for_ptr_kernel( + in_ptr, + output_ptr, + N: tl.constexpr, + M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + pid = tl.program_id(axis=0) + offset = 2 * pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + x_in = in_ptr + offset + x_out = output_ptr + offset + for _ in range(0, 2): + mask = (offset < M * N) + data = tl.load(x_in, mask=mask, other=0) + tl.store(x_out, data, mask) + x_in += BLOCK_SIZE_N + x_out += BLOCK_SIZE_N + + +@pytest.mark.parametrize('dtype, sigtype', types_all) +@pytest.mark.parametrize('M, N', shapes_common) +@pytest.mark.parametrize('BLOCK_SIZE_N', block_size) +def test_for_ptr(M, N, BLOCK_SIZE_N, dtype, sigtype): + + in_tensor = torch.randn(M, N, dtype=dtype).npu() + + triton_output = torch.zeros_like(in_tensor) + + grid = (ceil_div(2 * M * N, BLOCK_SIZE_N), ) + + for_ptr_kernel[grid](in_tensor, triton_output, N=N, M=M, BLOCK_SIZE_N=BLOCK_SIZE_N, optimize_dynamic_offset=False) + + assert torch.allclose(triton_output, in_tensor, rtol=1e-5, atol=1e-8) + + +def triton_lfor_ptr(in_tensor, BLOCK_SIZE): + M = in_tensor.shape[0] + N = in_tensor.shape[1] + + triton_output = torch.zeros_like(in_tensor) + grid = (ceil_div(2 * M * N, BLOCK_SIZE), ) + + for_ptr_kernel[grid](in_tensor, triton_output, N=N, M=M, BLOCK_SIZE_N=BLOCK_SIZE, optimize_dynamic_offset=True) + + +def profile_performance_test(M, N, dtype, BLOCK_SIZE): + print(f"\nDetailed performance analysis: M={M}, N={N}, dtype={dtype}, block_size={BLOCK_SIZE}") + + in_tensor = torch.randn(2 * M, N, dtype=dtype).npu() + + def wrapper_func(x): + triton_lfor_ptr(x, BLOCK_SIZE=BLOCK_SIZE) + + # Run performance analysis + profiler_wrapper(wrapper_func, in_tensor) + + +if __name__ == "__main__": + print("For Kernel Performance Test Suite") + print("Function: Broadcast first element") + + # Optional: Run detailed profiler test (specific configuration) + profile_performance_test(512, 512, torch.float32, BLOCK_SIZE=1024) + + print("\n" + "=" * 80) + print("Test completed!") + print(f"Detailed performance analysis results saved in: ./result_profiling_for/") + print("=" * 80) diff --git a/third_party/ascend/examples/pytest_ut/test_full.py b/third_party/ascend/unittest/pytest_ut/test_full.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_full.py rename to third_party/ascend/unittest/pytest_ut/test_full.py diff --git a/third_party/ascend/examples/pytest_ut/test_fusedattention.py b/third_party/ascend/unittest/pytest_ut/test_fusedattention.py similarity index 97% rename from third_party/ascend/examples/pytest_ut/test_fusedattention.py rename to third_party/ascend/unittest/pytest_ut/test_fusedattention.py index 94bd1af04..f884dd3db 100644 --- a/third_party/ascend/examples/pytest_ut/test_fusedattention.py +++ b/third_party/ascend/unittest/pytest_ut/test_fusedattention.py @@ -37,6 +37,7 @@ import torch_npu import triton import triton.language as tl +import triton.language.extra.cann.extension as extension DEVICE = "npu" @@ -127,13 +128,13 @@ def _attn_fwd_inner(acc_ptr, l_i, m_i, q, # Accumulator, local l, local m, quer # Calculate start/end rows for current slice offset = i * (BLOCK_M // 4) # Extract slice data - acc_i = tl.extract_slice(acc, (offset, 0), (BLOCK_M // 4, HEAD_DIM), (1, 1)) - alpha_i = tl.extract_slice(alpha, [offset], [BLOCK_M // 4], [1]) - pv_i = tl.extract_slice(pv, (offset, 0), (BLOCK_M // 4, HEAD_DIM), (1, 1)) + acc_i = extension.extract_slice(acc, (offset, 0), (BLOCK_M // 4, HEAD_DIM), (1, 1)) + alpha_i = extension.extract_slice(alpha, [offset], [BLOCK_M // 4], [1]) + pv_i = extension.extract_slice(pv, (offset, 0), (BLOCK_M // 4, HEAD_DIM), (1, 1)) # Incrementally update slice: acc = acc * alpha + pv acc_i = acc_i * alpha_i[:, None] + pv_i # Write updated slice back to accumulator - acc = tl.insert_slice(acc, acc_i, (offset, 0), (BLOCK_M // 4, HEAD_DIM), (1, 1)) + acc = extension.insert_slice(acc, acc_i, (offset, 0), (BLOCK_M // 4, HEAD_DIM), (1, 1)) # 3. updated accumulator tl.store(acc_ptr + block2d_acc, acc) diff --git a/third_party/ascend/examples/pytest_ut/test_gamma.py b/third_party/ascend/unittest/pytest_ut/test_gamma.py similarity index 98% rename from third_party/ascend/examples/pytest_ut/test_gamma.py rename to third_party/ascend/unittest/pytest_ut/test_gamma.py index c03df1895..388ed36af 100644 --- a/third_party/ascend/examples/pytest_ut/test_gamma.py +++ b/third_party/ascend/unittest/pytest_ut/test_gamma.py @@ -24,7 +24,7 @@ import torch_npu import pytest import test_common -import triton.language.extra.ascend.libdevice as libdevice +import triton.language.extra.cann.libdevice as libdevice import numpy as np from scipy.special import gamma diff --git a/third_party/ascend/examples/pytest_ut/test_gather.py b/third_party/ascend/unittest/pytest_ut/test_gather.py similarity index 96% rename from third_party/ascend/examples/pytest_ut/test_gather.py rename to third_party/ascend/unittest/pytest_ut/test_gather.py index 279163215..08946f894 100644 --- a/third_party/ascend/examples/pytest_ut/test_gather.py +++ b/third_party/ascend/unittest/pytest_ut/test_gather.py @@ -22,6 +22,7 @@ import torch_npu import triton import triton.language as tl +import triton.language.extra.cann.extension as extension import numpy as np import test_common import pytest @@ -47,7 +48,7 @@ def gather_kernel(src_ptr, idx_ptr, out_ptr, axis: tl.constexpr, src_dim0: tl.co idx_offs = (tl.arange(0, idx_dim0)[:, None] * idx_stride0 + tl.arange(0, idx_dim1)[None, :] * idx_stride1) idx = tl.load(idx_ptr + idx_offs) - out = tl.gather(src, idx, axis) + out = extension.gather(src, idx, axis) out_offs = (tl.arange(0, out_dim0)[:, None] * out_stride0 + tl.arange(0, out_dim1)[None, :] * out_stride1) tl.store(out_ptr + out_offs, out) diff --git a/third_party/ascend/examples/model_cases/qwen.py b/third_party/ascend/unittest/pytest_ut/test_gather_simd.py similarity index 54% rename from third_party/ascend/examples/model_cases/qwen.py rename to third_party/ascend/unittest/pytest_ut/test_gather_simd.py index e58f4c3a7..6225e4ef2 100644 --- a/third_party/ascend/examples/model_cases/qwen.py +++ b/third_party/ascend/unittest/pytest_ut/test_gather_simd.py @@ -17,42 +17,30 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. - -import logging -import os +""" +Unit test for gather_2d_simd kernel. +""" import torch import torch_npu -import torch_npu._inductor - -from transformers import AutoTokenizer, AutoModelForCausalLM - -os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" - -logging.basicConfig(level=logging.DEBUG) - -torch.npu.config.allow_internal_format = False -torch.manual_seed(0) -torch.npu.manual_seed(0) -tokenizer = AutoTokenizer.from_pretrained("./Qwen2.5-0.5B-Instruct") - -inputs = tokenizer("Hello, how to make China great again?", return_tensors="pt").to("npu:0") -model_ = AutoModelForCausalLM.from_pretrained("./Qwen2.5-0.5B-Instruct", device_map="npu:0") -model_.eval() - - -def model(**model_inputs): - with torch.no_grad(): - return model_(**model_inputs).logits +import triton +import triton.language as tl +import pytest +from triton.language.extra.kernels import gather_2d_simd -y = model(**inputs) -logging.info("result eager: " + str(torch.flatten(y)[:100])) -model_compiled = torch.compile(model_) +@pytest.mark.parametrize("M,N,K", [ + (32, 128, 64), +]) +def test_gather_2d_simd(M, N, K): + """Test gather_2d_simd with various tensor sizes.""" + src = torch.randn(M, N, dtype=torch.float32, device='npu') + indices = torch.randint(0, N, (M, K), dtype=torch.int32, device='npu') + output = torch.empty((M, K), dtype=src.dtype, device='npu') -z = model_compiled(**inputs) -logging.info("result compiled: " + str(torch.flatten(z)[:100])) + grid = (triton.cdiv(M, 32), ) + gather_2d_simd[grid](src, indices, output, M, N, K, XBLOCK=32, XBLOCK_SUB=4) -torch.testing.assert_close(y, z, atol=1e-4, rtol=1e-4) -logging.info("qwen accuracy check pass!") + ref = torch.gather(src, 1, indices.long()) + assert torch.allclose(output, ref, rtol=1e-5, atol=1e-5) diff --git a/third_party/ascend/examples/pytest_ut/test_ge.py b/third_party/ascend/unittest/pytest_ut/test_ge.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_ge.py rename to third_party/ascend/unittest/pytest_ut/test_ge.py diff --git a/third_party/ascend/examples/pytest_ut/test_ge_2.py b/third_party/ascend/unittest/pytest_ut/test_ge_2.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_ge_2.py rename to third_party/ascend/unittest/pytest_ut/test_ge_2.py diff --git a/third_party/ascend/examples/pytest_ut/test_gelu.py b/third_party/ascend/unittest/pytest_ut/test_gelu.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_gelu.py rename to third_party/ascend/unittest/pytest_ut/test_gelu.py diff --git a/third_party/ascend/examples/pytest_ut/test_gt.py b/third_party/ascend/unittest/pytest_ut/test_gt.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_gt.py rename to third_party/ascend/unittest/pytest_ut/test_gt.py diff --git a/third_party/ascend/examples/pytest_ut/test_hd_permute.py b/third_party/ascend/unittest/pytest_ut/test_hd_permute.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_hd_permute.py rename to third_party/ascend/unittest/pytest_ut/test_hd_permute.py diff --git a/third_party/ascend/examples/pytest_ut/test_histogram.py b/third_party/ascend/unittest/pytest_ut/test_histogram.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_histogram.py rename to third_party/ascend/unittest/pytest_ut/test_histogram.py diff --git a/third_party/ascend/examples/pytest_ut/test_hoistbroadcast.py b/third_party/ascend/unittest/pytest_ut/test_hoistbroadcast.py similarity index 79% rename from third_party/ascend/examples/pytest_ut/test_hoistbroadcast.py rename to third_party/ascend/unittest/pytest_ut/test_hoistbroadcast.py index 581f19b0b..79ea5b3ca 100644 --- a/third_party/ascend/examples/pytest_ut/test_hoistbroadcast.py +++ b/third_party/ascend/unittest/pytest_ut/test_hoistbroadcast.py @@ -166,3 +166,55 @@ def test_hoistbroadcast_compare(param_list): copy_all_layer_kv_cache[(len(data_ptrs), )](data_ptrs, data_strides, tgt_loc, src_loc, len(tgt_loc), 1) copy_all_layer_kv_cache2[(len(data_ptrs_ref), )](data_ptrs_ref, data_strides_ref, tgt_loc, src_loc, len(tgt_loc), 1) test_common.validate_cmp(dtype, kv_buffer, kv_buffer_ref) + + +def torch_pointwise(x0): + if x0.dtype != torch.uint32: + return torch.abs(x0) + else: + return torch.abs(x0.to(torch.float32)) + + +@triton.jit +def fn_npu_(output_ptr, x_ptr, y_ptr, z_ptr, XB: tl.constexpr, YB: tl.constexpr, ZB: tl.constexpr, XNUMEL: tl.constexpr, + YNUMEL: tl.constexpr, ZNUMEL: tl.constexpr): + xoffs = tl.program_id(0) * XB + yoffs = tl.program_id(1) * YB + zoffs = tl.program_id(2) * ZB + + xidx = tl.arange(0, XB) + xoffs + yidx = tl.arange(0, YB) + yoffs + zidx = tl.arange(0, ZB) + zoffs + + X = tl.load(x_ptr + xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :]) + ret = tl.abs(X) + tl.store(output_ptr + xidx[:, None, None] * YNUMEL * ZNUMEL + yidx[None, :, None] * ZNUMEL + zidx[None, None, :], + ret) + + +@pytest.mark.parametrize('shape', [(8, 16, 16)]) +@pytest.mark.parametrize('dtype', ['float32']) +def test_case2(dtype, shape): + # 生成数据 + x = test_common.generate_tensor(shape, dtype).npu() + y = test_common.generate_tensor(shape, dtype).npu() + z = test_common.generate_tensor(shape, dtype).npu() + new_shape = shape + + output = torch.randint(1, new_shape, dtype=eval('torch.' + dtype)).npu() + output1 = output + + ans = torch_pointwise(x) + + XB = shape[0] + xnumel = shape[0] + YB = shape[1] + ynumel = shape[1] + ZB = shape[2] + znumel = shape[2] + + grid = (1, 1, 1) + + fn_npu_[grid](output, x, y, z, XB, YB, ZB, xnumel, ynumel, znumel) + + test_common.validate_cmp(dtype, ans, output) diff --git a/third_party/ascend/examples/pytest_ut/test_hypot.py b/third_party/ascend/unittest/pytest_ut/test_hypot.py similarity index 97% rename from third_party/ascend/examples/pytest_ut/test_hypot.py rename to third_party/ascend/unittest/pytest_ut/test_hypot.py index e79c0ff7c..9bb1e6b31 100644 --- a/third_party/ascend/examples/pytest_ut/test_hypot.py +++ b/third_party/ascend/unittest/pytest_ut/test_hypot.py @@ -22,7 +22,7 @@ import triton import triton.language as tl -import triton.language.extra.ascend.libdevice as libdevice +import triton.language.extra.cann.libdevice as libdevice import test_common import torch diff --git a/third_party/ascend/examples/pytest_ut/test_if_tensor.py b/third_party/ascend/unittest/pytest_ut/test_if_tensor.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_if_tensor.py rename to third_party/ascend/unittest/pytest_ut/test_if_tensor.py diff --git a/third_party/ascend/examples/pytest_ut/test_index_select.py b/third_party/ascend/unittest/pytest_ut/test_index_select.py similarity index 85% rename from third_party/ascend/examples/pytest_ut/test_index_select.py rename to third_party/ascend/unittest/pytest_ut/test_index_select.py index 23a1c534b..a6d4737e4 100644 --- a/third_party/ascend/examples/pytest_ut/test_index_select.py +++ b/third_party/ascend/unittest/pytest_ut/test_index_select.py @@ -20,7 +20,8 @@ import triton import triton.language as tl -import triton.language.extra.ascend.libdevice as libdevice +import triton.language.extra.cann.extension as extension + import torch import torch_npu import pytest @@ -45,7 +46,7 @@ def torch_index_select(x0, dim, indices): def index_select_manual_kernel(in_ptr, indices_ptr, out_ptr, dim, g_stride: tl.constexpr, indice_length: tl.constexpr, g_block: tl.constexpr, g_block_sub: tl.constexpr, other_block: tl.constexpr): """ - Manual implementation using tl.get_element and tl.insert_slice. + Manual implementation using extension.get_element and extension.insert_slice. This is a baseline implementation without using the index_select_simd intrinsic. """ @@ -62,20 +63,21 @@ def index_select_manual_kernel(in_ptr, indices_ptr, out_ptr, dim, g_stride: tl.c # Manual gather: iterate over each index for i in range(0, g_block_sub): - gather_offset = tl.get_element(indices, (i, )) * g_stride + gather_offset = extension.get_element(indices, (i, )) * g_stride val = tl.load(in_ptr + gather_offset + other_idx, other_mask) - tmp_buf = tl.insert_slice(tmp_buf, val[None, :], offsets=(i, 0), sizes=(1, other_block), strides=(1, 1)) + tmp_buf = extension.insert_slice(tmp_buf, val[None, :], offsets=(i, 0), sizes=(1, other_block), + strides=(1, 1)) tl.store(out_ptr + g_idx[:, None] * g_stride + other_idx[None, :], tmp_buf, g_mask[:, None] & other_mask[None, :]) @triton.jit -def index_select_libdevice_kernel(in_ptr, indices_ptr, out_ptr, dim: tl.constexpr, other_numel: tl.constexpr, +def index_select_extension_kernel(in_ptr, indices_ptr, out_ptr, dim: tl.constexpr, other_numel: tl.constexpr, g_stride: tl.constexpr, indice_length: tl.constexpr, g_block: tl.constexpr, g_block_sub: tl.constexpr, other_block: tl.constexpr): """ - Implementation using libdevice.index_select_simd intrinsic. + Implementation using extension.index_select_simd intrinsic. This uses the hardware-accelerated SIMD index_select operation. """ @@ -89,8 +91,8 @@ def index_select_libdevice_kernel(in_ptr, indices_ptr, out_ptr, dim: tl.constexp other_idx = tl.arange(0, other_block) + other_offset other_mask = other_idx < g_stride - # Use libdevice index_select_simd - tmp_buf = libdevice.index_select_simd(src=in_ptr, dim=dim, index=indices, src_shape=(other_numel, g_stride), + # Use extension index_select_simd + tmp_buf = extension.index_select_simd(src=in_ptr, dim=dim, index=indices, src_shape=(other_numel, g_stride), src_offset=(-1, 0), read_shape=(-1, other_block)) tl.store(out_ptr + g_idx[:, None] * g_stride + other_idx[None, :], tmp_buf, @@ -118,7 +120,7 @@ def index_select_auto_kernel(in_ptr, indices_ptr, out_ptr, dim: tl.constexpr, ot # Auto-lowering: compute offsets and use standard load src_offsets = indices[:, None] * g_stride + other_idx[None, :] - tmp_buf = tl.load(in_ptr + src_offsets) + tmp_buf = tl.load(in_ptr + src_offsets, g_mask[:, None] & other_mask[None, :]) tl.store(out_ptr + g_idx[:, None] * g_stride + other_idx[None, :], tmp_buf, g_mask[:, None] & other_mask[None, :]) @@ -129,7 +131,7 @@ def index_select_auto_kernel(in_ptr, indices_ptr, out_ptr, dim: tl.constexpr, ot # ============================================================================ -def triton_index_select(x0, dim, indices, impl='libdevice', num_vec_core=48): +def triton_index_select(x0, dim, indices, impl='extension', num_vec_core=48): """ Triton implementation of index_select. @@ -137,7 +139,7 @@ def triton_index_select(x0, dim, indices, impl='libdevice', num_vec_core=48): x0: Source tensor dim: Dimension to select from indices: Indices to select - impl: Implementation to use ('manual', 'libdevice', or 'auto') + impl: Implementation to use ('manual', 'extension', or 'auto') num_vec_core: Number of vector cores to use Returns: @@ -153,13 +155,13 @@ def triton_index_select(x0, dim, indices, impl='libdevice', num_vec_core=48): # Calculate UB space allocation enable_multi_buffer = True - available_ub_space = (125 * 1024) // (x0.element_size() * (2 if enable_multi_buffer else 1)) - - if g_stride * 2 < available_ub_space: - other_block = g_stride - g_block_sub = available_ub_space // other_block - else: - other_block = available_ub_space + ub_size = 125 * 1024 // (2 if enable_multi_buffer else 1) + other_block = g_stride + g_block_sub = ub_size // ( + # max memory consumption: arith.select + other (mask handling in auto) + x0.element_size() * g_stride * 3 + indices.element_size()) + if g_block_sub < 1: + other_block = (ub_size - indices.element_size()) // x0.element_size() g_block_sub = 1 # Select kernel based on implementation @@ -167,8 +169,8 @@ def triton_index_select(x0, dim, indices, impl='libdevice', num_vec_core=48): kernel = index_select_manual_kernel kernel[num_vec_core, 1, 1](x0, indices, out, dim, g_stride=g_stride, indice_length=indice_length, g_block=g_block, g_block_sub=g_block_sub, other_block=other_block, multibuffer=False) - elif impl == 'libdevice': - kernel = index_select_libdevice_kernel + elif impl == 'extension': + kernel = index_select_extension_kernel kernel[num_vec_core, 1, 1](x0, indices, out, dim, other_numel=sz[0], g_stride=g_stride, indice_length=indice_length, g_block=g_block, g_block_sub=g_block_sub, other_block=other_block) @@ -222,7 +224,7 @@ def triton_index_select(x0, dim, indices, impl='libdevice', num_vec_core=48): @pytest.mark.parametrize("src_shape, dim, indice_shape, dtype", INDEX_SELECT_TEST_CASES) def test_index_select_manual(src_shape, dim, indice_shape, dtype): - """Test manual implementation using tl.get_element and tl.insert_slice.""" + """Test manual implementation using extension.get_element and extension.insert_slice.""" x0 = test_common.generate_tensor(shape=src_shape, dtype=dtype).npu() indices = torch.randint(0, src_shape[dim], size=indice_shape, dtype=torch.int32).npu() @@ -234,13 +236,13 @@ def test_index_select_manual(src_shape, dim, indice_shape, dtype): @pytest.mark.parametrize("src_shape, dim, indice_shape, dtype", INDEX_SELECT_TEST_CASES) -def test_index_select_libdevice(src_shape, dim, indice_shape, dtype): - """Test libdevice.index_select_simd implementation.""" +def test_index_select_extension(src_shape, dim, indice_shape, dtype): + """Test extension.index_select_simd implementation.""" x0 = test_common.generate_tensor(shape=src_shape, dtype=dtype).npu() indices = torch.randint(0, src_shape[dim], size=indice_shape, dtype=torch.int32).npu() torch_ref = torch_index_select(x0, dim, indices) - triton_cal = triton_index_select(x0, dim, indices, impl='libdevice', num_vec_core=48) + triton_cal = triton_index_select(x0, dim, indices, impl='extension', num_vec_core=48) test_common.validate_cmp(dtype, triton_cal, torch_ref) @@ -259,8 +261,8 @@ def test_index_select_auto(src_shape, dim, indice_shape, dtype): # Quick smoke test if __name__ == "__main__": - test_index_select_libdevice((500000, 37), 0, (324344, ), "float32") - print("libdevice implementation passed") + test_index_select_extension((500000, 37), 0, (324344, ), "float32") + print("extension implementation passed") test_index_select_auto((500000, 37), 0, (324344, ), "float32") print("auto-lowering implementation passed") diff --git a/third_party/ascend/examples/pytest_ut/test_index_select_inductor.py b/third_party/ascend/unittest/pytest_ut/test_index_select_inductor.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_index_select_inductor.py rename to third_party/ascend/unittest/pytest_ut/test_index_select_inductor.py diff --git a/third_party/ascend/examples/pytest_ut/test_insert_slice.py b/third_party/ascend/unittest/pytest_ut/test_insert_slice.py similarity index 88% rename from third_party/ascend/examples/pytest_ut/test_insert_slice.py rename to third_party/ascend/unittest/pytest_ut/test_insert_slice.py index 8cf450fb3..9f968971d 100644 --- a/third_party/ascend/examples/pytest_ut/test_insert_slice.py +++ b/third_party/ascend/unittest/pytest_ut/test_insert_slice.py @@ -23,7 +23,7 @@ import triton import triton.language as tl - +import triton.language.extra.cann.extension as extension import pytest @@ -36,11 +36,11 @@ def triton_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr mask = offsets < n_elements x = tl.load(x_ptr + offsets, mask=mask) y = tl.load(y_ptr + offsets, mask=mask) - x_sub = tl.extract_slice(x, [block_start + SLICE_OFFSET], [SLICE_SIZE], [1]) - y_sub = tl.extract_slice(y, [block_start + SLICE_OFFSET], [SLICE_SIZE], [1]) + x_sub = extension.extract_slice(x, [block_start + SLICE_OFFSET], [SLICE_SIZE], [1]) + y_sub = extension.extract_slice(y, [block_start + SLICE_OFFSET], [SLICE_SIZE], [1]) output_sub = x_sub + y_sub output = tl.load(output_ptr + offsets, mask=mask) - output = tl.insert_slice(output, output_sub, [block_start + SLICE_OFFSET], [SLICE_SIZE], [1]) + output = extension.insert_slice(output, output_sub, [block_start + SLICE_OFFSET], [SLICE_SIZE], [1]) tl.store(output_ptr + offsets, output, mask=mask) diff --git a/third_party/ascend/examples/pytest_ut/test_interleave.py b/third_party/ascend/unittest/pytest_ut/test_interleave.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_interleave.py rename to third_party/ascend/unittest/pytest_ut/test_interleave.py diff --git a/third_party/ascend/examples/pytest_ut/test_invert.py b/third_party/ascend/unittest/pytest_ut/test_invert.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_invert.py rename to third_party/ascend/unittest/pytest_ut/test_invert.py diff --git a/third_party/ascend/examples/pytest_ut/test_isfinited.py b/third_party/ascend/unittest/pytest_ut/test_isfinited.py similarity index 95% rename from third_party/ascend/examples/pytest_ut/test_isfinited.py rename to third_party/ascend/unittest/pytest_ut/test_isfinited.py index 22bab3741..aa87ecb84 100644 --- a/third_party/ascend/examples/pytest_ut/test_isfinited.py +++ b/third_party/ascend/unittest/pytest_ut/test_isfinited.py @@ -20,6 +20,7 @@ import triton import triton.language as tl +import triton.language.extra.cann.libdevice as libdevice import torch import torch_npu import pytest @@ -53,7 +54,7 @@ def torch_func(x0): def triton_kernel(out_ptr0, in_ptr0, N: tl.constexpr): idx = tl.arange(0, N) x0 = tl.load(in_ptr0 + idx) - ret = tl.math.isfinited(x0) + ret = libdevice.isfinited(x0) tl.store(out_ptr0 + idx, ret) def triton_func(x0, N): @@ -82,7 +83,7 @@ def torch_func(x0): def triton_kernel(out_ptr0, in_ptr0, N: tl.constexpr): idx = tl.arange(0, N) x0 = tl.load(in_ptr0 + idx) - ret = tl.math.finitef(x0) + ret = libdevice.finitef(x0) tl.store(out_ptr0 + idx, ret) def triton_func(x0, N): @@ -116,7 +117,7 @@ def torch_func(x0): def triton_kernel(out_ptr0, in_ptr0, N: tl.constexpr): idx = tl.arange(0, N) x0 = tl.load(in_ptr0 + idx) - ret = tl.math.isfinited(x0) + ret = libdevice.isfinited(x0) tl.store(out_ptr0 + idx, ret) def triton_func(x0, N): diff --git a/third_party/ascend/examples/pytest_ut/test_isnan.py b/third_party/ascend/unittest/pytest_ut/test_isnan.py similarity index 88% rename from third_party/ascend/examples/pytest_ut/test_isnan.py rename to third_party/ascend/unittest/pytest_ut/test_isnan.py index 4304b670d..dedf3d903 100644 --- a/third_party/ascend/examples/pytest_ut/test_isnan.py +++ b/third_party/ascend/unittest/pytest_ut/test_isnan.py @@ -52,7 +52,7 @@ def torch_func(x0): def triton_kernel(out_ptr0, in_ptr0, N: tl.constexpr): idx = tl.arange(0, N) x0 = tl.load(in_ptr0 + idx) - ret = tl.extra.ascend.libdevice.isnan(x0) + ret = tl.extra.cann.libdevice.isnan(x0) tl.store(out_ptr0 + idx, ret) def triton_func(x0, N): @@ -76,7 +76,6 @@ def triton_func(x0, N): @pytest.mark.parametrize("sigtype", invalid_types) @pytest.mark.parametrize("N", shapes) -@test_common.raises_with_match(triton.compiler.errors.CompilationError, "input arg type does not match.") def test_isnan_invalid_dtype(sigtype, N): def torch_func(x0): @@ -87,7 +86,7 @@ def torch_func(x0): def triton_kernel(out_ptr0, in_ptr0, N: tl.constexpr): idx = tl.arange(0, N) x0 = tl.load(in_ptr0 + idx) - ret = tl.extra.ascend.libdevice.isnan(x0) + ret = tl.extra.cann.libdevice.isnan(x0) tl.store(out_ptr0 + idx, ret) def triton_func(x0, N): @@ -97,8 +96,12 @@ def triton_func(x0, N): x0 = test_common.generate_tensor(shape=(N, ), dtype=sigtype).npu() x0[1] = float('nan') - - torch_ref = torch_func(x0) - triton_cal = triton_func(x0, N) - test_common.validate_cmp("bool", triton_cal, torch_ref) - assert triton_cal[1] == True + flag = False + try: + torch_ref = torch_func(x0) + triton_cal = triton_func(x0, N) + test_common.validate_cmp("bool", triton_cal, torch_ref) + assert triton_cal[1] == True + except Exception as e: + flag = True + assert flag diff --git a/third_party/ascend/examples/pytest_ut/test_join.py b/third_party/ascend/unittest/pytest_ut/test_join.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_join.py rename to third_party/ascend/unittest/pytest_ut/test_join.py diff --git a/third_party/ascend/examples/pytest_ut/test_lanzcos.py b/third_party/ascend/unittest/pytest_ut/test_lanzcos.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_lanzcos.py rename to third_party/ascend/unittest/pytest_ut/test_lanzcos.py diff --git a/third_party/ascend/examples/pytest_ut/test_launcher_empty_signature.py b/third_party/ascend/unittest/pytest_ut/test_launcher_empty_signature.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_launcher_empty_signature.py rename to third_party/ascend/unittest/pytest_ut/test_launcher_empty_signature.py diff --git a/third_party/ascend/examples/pytest_ut/test_layernorm.py b/third_party/ascend/unittest/pytest_ut/test_layernorm.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_layernorm.py rename to third_party/ascend/unittest/pytest_ut/test_layernorm.py diff --git a/third_party/ascend/examples/pytest_ut/test_ldst.py b/third_party/ascend/unittest/pytest_ut/test_ldst.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_ldst.py rename to third_party/ascend/unittest/pytest_ut/test_ldst.py diff --git a/third_party/ascend/examples/pytest_ut/test_le.py b/third_party/ascend/unittest/pytest_ut/test_le.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_le.py rename to third_party/ascend/unittest/pytest_ut/test_le.py diff --git a/third_party/ascend/examples/pytest_ut/test_lgamma.py b/third_party/ascend/unittest/pytest_ut/test_lgamma.py similarity index 98% rename from third_party/ascend/examples/pytest_ut/test_lgamma.py rename to third_party/ascend/unittest/pytest_ut/test_lgamma.py index b9bdb91f4..bc5db118a 100644 --- a/third_party/ascend/examples/pytest_ut/test_lgamma.py +++ b/third_party/ascend/unittest/pytest_ut/test_lgamma.py @@ -24,7 +24,7 @@ import torch_npu import pytest import test_common -import triton.language.extra.ascend.libdevice as libdevice +import triton.language.extra.cann.libdevice as libdevice @triton.jit diff --git a/third_party/ascend/examples/pytest_ut/test_linearize.py b/third_party/ascend/unittest/pytest_ut/test_linearize.py similarity index 96% rename from third_party/ascend/examples/pytest_ut/test_linearize.py rename to third_party/ascend/unittest/pytest_ut/test_linearize.py index 542268734..4a33ef367 100644 --- a/third_party/ascend/examples/pytest_ut/test_linearize.py +++ b/third_party/ascend/unittest/pytest_ut/test_linearize.py @@ -112,8 +112,7 @@ def triton_foo(a, d, shape, dtype): print(f"XBLOCK={XBLOCK},YBLOCK={YBLOCK}, block_dim={((x + XBLOCK -1 )//XBLOCK) * (((y*z) + YBLOCK -1 ) // YBLOCK)}") grid = ((x + XBLOCK - 1) // XBLOCK, ((y * z) + YBLOCK - 1) // YBLOCK, 1) - triton_gpu_revised[grid](a, d, out, y * z, x, SHAPE0=z, SHAPE1=y, SHAPE2=x, YBLOCK=YBLOCK, XBLOCK=XBLOCK, - enable_linearize=True) + triton_gpu_revised[grid](a, d, out, y * z, x, SHAPE0=z, SHAPE1=y, SHAPE2=x, YBLOCK=YBLOCK, XBLOCK=XBLOCK) return out @@ -224,8 +223,7 @@ def test_linearize_offset_handling(dtype, sigtype): # Run triton kernel linearize_offset_kernel[(16, )](bias_ptr=bias_ptr, output_ptr=output_ptr, experts_ids_ptr=experts_ids_ptr, N=N, EM=EM, stride_bias_e=stride_bias_e, stride_bias_n=stride_bias_n, - BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, GROUP_SIZE_M=GROUP_SIZE_M, - enable_linearize=True) + BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, GROUP_SIZE_M=GROUP_SIZE_M) # Compute reference result expected = torch_linearize_offset(bias_ptr, experts_ids_ptr, N, EM, stride_bias_e, stride_bias_n, BLOCK_SIZE_M, @@ -280,7 +278,7 @@ def test_expand_dims_and_add(batch_size, buffer_len, dtype, sigtype): cache_ref = torch.zeros(batch_size, 1, cache_len, dtype=dtype).npu() numel = batch_size * buffer_len * 2 torch_expand_dims_and_add(buffer, cache_ref, buffer_len, block, numel) - expand_dims_and_add[block, cache_len](buffer, cache, buffer_len, block, numel, enable_linearize=True) + expand_dims_and_add[block, cache_len](buffer, cache, buffer_len, block, numel) test_common.validate_cmp(sigtype, cache, cache_ref) @@ -530,7 +528,7 @@ def test_linearize_jump_load(batch_size, buffer_len, dtype, sigtype): cache2 = cache2_ref.npu() torch_save_cache_to_buffer(buffer_ref, cache1_ref, cache2_ref, buffer_len, cache_len, block) - save_cache_to_buffer[(batch_size, 1, 1)](buffer, cache1, cache2, buffer_len, block, enable_linearize=True) + save_cache_to_buffer[(batch_size, 1, 1)](buffer, cache1, cache2, buffer_len, block) test_common.validate_cmp(sigtype, buffer, buffer_ref) @@ -547,8 +545,7 @@ def test_linearize_jump_load_with_offset(batch_size, buffer_len, dtype, sigtype) cache2 = cache2_ref.npu() torch_save_cache_to_buffer_with_offset(buffer_ref, cache1_ref, cache2_ref, buffer_len, cache_len, block) - save_cache_to_buffer_with_offset[(batch_size, 1, 1)](buffer, cache1, cache2, buffer_len, block, - enable_linearize=True) + save_cache_to_buffer_with_offset[(batch_size, 1, 1)](buffer, cache1, cache2, buffer_len, block) test_common.validate_cmp(sigtype, buffer, buffer_ref) @@ -566,12 +563,10 @@ def test_linearize_rearrange(batch_size, buffer_len, dtype, sigtype): cache = cache_ref.npu() torch_rearrange_and_combine_two_buffer(buffer1_ref, buffer2_ref, cache_ref, buffer_len, num_block, block) - rearrange_and_combine_two_buffer[(batch_size, 1, 1)](buffer1, buffer2, cache, buffer_len, num_block, block, - enable_linearize=True) + rearrange_and_combine_two_buffer[(batch_size, 1, 1)](buffer1, buffer2, cache, buffer_len, num_block, block) test_common.validate_cmp(sigtype, cache, cache_ref) -@pytest.mark.skip(reason="mask load still has issues to be fixed by bisheng") @pytest.mark.parametrize('dtype,sigtype', types) @pytest.mark.parametrize('batch_size,buffer_len', cache_shapes) def test_linearize_jump_load_with_mask(batch_size, buffer_len, dtype, sigtype): @@ -588,12 +583,10 @@ def test_linearize_jump_load_with_mask(batch_size, buffer_len, dtype, sigtype): mask_num = 16 torch_save_cache_to_buffer_with_mask(buffer_ref, cache1_ref, cache2_ref, mask_ref, buffer_len, cache_len, block, mask_num) - save_cache_to_buffer_with_mask[(batch_size, 1, 1)](buffer, cache1, cache2, mask, buffer_len, block, mask_num, - enable_linearize=True) + save_cache_to_buffer_with_mask[(batch_size, 1, 1)](buffer, cache1, cache2, mask, buffer_len, block, mask_num) test_common.validate_cmp(sigtype, buffer, buffer_ref) -@pytest.mark.skip(reason="mask still has issues to be fixed") @pytest.mark.parametrize('dtype,sigtype', types) @pytest.mark.parametrize('batch_size,buffer_len', cache_shapes) def test_linearize_rearrange_with_mask(batch_size, buffer_len, dtype, sigtype): @@ -605,6 +598,5 @@ def test_linearize_rearrange_with_mask(batch_size, buffer_len, dtype, sigtype): cache2 = cache2_ref.npu() torch_rearrange_cache_with_mask(cache1_ref, cache2_ref, 2, buffer_len, num_block, block) - rearrange_cache_with_mask[(batch_size, 1, 1)](cache1, cache2, 2, buffer_len, num_block, block, - enable_linearize=True) + rearrange_cache_with_mask[(batch_size, 1, 1)](cache1, cache2, 2, buffer_len, num_block, block) test_common.validate_cmp(sigtype, cache2, cache2_ref) diff --git a/third_party/ascend/examples/pytest_ut/test_linearize_fallback.py b/third_party/ascend/unittest/pytest_ut/test_linearize_fallback.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_linearize_fallback.py rename to third_party/ascend/unittest/pytest_ut/test_linearize_fallback.py diff --git a/third_party/ascend/unittest/pytest_ut/test_linearize_mask.py b/third_party/ascend/unittest/pytest_ut/test_linearize_mask.py new file mode 100644 index 000000000..64790e193 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_linearize_mask.py @@ -0,0 +1,128 @@ +import pytest + +import triton +import triton.language as tl +import test_common + +import torch +import torch_npu + +types_all = [ + (torch.float32, 'float32'), +] + +shapes_common = [(128, 256), (127, 256), (127, 16), (129, 256), (77, 1024), (69, 512)] + +block_size = [128, 256, 1024] + + +def ceil_div(a, b): + return (a + b - 1) // b + + +def profiler_wrapper(fn, *args): + result_path = "./result_profiling_broadcast" + skip_first = 10 + wait = 0 + warmup = 3 + active = 30 + repeat = 1 + stream = torch.npu.current_stream() + experimental_config = torch_npu.profiler._ExperimentalConfig( + aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, + profiler_level=torch_npu.profiler.ProfilerLevel.Level1, l2_cache=False, data_simplification=False) + with torch_npu.profiler.profile( + activities=[torch_npu.profiler.ProfilerActivity.CPU, torch_npu.profiler.ProfilerActivity.NPU], + schedule=torch_npu.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=repeat, + skip_first=skip_first), + on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(result_path), record_shapes=True, + profile_memory=False, with_stack=False, with_flops=False, with_modules=False, + experimental_config=experimental_config) as prof: + stream.synchronize() + for _ in range(skip_first + (wait + warmup + active) * repeat): + fn(*args) + prof.step() + stream.synchronize() + + +@triton.jit +def linearize_mask_broadcast_kernel( + in_ptr, + output_ptr, + N: tl.constexpr, + M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + pid = tl.program_id(axis=0) + offset = tl.multiple_of(pid * BLOCK_SIZE_N, N) + x1 = (offset + tl.arange(0, BLOCK_SIZE_N)) // N + mask1 = (x1 < M) + data = tl.load(in_ptr + x1 * N, mask=mask1, other=0) + x2 = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + tl.store(output_ptr + x2, data) + + +def torch_linearize_mask_broadcast(in_tensor): + M = in_tensor.shape[0] // 2 + N = in_tensor.shape[1] + + output = torch.zeros_like(in_tensor) + + first_elements = in_tensor[:M, 0:1] + output[:M] = first_elements.expand(-1, N) + + return output + + +@pytest.mark.parametrize('dtype, sigtype', types_all) +@pytest.mark.parametrize('M, N', shapes_common) +@pytest.mark.parametrize('BLOCK_SIZE_N', block_size) +def test_linearize_mask_broadcast(M, N, BLOCK_SIZE_N, dtype, sigtype): + + in_tensor = torch.randn(2 * M, N, dtype=dtype).npu() + + triton_output = torch.zeros_like(in_tensor) + + grid = (ceil_div(2 * M * N, BLOCK_SIZE_N), ) + + linearize_mask_broadcast_kernel[grid](in_tensor, triton_output, N=N, M=M, BLOCK_SIZE_N=BLOCK_SIZE_N, + optimize_dynamic_offset=True) + + torch_output = torch_linearize_mask_broadcast(in_tensor.clone()) + assert torch.allclose(triton_output, torch_output, rtol=1e-5, atol=1e-8) + + +def triton_linearize_mask_broadcast(in_tensor, BLOCK_SIZE): + M = in_tensor.shape[0] // 2 + N = in_tensor.shape[1] + + triton_output = torch.zeros_like(in_tensor) + grid = (ceil_div(2 * M * N, BLOCK_SIZE), ) + + linearize_mask_broadcast_kernel[grid](in_tensor, triton_output, N=N, M=M, BLOCK_SIZE_N=BLOCK_SIZE, + optimize_dynamic_offset=True) + + +def profile_performance_test(M, N, dtype, BLOCK_SIZE): + print(f"\nDetailed performance analysis: M={M}, N={N}, dtype={dtype}, block_size={BLOCK_SIZE}") + + in_tensor = torch.randn(2 * M, N, dtype=dtype).npu() + + def wrapper_func(x): + triton_linearize_mask_broadcast(x, BLOCK_SIZE=BLOCK_SIZE) + + # Run performance analysis + profiler_wrapper(wrapper_func, in_tensor) + + +if __name__ == "__main__": + print("Broadcast Kernel Performance Test Suite") + print("Function: Broadcast first element of first M rows, set remaining M rows to zero") + + # Optional: Run detailed profiler test (specific configuration) + profile_performance_test(512, 512, torch.float32, BLOCK_SIZE=1024) + + print("\n" + "=" * 80) + print("Test completed!") + print(f"Detailed performance analysis results saved in: ./result_profiling_broadcast/") + print("=" * 80) diff --git a/third_party/ascend/unittest/pytest_ut/test_linearize_mask_fallback.py b/third_party/ascend/unittest/pytest_ut/test_linearize_mask_fallback.py new file mode 100644 index 000000000..17b6a0987 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_linearize_mask_fallback.py @@ -0,0 +1,129 @@ +import pytest + +import triton +import triton.language as tl +import test_common + +import torch +import torch_npu + +types_all = [ + (torch.float32, 'float32'), +] + +shapes_common = [(128, 256), (127, 256), (127, 16), (129, 256), (77, 1024), (69, 512)] + +block_size = [128, 256, 1024] + + +def ceil_div(a, b): + return (a + b - 1) // b + + +def profiler_wrapper(fn, *args): + result_path = "./result_profiling_mask_fallback" + skip_first = 10 + wait = 0 + warmup = 3 + active = 30 + repeat = 1 + stream = torch.npu.current_stream() + experimental_config = torch_npu.profiler._ExperimentalConfig( + aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, + profiler_level=torch_npu.profiler.ProfilerLevel.Level1, l2_cache=False, data_simplification=False) + with torch_npu.profiler.profile( + activities=[torch_npu.profiler.ProfilerActivity.CPU, torch_npu.profiler.ProfilerActivity.NPU], + schedule=torch_npu.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=repeat, + skip_first=skip_first), + on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(result_path), record_shapes=True, + profile_memory=False, with_stack=False, with_flops=False, with_modules=False, + experimental_config=experimental_config) as prof: + stream.synchronize() + for _ in range(skip_first + (wait + warmup + active) * repeat): + fn(*args) + prof.step() + stream.synchronize() + + +@triton.jit +def linearize_mask_broadcast_kernel( + in_ptr, + output_ptr, + N: tl.constexpr, + M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + pid = tl.program_id(axis=0) + offset = tl.multiple_of(pid * BLOCK_SIZE_N, N) + x1 = (offset + tl.arange(0, BLOCK_SIZE_N)) + x2 = x1 // N + mask1 = (x1 < M * N) + data = tl.load(in_ptr + x2 * N, mask=mask1, other=0) + x3 = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + tl.store(output_ptr + x3, data) + + +def torch_linearize_mask_broadcast(in_tensor): + M = in_tensor.shape[0] // 2 + N = in_tensor.shape[1] + + output = torch.zeros_like(in_tensor) + + first_elements = in_tensor[:M, 0:1] + output[:M] = first_elements.expand(-1, N) + + return output + + +@pytest.mark.parametrize('dtype, sigtype', types_all) +@pytest.mark.parametrize('M, N', shapes_common) +@pytest.mark.parametrize('BLOCK_SIZE_N', block_size) +def test_linearize_mask_broadcast(M, N, BLOCK_SIZE_N, dtype, sigtype): + + in_tensor = torch.randn(2 * M, N, dtype=dtype).npu() + + triton_output = torch.zeros_like(in_tensor) + + grid = (ceil_div(2 * M * N, BLOCK_SIZE_N), ) + + linearize_mask_broadcast_kernel[grid](in_tensor, triton_output, N=N, M=M, BLOCK_SIZE_N=BLOCK_SIZE_N, + optimize_dynamic_offset=True, enable_mask_fallback_conversion=True) + + torch_output = torch_linearize_mask_broadcast(in_tensor.clone()) + assert torch.allclose(triton_output, torch_output, rtol=1e-5, atol=1e-8) + + +def triton_linearize_mask_broadcast(in_tensor, BLOCK_SIZE): + M = in_tensor.shape[0] // 2 + N = in_tensor.shape[1] + + triton_output = torch.zeros_like(in_tensor) + grid = (ceil_div(2 * M * N, BLOCK_SIZE), ) + + linearize_mask_broadcast_kernel[grid](in_tensor, triton_output, N=N, M=M, BLOCK_SIZE_N=BLOCK_SIZE, + optimize_dynamic_offset=True, enable_mask_fallback_conversion=False) + + +def profile_performance_test(M, N, dtype, BLOCK_SIZE): + print(f"\nDetailed performance analysis: M={M}, N={N}, dtype={dtype}, block_size={BLOCK_SIZE}") + + in_tensor = torch.randn(2 * M, N, dtype=dtype).npu() + + def wrapper_func(x): + triton_linearize_mask_broadcast(x, BLOCK_SIZE=BLOCK_SIZE) + + # Run performance analysis + profiler_wrapper(wrapper_func, in_tensor) + + +if __name__ == "__main__": + print("mask fallback Kernel Performance Test Suite") + print("Function: Broadcast first element of first M rows, set remaining M rows to zero") + + # Optional: Run detailed profiler test (specific configuration) + profile_performance_test(512, 512, torch.float32, BLOCK_SIZE=1024) + + print("\n" + "=" * 80) + print("Test completed!") + print(f"Detailed performance analysis results saved in: ./result_profiling_mask_fallback/") + print("=" * 80) diff --git a/third_party/ascend/examples/pytest_ut/test_linearize_permute.py b/third_party/ascend/unittest/pytest_ut/test_linearize_permute.py similarity index 94% rename from third_party/ascend/examples/pytest_ut/test_linearize_permute.py rename to third_party/ascend/unittest/pytest_ut/test_linearize_permute.py index e4b64609c..8aaa295f5 100644 --- a/third_party/ascend/examples/pytest_ut/test_linearize_permute.py +++ b/third_party/ascend/unittest/pytest_ut/test_linearize_permute.py @@ -318,8 +318,7 @@ def test_triton_gpu_kernel(Z, Y, X, dtype, sigtype): pytest.skip(f"ynumel:{ynumel} not divisible by YBLOCK:{YBLOCK}") grid = (ceil_div(xnumel, XBLOCK), ceil_div(ynumel, YBLOCK), 1) - triton_gpu[grid](a, b, out, ynumel, xnumel, YBLOCK=YBLOCK, XBLOCK=XBLOCK, SHAPE0=Z, SHAPE1=Y, SHAPE2=X, - enable_linearize=True) + triton_gpu[grid](a, b, out, ynumel, xnumel, YBLOCK=YBLOCK, XBLOCK=XBLOCK, SHAPE0=Z, SHAPE1=Y, SHAPE2=X) test_common.validate_cmp(sigtype, out_ref, out) @@ -332,7 +331,7 @@ def test_k_load_perm_select(xnumel, ynumel, XBLOCK, YBLOCK, dtype, sigtype): out_ptr = torch.zeros_like(in_ptr) grid = (ceil_div(xnumel, XBLOCK), ceil_div(ynumel, YBLOCK), 1) - k_load_perm_select[grid](in_ptr, out_ptr, ynumel, xnumel, YBLOCK=YBLOCK, XBLOCK=XBLOCK, enable_linearize=True) + k_load_perm_select[grid](in_ptr, out_ptr, ynumel, xnumel, YBLOCK=YBLOCK, XBLOCK=XBLOCK) out_ref = torch.zeros_like(out_ptr) y_idx = torch.arange(ynumel).unsqueeze(1).npu() # [ynumel, 1] @@ -349,7 +348,14 @@ def test_k_store_perm_select(xnumel, ynumel, XBLOCK, YBLOCK, dtype, sigtype): out_ptr = torch.zeros_like(in_ptr) grid = (ceil_div(xnumel, XBLOCK), ceil_div(ynumel, YBLOCK), 1) - k_store_perm_select[grid](in_ptr, out_ptr, ynumel, xnumel, YBLOCK=YBLOCK, XBLOCK=XBLOCK, enable_linearize=True) + k_store_perm_select[grid]( + in_ptr, + out_ptr, + ynumel, + xnumel, + YBLOCK=YBLOCK, + XBLOCK=XBLOCK, + ) out_ref = torch.zeros_like(out_ptr) y_idx = torch.arange(ynumel).unsqueeze(1).npu() @@ -372,8 +378,7 @@ def test_k_load_moddiv_noperm(Z, Y, X, dtype, sigtype): YBLOCK = 64 grid = (ceil_div(xnumel, XBLOCK), ceil_div(ynumel, YBLOCK), 1) - k_load_moddiv_noperm[grid](in_flat, out, ynumel, xnumel, YBLOCK=YBLOCK, XBLOCK=XBLOCK, SHAPE0=Z, SHAPE1=Y, SHAPE2=X, - enable_linearize=True) + k_load_moddiv_noperm[grid](in_flat, out, ynumel, xnumel, YBLOCK=YBLOCK, XBLOCK=XBLOCK, SHAPE0=Z, SHAPE1=Y, SHAPE2=X) torch.testing.assert_close(out, in_flat) @@ -392,7 +397,7 @@ def test_k_store_moddiv_noperm(Z, Y, X, dtype, sigtype): grid = (ceil_div(xnumel, XBLOCK), ceil_div(ynumel, YBLOCK), 1) k_store_moddiv_noperm[grid](in_flat, out, ynumel, xnumel, YBLOCK=YBLOCK, XBLOCK=XBLOCK, SHAPE0=Z, SHAPE1=Y, - SHAPE2=X, enable_linearize=True) + SHAPE2=X) torch.testing.assert_close(out, in_flat) @@ -411,8 +416,7 @@ def test_k_load_moddiv_perm(Z, Y, X, dtype, sigtype): grid = (ceil_div(xnumel, XBLOCK), ceil_div(ynumel, YBLOCK), 1) - k_load_moddiv_perm[grid](in_flat, out, ynumel, xnumel, YBLOCK=YBLOCK, XBLOCK=XBLOCK, SHAPE0=Z, SHAPE1=Y, SHAPE2=X, - enable_linearize=True) + k_load_moddiv_perm[grid](in_flat, out, ynumel, xnumel, YBLOCK=YBLOCK, XBLOCK=XBLOCK, SHAPE0=Z, SHAPE1=Y, SHAPE2=X) torch.testing.assert_close(out, in_flat) @@ -431,8 +435,7 @@ def test_k_store_moddiv_perm(Z, Y, X, dtype, sigtype): grid = (ceil_div(xnumel, XBLOCK), ceil_div(ynumel, YBLOCK), 1) - k_store_moddiv_perm[grid](in_flat, out, ynumel, xnumel, YBLOCK=YBLOCK, XBLOCK=XBLOCK, SHAPE0=Z, SHAPE1=Y, SHAPE2=X, - enable_linearize=True) + k_store_moddiv_perm[grid](in_flat, out, ynumel, xnumel, YBLOCK=YBLOCK, XBLOCK=XBLOCK, SHAPE0=Z, SHAPE1=Y, SHAPE2=X) torch.testing.assert_close(out, in_flat) @@ -451,7 +454,7 @@ def test_k_load_store_moddiv_noperm(Z, Y, X, dtype, sigtype): grid = (ceil_div(xnumel, XBLOCK), ceil_div(ynumel, YBLOCK), 1) k_load_store_moddiv_noperm[grid](in_flat, out, ynumel, xnumel, YBLOCK=YBLOCK, XBLOCK=XBLOCK, SHAPE0=Z, SHAPE1=Y, - SHAPE2=X, enable_linearize=True) + SHAPE2=X) ref = (a + 2).contiguous().view(-1) torch.testing.assert_close(out, ref) @@ -475,7 +478,7 @@ def test_k_load_store_moddiv_perm(Z, Y, X, dtype, sigtype): grid = (ceil_div(xnumel, XBLOCK), ceil_div(ynumel, YBLOCK), 1) k_load_store_moddiv_perm[grid](a_flat, out, ynumel, xnumel, YBLOCK=YBLOCK, XBLOCK=XBLOCK, SHAPE0=Z, SHAPE1=Y, - SHAPE2=X, enable_linearize=True) + SHAPE2=X) a_reshaped = a + 1 out_ref = a_reshaped.contiguous().view(-1) @@ -499,8 +502,16 @@ def test_k_load_perm_scalar(y1_numel, y0_numel, x2_numel, Y1BLOCK, Y0BLOCK, Y0BL x2_numel=x2_numel, Y1BLOCK=Y1BLOCK, Y0BLOCK=Y0BLOCK, Y0BLOCK_SUB=Y0BLOCK_SUB, X2BLOCK_SUB=X2BLOCK_SUB) - k_load_perm_scalar[grid](in_ptr=in_ptr, out_ptr=out_ptr_triton, y1_numel=y1_numel, y0_numel=y0_numel, - x2_numel=x2_numel, Y1BLOCK=Y1BLOCK, Y0BLOCK=Y0BLOCK, Y0BLOCK_SUB=Y0BLOCK_SUB, - X2BLOCK_SUB=X2BLOCK_SUB, enable_linearize=True) + k_load_perm_scalar[grid]( + in_ptr=in_ptr, + out_ptr=out_ptr_triton, + y1_numel=y1_numel, + y0_numel=y0_numel, + x2_numel=x2_numel, + Y1BLOCK=Y1BLOCK, + Y0BLOCK=Y0BLOCK, + Y0BLOCK_SUB=Y0BLOCK_SUB, + X2BLOCK_SUB=X2BLOCK_SUB, + ) torch.testing.assert_close(out_ptr_triton, out_ptr_triton_ref, rtol=1e-5, atol=1e-6) diff --git a/third_party/ascend/examples/pytest_ut/test_load.py b/third_party/ascend/unittest/pytest_ut/test_load.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_load.py rename to third_party/ascend/unittest/pytest_ut/test_load.py diff --git a/third_party/ascend/examples/pytest_ut/test_load_store.py b/third_party/ascend/unittest/pytest_ut/test_load_store.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_load_store.py rename to third_party/ascend/unittest/pytest_ut/test_load_store.py diff --git a/third_party/ascend/examples/pytest_ut/test_log.py b/third_party/ascend/unittest/pytest_ut/test_log.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_log.py rename to third_party/ascend/unittest/pytest_ut/test_log.py diff --git a/third_party/ascend/examples/pytest_ut/test_log10.py b/third_party/ascend/unittest/pytest_ut/test_log10.py similarity index 95% rename from third_party/ascend/examples/pytest_ut/test_log10.py rename to third_party/ascend/unittest/pytest_ut/test_log10.py index 69ae32257..caa8bbcce 100644 --- a/third_party/ascend/examples/pytest_ut/test_log10.py +++ b/third_party/ascend/unittest/pytest_ut/test_log10.py @@ -1,90 +1,90 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import pytest -import triton -import torch -import triton.language as tl -import triton.language.extra.ascend.libdevice as libdevice -import test_common - - -@triton.jit -def log10_kernel(x_ptr, y_ptr, n_elements, BLOCK_SIZE: tl.constexpr): - pid = tl.program_id(axis=0) - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - - mask = offsets < n_elements - - x = tl.load(x_ptr + offsets, mask=mask) - - y = libdevice.log10(x) - - tl.store(y_ptr + offsets, y, mask=mask) - - -@pytest.mark.parametrize('shape', [ - (12, 16), -]) -@pytest.mark.parametrize('dtype', ['float32']) -def test_log10(shape, dtype): - x = test_common.generate_tensor(shape, dtype).npu() - - x[0, 0] = 1.0 # log10(1) = 0 - x[0, 1] = 10.0 # log10(10) = 1 - x[0, 2] = 100.0 # log10(100) = 2 - x[0, 3] = 0.1 # log10(0.1) = -1 - x[0, 4] = 0.0 # log10(0) = -inf - x[0, 5] = -1.0 # log10(-1) = NaN - x[0, 6] = 2.0 # log10(2) ≈ 0.3010 - - y = torch.empty_like(x) - - BLOCK_SIZE = 192 - grid = lambda meta: (triton.cdiv(192, meta['BLOCK_SIZE']), ) - - log10_kernel[grid](x, y, 192, BLOCK_SIZE=BLOCK_SIZE) - - expected = torch.log10(x) - print(f"triton_ret = {y}") - print(f"triton_ret = {expected}") - - valid_mask = (x > 0) - - if torch.any(valid_mask): - valid_y = y[valid_mask] - valid_expected = expected[valid_mask] - - torch.testing.assert_close(valid_y, valid_expected, rtol=1e-3, atol=1e-3) - - # Check if negative values return NaN - negative_mask = (x < 0) - if torch.any(negative_mask): - negative_y = y[negative_mask] - assert torch.all(torch.isnan(negative_y)), "Negative inputs should return NaN" - - # Check if zero value returns -inf - zero_mask = (x == 0) - if torch.any(zero_mask): - zero_y = y[zero_mask] - assert torch.all(torch.isinf(zero_y) & (zero_y < 0)), "Zero inputs should return -inf" - - print("✓ LOG10 test PASSED!") +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import pytest +import triton +import torch +import triton.language as tl +import triton.language.extra.cann.libdevice as libdevice +import test_common + + +@triton.jit +def log10_kernel(x_ptr, y_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + mask = offsets < n_elements + + x = tl.load(x_ptr + offsets, mask=mask) + + y = libdevice.log10(x) + + tl.store(y_ptr + offsets, y, mask=mask) + + +@pytest.mark.parametrize('shape', [ + (12, 16), +]) +@pytest.mark.parametrize('dtype', ['float32']) +def test_log10(shape, dtype): + x = test_common.generate_tensor(shape, dtype).npu() + + x[0, 0] = 1.0 # log10(1) = 0 + x[0, 1] = 10.0 # log10(10) = 1 + x[0, 2] = 100.0 # log10(100) = 2 + x[0, 3] = 0.1 # log10(0.1) = -1 + x[0, 4] = 0.0 # log10(0) = -inf + x[0, 5] = -1.0 # log10(-1) = NaN + x[0, 6] = 2.0 # log10(2) ≈ 0.3010 + + y = torch.empty_like(x) + + BLOCK_SIZE = 192 + grid = lambda meta: (triton.cdiv(192, meta['BLOCK_SIZE']), ) + + log10_kernel[grid](x, y, 192, BLOCK_SIZE=BLOCK_SIZE) + + expected = torch.log10(x) + print(f"triton_ret = {y}") + print(f"triton_ret = {expected}") + + valid_mask = (x > 0) + + if torch.any(valid_mask): + valid_y = y[valid_mask] + valid_expected = expected[valid_mask] + + torch.testing.assert_close(valid_y, valid_expected, rtol=1e-3, atol=1e-3) + + # Check if negative values return NaN + negative_mask = (x < 0) + if torch.any(negative_mask): + negative_y = y[negative_mask] + assert torch.all(torch.isnan(negative_y)), "Negative inputs should return NaN" + + # Check if zero value returns -inf + zero_mask = (x == 0) + if torch.any(zero_mask): + zero_y = y[zero_mask] + assert torch.all(torch.isinf(zero_y) & (zero_y < 0)), "Zero inputs should return -inf" + + print("✓ LOG10 test PASSED!") diff --git a/third_party/ascend/examples/pytest_ut/test_log1p.py b/third_party/ascend/unittest/pytest_ut/test_log1p.py similarity index 97% rename from third_party/ascend/examples/pytest_ut/test_log1p.py rename to third_party/ascend/unittest/pytest_ut/test_log1p.py index f33a0ec27..c066ea35c 100644 --- a/third_party/ascend/examples/pytest_ut/test_log1p.py +++ b/third_party/ascend/unittest/pytest_ut/test_log1p.py @@ -23,7 +23,7 @@ import triton.language as tl import torch import test_common -import triton.language.extra.ascend.libdevice as libdevice +import triton.language.extra.cann.libdevice as libdevice def torch_log1p(x0, x1): diff --git a/third_party/ascend/examples/pytest_ut/test_log2.py b/third_party/ascend/unittest/pytest_ut/test_log2.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_log2.py rename to third_party/ascend/unittest/pytest_ut/test_log2.py diff --git a/third_party/ascend/examples/pytest_ut/test_log_2.py b/third_party/ascend/unittest/pytest_ut/test_log_2.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_log_2.py rename to third_party/ascend/unittest/pytest_ut/test_log_2.py diff --git a/third_party/ascend/examples/pytest_ut/test_logical_and.py b/third_party/ascend/unittest/pytest_ut/test_logical_and.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_logical_and.py rename to third_party/ascend/unittest/pytest_ut/test_logical_and.py diff --git a/third_party/ascend/examples/pytest_ut/test_logical_or.py b/third_party/ascend/unittest/pytest_ut/test_logical_or.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_logical_or.py rename to third_party/ascend/unittest/pytest_ut/test_logical_or.py diff --git a/third_party/ascend/examples/pytest_ut/test_lshift.py b/third_party/ascend/unittest/pytest_ut/test_lshift.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_lshift.py rename to third_party/ascend/unittest/pytest_ut/test_lshift.py diff --git a/third_party/ascend/examples/pytest_ut/test_lt.py b/third_party/ascend/unittest/pytest_ut/test_lt.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_lt.py rename to third_party/ascend/unittest/pytest_ut/test_lt.py diff --git a/third_party/ascend/examples/pytest_ut/test_makeblockptr_permute.py b/third_party/ascend/unittest/pytest_ut/test_makeblockptr_permute.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_makeblockptr_permute.py rename to third_party/ascend/unittest/pytest_ut/test_makeblockptr_permute.py diff --git a/third_party/ascend/examples/pytest_ut/test_max_constancy.py b/third_party/ascend/unittest/pytest_ut/test_max_constancy.py similarity index 93% rename from third_party/ascend/examples/pytest_ut/test_max_constancy.py rename to third_party/ascend/unittest/pytest_ut/test_max_constancy.py index 155772467..679879014 100644 --- a/third_party/ascend/examples/pytest_ut/test_max_constancy.py +++ b/third_party/ascend/unittest/pytest_ut/test_max_constancy.py @@ -1,26 +1,27 @@ -import torch -import triton -import triton.language as tl -import pytest - - -@triton.jit -def compile_hint_kernel(input_ptr, output_ptr, n_elements: tl.constexpr, SIZE: tl.constexpr): - offsets = tl.arange(0, n_elements) - val = tl.load(input_ptr + offsets) - val = tl.max_constancy(val, SIZE) - tl.store(output_ptr + offsets, val) - - -@pytest.mark.parametrize('sigtype', [ - 'int32', - #'int64', 'int16', 'int8', - #'uint8', 'uint16', 'uint32', 'uint64', - #'float32', 'float16', 'bfloat16', 'bool' -]) -def test_compile_hint(sigtype): - n_elements = 10 - dtype = eval(f"torch.{sigtype}") - x = torch.ones((n_elements, ), dtype=dtype).npu() - y = torch.zeros((n_elements, ), dtype=dtype).npu() - compile_hint_kernel[(1, )](x, y, n_elements, 1) +import torch +import triton +import triton.language as tl +import triton.language.extra.cann.extension as extension +import pytest + + +@triton.jit +def compile_hint_kernel(input_ptr, output_ptr, n_elements: tl.constexpr, SIZE: tl.constexpr): + offsets = tl.arange(0, n_elements) + val = tl.load(input_ptr + offsets) + val = tl.max_constancy(val, SIZE) + tl.store(output_ptr + offsets, val) + + +@pytest.mark.parametrize('sigtype', [ + 'int32', + #'int64', 'int16', 'int8', + #'uint8', 'uint16', 'uint32', 'uint64', + #'float32', 'float16', 'bfloat16', 'bool' +]) +def test_compile_hint(sigtype): + n_elements = 10 + dtype = eval(f"torch.{sigtype}") + x = torch.ones((n_elements, ), dtype=dtype).npu() + y = torch.zeros((n_elements, ), dtype=dtype).npu() + compile_hint_kernel[(1, )](x, y, n_elements, 1) diff --git a/third_party/ascend/examples/pytest_ut/test_max_contiguous.py b/third_party/ascend/unittest/pytest_ut/test_max_contiguous.py similarity index 96% rename from third_party/ascend/examples/pytest_ut/test_max_contiguous.py rename to third_party/ascend/unittest/pytest_ut/test_max_contiguous.py index 966035afd..c5c226825 100644 --- a/third_party/ascend/examples/pytest_ut/test_max_contiguous.py +++ b/third_party/ascend/unittest/pytest_ut/test_max_contiguous.py @@ -1,26 +1,26 @@ -import torch -import triton -import triton.language as tl -import pytest - - -@triton.jit -def compile_hint_kernel(input_ptr, output_ptr, n_elements: tl.constexpr, SIZE: tl.constexpr): - offsets = tl.arange(0, n_elements) - val = tl.load(input_ptr + offsets) - val = tl.max_contiguous(val, SIZE) - tl.store(output_ptr + offsets, val) - - -@pytest.mark.parametrize('sigtype', [ - 'int32', - #'int64', 'int16', 'int8', - #'uint8', 'uint16', 'uint32', 'uint64', - #'float32', 'float16', 'bfloat16', 'bool' -]) -def test_compile_hint(sigtype): - n_elements = 10 - dtype = eval(f"torch.{sigtype}") - x = torch.ones((n_elements, ), dtype=dtype).npu() - y = torch.zeros((n_elements, ), dtype=dtype).npu() - compile_hint_kernel[(1, )](x, y, n_elements, 1) +import torch +import triton +import triton.language as tl +import pytest + + +@triton.jit +def compile_hint_kernel(input_ptr, output_ptr, n_elements: tl.constexpr, SIZE: tl.constexpr): + offsets = tl.arange(0, n_elements) + val = tl.load(input_ptr + offsets) + val = tl.max_contiguous(val, SIZE) + tl.store(output_ptr + offsets, val) + + +@pytest.mark.parametrize('sigtype', [ + 'int32', + #'int64', 'int16', 'int8', + #'uint8', 'uint16', 'uint32', 'uint64', + #'float32', 'float16', 'bfloat16', 'bool' +]) +def test_compile_hint(sigtype): + n_elements = 10 + dtype = eval(f"torch.{sigtype}") + x = torch.ones((n_elements, ), dtype=dtype).npu() + y = torch.zeros((n_elements, ), dtype=dtype).npu() + compile_hint_kernel[(1, )](x, y, n_elements, 1) diff --git a/third_party/ascend/examples/pytest_ut/test_max_dim0.py b/third_party/ascend/unittest/pytest_ut/test_max_dim0.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_max_dim0.py rename to third_party/ascend/unittest/pytest_ut/test_max_dim0.py diff --git a/third_party/ascend/examples/pytest_ut/test_max_dim1.py b/third_party/ascend/unittest/pytest_ut/test_max_dim1.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_max_dim1.py rename to third_party/ascend/unittest/pytest_ut/test_max_dim1.py diff --git a/third_party/ascend/examples/pytest_ut/test_max_propagate_nan.py b/third_party/ascend/unittest/pytest_ut/test_max_propagate_nan.py similarity index 96% rename from third_party/ascend/examples/pytest_ut/test_max_propagate_nan.py rename to third_party/ascend/unittest/pytest_ut/test_max_propagate_nan.py index 0f2e60654..198a37180 100644 --- a/third_party/ascend/examples/pytest_ut/test_max_propagate_nan.py +++ b/third_party/ascend/unittest/pytest_ut/test_max_propagate_nan.py @@ -1,20 +1,20 @@ -import torch -import torch_npu -import triton -import triton.language as tl -import pytest - - -def test_max_propagate_nan(): - - @triton.jit - def func(in_ptr, out_ptr): - a = tl.load(in_ptr + tl.arange(0, 8)[:, None] * 8 + tl.arange(0, 8)[None, :]) - a = tl.max(a, 0, propagate_nan=True) - tl.store(out_ptr + tl.arange(0, 8), a) - - a = torch.randn((8, 8), device="npu") - std = a.max(0)[0] - ans = torch.zeros((8, ), dtype=torch.float32, device="npu") - func[1, 1, 1](a, ans) - torch.testing.assert_close(std, ans) +import torch +import torch_npu +import triton +import triton.language as tl +import pytest + + +def test_max_propagate_nan(): + + @triton.jit + def func(in_ptr, out_ptr): + a = tl.load(in_ptr + tl.arange(0, 8)[:, None] * 8 + tl.arange(0, 8)[None, :]) + a = tl.max(a, 0, propagate_nan=True) + tl.store(out_ptr + tl.arange(0, 8), a) + + a = torch.randn((8, 8), device="npu") + std = a.max(0)[0] + ans = torch.zeros((8, ), dtype=torch.float32, device="npu") + func[1, 1, 1](a, ans) + torch.testing.assert_close(std, ans) diff --git a/third_party/ascend/examples/pytest_ut/test_max_vector.py b/third_party/ascend/unittest/pytest_ut/test_max_vector.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_max_vector.py rename to third_party/ascend/unittest/pytest_ut/test_max_vector.py diff --git a/third_party/ascend/examples/pytest_ut/test_maximum.py b/third_party/ascend/unittest/pytest_ut/test_maximum.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_maximum.py rename to third_party/ascend/unittest/pytest_ut/test_maximum.py diff --git a/third_party/ascend/examples/pytest_ut/test_mean_dim0.py b/third_party/ascend/unittest/pytest_ut/test_mean_dim0.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_mean_dim0.py rename to third_party/ascend/unittest/pytest_ut/test_mean_dim0.py diff --git a/third_party/ascend/examples/pytest_ut/test_mean_dim1.py b/third_party/ascend/unittest/pytest_ut/test_mean_dim1.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_mean_dim1.py rename to third_party/ascend/unittest/pytest_ut/test_mean_dim1.py diff --git a/third_party/ascend/examples/pytest_ut/test_mean_vector.py b/third_party/ascend/unittest/pytest_ut/test_mean_vector.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_mean_vector.py rename to third_party/ascend/unittest/pytest_ut/test_mean_vector.py diff --git a/third_party/ascend/examples/pytest_ut/test_min_dim0.py b/third_party/ascend/unittest/pytest_ut/test_min_dim0.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_min_dim0.py rename to third_party/ascend/unittest/pytest_ut/test_min_dim0.py diff --git a/third_party/ascend/examples/pytest_ut/test_min_dim1.py b/third_party/ascend/unittest/pytest_ut/test_min_dim1.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_min_dim1.py rename to third_party/ascend/unittest/pytest_ut/test_min_dim1.py diff --git a/third_party/ascend/examples/pytest_ut/test_min_vector.py b/third_party/ascend/unittest/pytest_ut/test_min_vector.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_min_vector.py rename to third_party/ascend/unittest/pytest_ut/test_min_vector.py diff --git a/third_party/ascend/examples/pytest_ut/test_minimum.py b/third_party/ascend/unittest/pytest_ut/test_minimum.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_minimum.py rename to third_party/ascend/unittest/pytest_ut/test_minimum.py diff --git a/third_party/ascend/examples/pytest_ut/test_mod.py b/third_party/ascend/unittest/pytest_ut/test_mod.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_mod.py rename to third_party/ascend/unittest/pytest_ut/test_mod.py diff --git a/third_party/ascend/examples/pytest_ut/test_mul.py b/third_party/ascend/unittest/pytest_ut/test_mul.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_mul.py rename to third_party/ascend/unittest/pytest_ut/test_mul.py diff --git a/third_party/ascend/examples/pytest_ut/test_multi_return.py b/third_party/ascend/unittest/pytest_ut/test_multi_return.py similarity index 98% rename from third_party/ascend/examples/pytest_ut/test_multi_return.py rename to third_party/ascend/unittest/pytest_ut/test_multi_return.py index ab2497df6..1935e5225 100644 --- a/third_party/ascend/examples/pytest_ut/test_multi_return.py +++ b/third_party/ascend/unittest/pytest_ut/test_multi_return.py @@ -23,6 +23,7 @@ import torch import torch_npu import triton +from triton.language.math import tanh import triton.language as tl device = 'npu' @@ -156,7 +157,7 @@ def liger_cross_entropy_kernel( d = 0.0 # d is the sum. use the notation from the paper ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation if HAS_SOFTCAPPING: - ori_X_y = softcap * tl.math.tanh(ori_X_y / softcap) + ori_X_y = softcap * tanh(ori_X_y / softcap) # Label smoothing is a general case of normal cross entropy scaled_x_sum = 0.0 @@ -171,7 +172,7 @@ def liger_cross_entropy_kernel( # Ensure float32 precision for softmax calculation ).cast(tl.float32) if HAS_SOFTCAPPING: - X_block = softcap * tl.math.tanh(X_block / softcap) + X_block = softcap * tanh(X_block / softcap) block_max = tl.max(X_block) if label_smoothing > 0: # scale X beforehand to avoid overflow @@ -199,7 +200,7 @@ def liger_cross_entropy_kernel( # Ensure float32 precision for softmax calculation ).cast(tl.float32) if HAS_SOFTCAPPING: - intermediate = tl.math.tanh(X_block / softcap) + intermediate = tanh(X_block / softcap) X_block = softcap * intermediate if not HAS_WEIGHT: @@ -234,7 +235,7 @@ def liger_cross_entropy_kernel( X_block = dloss_ori + dloss_smooth + dz_loss # chain rule softcapping - # d(softcap * tl.math.tanh(x / softcap)) = (1 - tl.math.tanh^2(x / softcap)) + # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap)) if HAS_SOFTCAPPING: X_block = X_block * (1 - intermediate * intermediate) diff --git a/third_party/ascend/examples/pytest_ut/test_multiple_of.py b/third_party/ascend/unittest/pytest_ut/test_multiple_of.py similarity index 96% rename from third_party/ascend/examples/pytest_ut/test_multiple_of.py rename to third_party/ascend/unittest/pytest_ut/test_multiple_of.py index 099804434..6813ee8bc 100644 --- a/third_party/ascend/examples/pytest_ut/test_multiple_of.py +++ b/third_party/ascend/unittest/pytest_ut/test_multiple_of.py @@ -1,26 +1,26 @@ -import torch -import triton -import triton.language as tl -import pytest - - -@triton.jit -def compile_hint_kernel(input_ptr, output_ptr, n_elements: tl.constexpr, SIZE: tl.constexpr): - offsets = tl.arange(0, n_elements) - val = tl.load(input_ptr + offsets) - val = tl.multiple_of(val, SIZE) - tl.store(output_ptr + offsets, val) - - -@pytest.mark.parametrize('sigtype', [ - 'int32', - #'int64', 'int16', 'int8', - #'uint8', 'uint16', 'uint32', 'uint64', - #'float32', 'float16', 'bfloat16', 'bool' -]) -def test_compile_hint(sigtype): - n_elements = 10 - dtype = eval(f"torch.{sigtype}") - x = torch.ones((n_elements, ), dtype=dtype).npu() - y = torch.zeros((n_elements, ), dtype=dtype).npu() - compile_hint_kernel[(1, )](x, y, n_elements, 1) +import torch +import triton +import triton.language as tl +import pytest + + +@triton.jit +def compile_hint_kernel(input_ptr, output_ptr, n_elements: tl.constexpr, SIZE: tl.constexpr): + offsets = tl.arange(0, n_elements) + val = tl.load(input_ptr + offsets) + val = tl.multiple_of(val, SIZE) + tl.store(output_ptr + offsets, val) + + +@pytest.mark.parametrize('sigtype', [ + 'int32', + #'int64', 'int16', 'int8', + #'uint8', 'uint16', 'uint32', 'uint64', + #'float32', 'float16', 'bfloat16', 'bool' +]) +def test_compile_hint(sigtype): + n_elements = 10 + dtype = eval(f"torch.{sigtype}") + x = torch.ones((n_elements, ), dtype=dtype).npu() + y = torch.zeros((n_elements, ), dtype=dtype).npu() + compile_hint_kernel[(1, )](x, y, n_elements, 1) diff --git a/third_party/ascend/examples/pytest_ut/test_ne.py b/third_party/ascend/unittest/pytest_ut/test_ne.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_ne.py rename to third_party/ascend/unittest/pytest_ut/test_ne.py diff --git a/third_party/ascend/examples/pytest_ut/test_nearbyint.py b/third_party/ascend/unittest/pytest_ut/test_nearbyint.py similarity index 95% rename from third_party/ascend/examples/pytest_ut/test_nearbyint.py rename to third_party/ascend/unittest/pytest_ut/test_nearbyint.py index ffd1cbfac..f02b896d7 100644 --- a/third_party/ascend/examples/pytest_ut/test_nearbyint.py +++ b/third_party/ascend/unittest/pytest_ut/test_nearbyint.py @@ -1,69 +1,69 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import pytest -import triton -import torch -import triton.language as tl -import triton.language.extra.ascend.libdevice as libdevice -import test_common - - -@triton.jit -def nearbyint_kernel(x_ptr, y_ptr, n_elements, BLOCK_SIZE: tl.constexpr): - pid = tl.program_id(axis=0) - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - - mask = offsets < n_elements - - x = tl.load(x_ptr + offsets, mask=mask) - - y = libdevice.nearbyint(x) - - tl.store(y_ptr + offsets, y, mask=mask) - - -@pytest.mark.parametrize('shape', [ - (12, 16), -]) -@pytest.mark.parametrize('dtype', ['float32']) -def test_nearbyint(shape, dtype): - n_elements = shape[0] * shape[1] - x = test_common.generate_tensor(shape, dtype).npu() - - # Ensure some boundary cases are included - x[0, 0] = 0.0 - x[0, 1] = 3.14 # Should be rounded to 3.0 - x[0, 2] = -2.71 # Should be rounded to -3.0 - x[0, 3] = 5.0 # Integer, should remain unchanged - x[0, 4] = -3.0 # Negative integer, should remain unchanged - x[0, 5] = 2.5 # Should be rounded to 3.0 - x[0, 6] = -1.5 # Should be rounded to -2.0 - - y = torch.empty_like(x) - - BLOCK_SIZE = 192 - grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) - - nearbyint_kernel[grid](x, y, n_elements, BLOCK_SIZE=BLOCK_SIZE) - - expected = torch.round(x) - test_common.validate_cmp('float32', y, expected) +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import pytest +import triton +import torch +import triton.language as tl +import triton.language.extra.cann.libdevice as libdevice +import test_common + + +@triton.jit +def nearbyint_kernel(x_ptr, y_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + mask = offsets < n_elements + + x = tl.load(x_ptr + offsets, mask=mask) + + y = libdevice.nearbyint(x) + + tl.store(y_ptr + offsets, y, mask=mask) + + +@pytest.mark.parametrize('shape', [ + (12, 16), +]) +@pytest.mark.parametrize('dtype', ['float32']) +def test_nearbyint(shape, dtype): + n_elements = shape[0] * shape[1] + x = test_common.generate_tensor(shape, dtype).npu() + + # Ensure some boundary cases are included + x[0, 0] = 0.0 + x[0, 1] = 3.14 # Should be rounded to 3.0 + x[0, 2] = -2.71 # Should be rounded to -3.0 + x[0, 3] = 5.0 # Integer, should remain unchanged + x[0, 4] = -3.0 # Negative integer, should remain unchanged + x[0, 5] = 2.5 # Should be rounded to 3.0 + x[0, 6] = -1.5 # Should be rounded to -2.0 + + y = torch.empty_like(x) + + BLOCK_SIZE = 192 + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + + nearbyint_kernel[grid](x, y, n_elements, BLOCK_SIZE=BLOCK_SIZE) + + expected = torch.round(x) + test_common.validate_cmp('float32', y, expected) diff --git a/third_party/ascend/examples/pytest_ut/test_nearest.py b/third_party/ascend/unittest/pytest_ut/test_nearest.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_nearest.py rename to third_party/ascend/unittest/pytest_ut/test_nearest.py diff --git a/third_party/ascend/examples/pytest_ut/test_neg.py b/third_party/ascend/unittest/pytest_ut/test_neg.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_neg.py rename to third_party/ascend/unittest/pytest_ut/test_neg.py diff --git a/third_party/ascend/examples/pytest_ut/test_nextafter.py b/third_party/ascend/unittest/pytest_ut/test_nextafter.py similarity index 97% rename from third_party/ascend/examples/pytest_ut/test_nextafter.py rename to third_party/ascend/unittest/pytest_ut/test_nextafter.py index a12ffd00c..4df371dd1 100644 --- a/third_party/ascend/examples/pytest_ut/test_nextafter.py +++ b/third_party/ascend/unittest/pytest_ut/test_nextafter.py @@ -22,7 +22,7 @@ import triton import triton.language as tl -import triton.language.extra.ascend.libdevice as libdevice +import triton.language.extra.cann.libdevice as libdevice import test_common import torch diff --git a/third_party/ascend/examples/pytest_ut/test_not.py b/third_party/ascend/unittest/pytest_ut/test_not.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_not.py rename to third_party/ascend/unittest/pytest_ut/test_not.py diff --git a/third_party/ascend/examples/pytest_ut/test_npu_indexing.py b/third_party/ascend/unittest/pytest_ut/test_npu_indexing.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_npu_indexing.py rename to third_party/ascend/unittest/pytest_ut/test_npu_indexing.py diff --git a/third_party/ascend/examples/pytest_ut/test_npu_indexing2.py b/third_party/ascend/unittest/pytest_ut/test_npu_indexing2.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_npu_indexing2.py rename to third_party/ascend/unittest/pytest_ut/test_npu_indexing2.py diff --git a/third_party/ascend/examples/pytest_ut/test_or.py b/third_party/ascend/unittest/pytest_ut/test_or.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_or.py rename to third_party/ascend/unittest/pytest_ut/test_or.py diff --git a/third_party/ascend/unittest/pytest_ut/test_parallel.py b/third_party/ascend/unittest/pytest_ut/test_parallel.py new file mode 100644 index 000000000..ff41149de --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_parallel.py @@ -0,0 +1,99 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import pytest + +import triton +import triton.language as tl + +import torch +import torch_npu + +import test_common +import triton.language.extra.cann.extension as extension + + +@triton.jit +def triton_add(in_ptr0, in_ptr1, out_ptr0, L: tl.constexpr, M: tl.constexpr, N: tl.constexpr): + lblk_idx = tl.arange(0, L) + mblk_idx = tl.arange(0, M) + nblk_idx = tl.arange(0, N) + idx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] + x0 = tl.load(in_ptr0 + idx) + x1 = tl.load(in_ptr1 + idx) + ret = x0 + x1 + + for _ in extension.parallel(2, 5, 2, bind_sub_block=False): + ret = ret + x1 + + for _ in extension.parallel(2, 10, 3, bind_sub_block=False): + ret = ret + x0 + + odx = lblk_idx[:, None, None] * N * M + mblk_idx[None, :, None] * N + nblk_idx[None, None, :] + tl.store(out_ptr0 + odx, ret) + + +testlist = [ + (3, 5, 8), +] + + +def get_torch_typename(dtype): + if dtype == 'float32': + tyname = torch.float32 + elif dtype == 'int32': + tyname = torch.int32 + elif dtype == 'int64': + tyname = torch.int64 + elif dtype == 'float16': + tyname = torch.float16 + elif dtype == 'bfloat16': + tyname = torch.bfloat16 + elif dtype == 'int16': + tyname = torch.int16 + elif dtype == 'int8': + tyname = torch.int8 + elif dtype == 'bool': + tyname = torch.bool + else: + raise ValueError('Invalid parameter \"dtype\" is found : {}'.format(dtype)) + + return tyname + + +typelist = ['int8', 'int16', 'int32', 'int64'] + + +@pytest.mark.parametrize('L, M, N', testlist) +@pytest.mark.parametrize('sigtype', typelist) +def test_add_bind_false(sigtype, L, M, N): + dtype = get_torch_typename(sigtype) + shape = (L, M, N) + x0 = test_common.generate_tensor(shape=(L, M, N), dtype=sigtype).npu() + x1 = test_common.generate_tensor(shape=(L, M, N), dtype=sigtype).npu() + y_ref = x0 + x1 + x1 + x1 + x0 + x0 + x0 + + output = torch.zeros(shape, dtype=dtype).npu() + h = triton_add[1, 1, 1](x0, x1, output, L, M, N) + + test_common.validate_cmp(sigtype, output, y_ref) + code_str = h.asm["ttadapter"] + count = code_str.count("hivm.parallel_loop") + assert count == 2 diff --git a/third_party/ascend/examples/pytest_ut/test_permute.py b/third_party/ascend/unittest/pytest_ut/test_permute.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_permute.py rename to third_party/ascend/unittest/pytest_ut/test_permute.py diff --git a/third_party/ascend/examples/pytest_ut/test_permute_full.py b/third_party/ascend/unittest/pytest_ut/test_permute_full.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_permute_full.py rename to third_party/ascend/unittest/pytest_ut/test_permute_full.py diff --git a/third_party/ascend/examples/pytest_ut/test_permute_reshape.py b/third_party/ascend/unittest/pytest_ut/test_permute_reshape.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_permute_reshape.py rename to third_party/ascend/unittest/pytest_ut/test_permute_reshape.py diff --git a/third_party/ascend/examples/pytest_ut/test_pointer_type.py b/third_party/ascend/unittest/pytest_ut/test_pointer_type.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_pointer_type.py rename to third_party/ascend/unittest/pytest_ut/test_pointer_type.py diff --git a/third_party/ascend/examples/pytest_ut/test_pow.py b/third_party/ascend/unittest/pytest_ut/test_pow.py similarity index 98% rename from third_party/ascend/examples/pytest_ut/test_pow.py rename to third_party/ascend/unittest/pytest_ut/test_pow.py index 1285a1d5b..1cf7ac23f 100644 --- a/third_party/ascend/examples/pytest_ut/test_pow.py +++ b/third_party/ascend/unittest/pytest_ut/test_pow.py @@ -20,7 +20,7 @@ import triton import triton.language as tl -from triton.language.extra.ascend.libdevice import pow +from triton.language.extra.cann.libdevice import pow import torch import torch_npu import pytest diff --git a/third_party/ascend/examples/pytest_ut/test_precise_div.py b/third_party/ascend/unittest/pytest_ut/test_precise_div.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_precise_div.py rename to third_party/ascend/unittest/pytest_ut/test_precise_div.py diff --git a/third_party/ascend/examples/pytest_ut/test_precise_sqrt.py b/third_party/ascend/unittest/pytest_ut/test_precise_sqrt.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_precise_sqrt.py rename to third_party/ascend/unittest/pytest_ut/test_precise_sqrt.py diff --git a/third_party/ascend/examples/pytest_ut/test_ptr_add.py b/third_party/ascend/unittest/pytest_ut/test_ptr_add.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_ptr_add.py rename to third_party/ascend/unittest/pytest_ut/test_ptr_add.py diff --git a/third_party/ascend/examples/pytest_ut/test_rand.py b/third_party/ascend/unittest/pytest_ut/test_rand.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_rand.py rename to third_party/ascend/unittest/pytest_ut/test_rand.py diff --git a/third_party/ascend/examples/pytest_ut/test_range.py b/third_party/ascend/unittest/pytest_ut/test_range.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_range.py rename to third_party/ascend/unittest/pytest_ut/test_range.py diff --git a/third_party/ascend/examples/pytest_ut/test_ravel.py b/third_party/ascend/unittest/pytest_ut/test_ravel.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_ravel.py rename to third_party/ascend/unittest/pytest_ut/test_ravel.py diff --git a/third_party/ascend/examples/pytest_ut/test_reduce_count_vector.py b/third_party/ascend/unittest/pytest_ut/test_reduce_count_vector.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_reduce_count_vector.py rename to third_party/ascend/unittest/pytest_ut/test_reduce_count_vector.py diff --git a/third_party/ascend/examples/pytest_ut/test_reduce_mean.py b/third_party/ascend/unittest/pytest_ut/test_reduce_mean.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_reduce_mean.py rename to third_party/ascend/unittest/pytest_ut/test_reduce_mean.py diff --git a/third_party/ascend/examples/pytest_ut/test_reduce_sum.py b/third_party/ascend/unittest/pytest_ut/test_reduce_sum.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_reduce_sum.py rename to third_party/ascend/unittest/pytest_ut/test_reduce_sum.py diff --git a/third_party/ascend/examples/pytest_ut/test_relu.py b/third_party/ascend/unittest/pytest_ut/test_relu.py similarity index 97% rename from third_party/ascend/examples/pytest_ut/test_relu.py rename to third_party/ascend/unittest/pytest_ut/test_relu.py index 5b5a71374..63e9bd971 100644 --- a/third_party/ascend/examples/pytest_ut/test_relu.py +++ b/third_party/ascend/unittest/pytest_ut/test_relu.py @@ -23,7 +23,7 @@ import torch import pytest import test_common -import triton.language.extra.ascend.libdevice as libdevice +import triton.language.extra.cann.libdevice as libdevice def torch_relu(x0, x1): diff --git a/third_party/ascend/examples/pytest_ut/test_reshape.py b/third_party/ascend/unittest/pytest_ut/test_reshape.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_reshape.py rename to third_party/ascend/unittest/pytest_ut/test_reshape.py diff --git a/third_party/ascend/examples/pytest_ut/test_resize_performance.py b/third_party/ascend/unittest/pytest_ut/test_resize_performance.py similarity index 97% rename from third_party/ascend/examples/pytest_ut/test_resize_performance.py rename to third_party/ascend/unittest/pytest_ut/test_resize_performance.py index 3ba2cabfc..6478e2f6b 100644 --- a/third_party/ascend/examples/pytest_ut/test_resize_performance.py +++ b/third_party/ascend/unittest/pytest_ut/test_resize_performance.py @@ -1,102 +1,102 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import torch, torch_npu -import triton -import triton.language as tl -import numpy as np -import math -import pytest - - -@triton.jit -def nearest_resize_kernel_col_tile(img_src_ptr, img_dst_ptr, src_rows: tl.constexpr, src_cols: tl.constexpr, - dst_rows: tl.constexpr, dst_cols: tl.constexpr, RR_H: tl.constexpr, - RR_W: tl.constexpr, stride_in_h: tl.constexpr, stride_in_w: tl.constexpr, - stride_in_c: tl.constexpr, stride_out_h: tl.constexpr, stride_out_w: tl.constexpr, - stride_out_c: tl.constexpr, BLOCK_SIZE: tl.constexpr): - #RR_H和RR_W分别为高和宽的缩放比例 - block_id_c = tl.program_id(0) - block_id_h = tl.program_id(1) - block_id_w = tl.program_id(2) - - dest_w_offs = (block_id_w * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)) - - dest_offs = (block_id_c[None, None] * stride_out_c + block_id_h[None, None] * stride_out_h + - dest_w_offs[None, :] * stride_out_w) - - fx = dest_w_offs * RR_W - sx = tl.floor(fx) - - new_col = block_id_h * RR_H - src_offsets = (block_id_c[None, None] * stride_in_c + new_col[None, None].to(tl.int32) * stride_in_h + - tl.clamp(sx, 0, src_cols - 1)[None, :].to(tl.int32) * stride_in_w) - src_val = tl.load(img_src_ptr + src_offsets) - dst_mask = dest_w_offs[None, :] < dst_cols - tl.store(img_dst_ptr + dest_offs, src_val, mask=dst_mask) - - -def nearest_resize_cpu(img_src, img_dst, dst_rows, dst_cols): - N, C, src_rows, src_cols = img_src.shape - # 2,4 64, 32 - #RR_H和RR_W分别为高和宽的缩放比例 - RR_H = src_rows / float(dst_rows) - RR_W = src_cols / float(dst_cols) - print("RR_H RR_W", RR_H, RR_W) - # 2, 2 - #根据output image的坐标值(i,j)计算input image的坐标值(sy, sx) - for i in range(dst_rows): #32 - for j in range(dst_cols): #16 - # fy = i * 2 = 0/1/..31 * 2 = 0/2/4...62 - fy = (i * RR_H) - sy = math.floor(fy) - # fx = j * 2 = 0/1/2..14 * 2 = 0/2/4...28 - fx = (j * RR_W) - sx = math.floor(fx) - src_val = img_src[0, :, np.clip(sy, 0, src_rows - 1), np.clip(sx, 0, src_cols - 1)] - # img_dst[0, :, i, j] 表示取批量中第 0 张图像、第 i 行第 j 列位置上的所有通道像素值 - img_dst[0, :, i, j] = src_val - - -def test_nearest_resize(): - n, c, h, w = 1, 4, 64, 64 - img_src = torch.randint(0, 255, size=(n, c, h, w)) - dst_rows = h // 2 - dst_cols = w // 2 - img_dst_cpu = torch.randint(0, 255, size=(n, c, dst_rows, dst_cols)) - nearest_resize_cpu(img_src, img_dst_cpu, dst_rows, dst_cols) - - # call triton kernel - img_src = img_src.npu() - RR_H = h / float(dst_rows) - RR_W = w / float(dst_cols) - img_dst_npu = torch.randint(0, 255, size=(n, c, dst_rows, dst_cols)).npu() - stride_in_h = h - stride_in_w = 1 - stride_in_c = h * w - stride_out_h = dst_rows - stride_out_w = 1 - stride_out_c = dst_rows * dst_cols - BLOCK_SIZE = 32 - # best performance - nearest_resize_kernel_col_tile[(4, 32, 1)](img_src, img_dst_npu, h, w, dst_rows, dst_cols, RR_H, RR_W, stride_in_h, - stride_in_w, stride_in_c, stride_out_h, stride_out_w, stride_out_c, - BLOCK_SIZE) - assert torch.equal(img_dst_cpu.cpu(), img_dst_npu.cpu()) +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import torch, torch_npu +import triton +import triton.language as tl +import numpy as np +import math +import pytest + + +@triton.jit +def nearest_resize_kernel_col_tile(img_src_ptr, img_dst_ptr, src_rows: tl.constexpr, src_cols: tl.constexpr, + dst_rows: tl.constexpr, dst_cols: tl.constexpr, RR_H: tl.constexpr, + RR_W: tl.constexpr, stride_in_h: tl.constexpr, stride_in_w: tl.constexpr, + stride_in_c: tl.constexpr, stride_out_h: tl.constexpr, stride_out_w: tl.constexpr, + stride_out_c: tl.constexpr, BLOCK_SIZE: tl.constexpr): + #RR_H和RR_W分别为高和宽的缩放比例 + block_id_c = tl.program_id(0) + block_id_h = tl.program_id(1) + block_id_w = tl.program_id(2) + + dest_w_offs = (block_id_w * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)) + + dest_offs = (block_id_c[None, None] * stride_out_c + block_id_h[None, None] * stride_out_h + + dest_w_offs[None, :] * stride_out_w) + + fx = dest_w_offs * RR_W + sx = tl.floor(fx) + + new_col = block_id_h * RR_H + src_offsets = (block_id_c[None, None] * stride_in_c + new_col[None, None].to(tl.int32) * stride_in_h + + tl.clamp(sx, 0, src_cols - 1)[None, :].to(tl.int32) * stride_in_w) + src_val = tl.load(img_src_ptr + src_offsets) + dst_mask = dest_w_offs[None, :] < dst_cols + tl.store(img_dst_ptr + dest_offs, src_val, mask=dst_mask) + + +def nearest_resize_cpu(img_src, img_dst, dst_rows, dst_cols): + N, C, src_rows, src_cols = img_src.shape + # 2,4 64, 32 + #RR_H和RR_W分别为高和宽的缩放比例 + RR_H = src_rows / float(dst_rows) + RR_W = src_cols / float(dst_cols) + print("RR_H RR_W", RR_H, RR_W) + # 2, 2 + #根据output image的坐标值(i,j)计算input image的坐标值(sy, sx) + for i in range(dst_rows): #32 + for j in range(dst_cols): #16 + # fy = i * 2 = 0/1/..31 * 2 = 0/2/4...62 + fy = (i * RR_H) + sy = math.floor(fy) + # fx = j * 2 = 0/1/2..14 * 2 = 0/2/4...28 + fx = (j * RR_W) + sx = math.floor(fx) + src_val = img_src[0, :, np.clip(sy, 0, src_rows - 1), np.clip(sx, 0, src_cols - 1)] + # img_dst[0, :, i, j] 表示取批量中第 0 张图像、第 i 行第 j 列位置上的所有通道像素值 + img_dst[0, :, i, j] = src_val + + +def test_nearest_resize(): + n, c, h, w = 1, 4, 64, 64 + img_src = torch.randint(0, 255, size=(n, c, h, w)) + dst_rows = h // 2 + dst_cols = w // 2 + img_dst_cpu = torch.randint(0, 255, size=(n, c, dst_rows, dst_cols)) + nearest_resize_cpu(img_src, img_dst_cpu, dst_rows, dst_cols) + + # call triton kernel + img_src = img_src.npu() + RR_H = h / float(dst_rows) + RR_W = w / float(dst_cols) + img_dst_npu = torch.randint(0, 255, size=(n, c, dst_rows, dst_cols)).npu() + stride_in_h = h + stride_in_w = 1 + stride_in_c = h * w + stride_out_h = dst_rows + stride_out_w = 1 + stride_out_c = dst_rows * dst_cols + BLOCK_SIZE = 32 + # best performance + nearest_resize_kernel_col_tile[(4, 32, 1)](img_src, img_dst_npu, h, w, dst_rows, dst_cols, RR_H, RR_W, stride_in_h, + stride_in_w, stride_in_c, stride_out_h, stride_out_w, stride_out_c, + BLOCK_SIZE) + assert torch.equal(img_dst_cpu.cpu(), img_dst_npu.cpu()) diff --git a/third_party/ascend/examples/pytest_ut/test_rint.py b/third_party/ascend/unittest/pytest_ut/test_rint.py similarity index 96% rename from third_party/ascend/examples/pytest_ut/test_rint.py rename to third_party/ascend/unittest/pytest_ut/test_rint.py index a77a3f11d..23b1bb3f1 100644 --- a/third_party/ascend/examples/pytest_ut/test_rint.py +++ b/third_party/ascend/unittest/pytest_ut/test_rint.py @@ -22,6 +22,8 @@ import triton import triton.language as tl +import triton.language.extra.cann.libdevice as libdevice + import time import torch @@ -36,7 +38,7 @@ def triton_rint(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr, XBLOCK_SUB: tl. x0 = offset + loop1 + tl.arange(0, XBLOCK_SUB) xmask = x0 < xnumel tmp0 = tl.load(in_ptr0 + x0, mask=xmask) - tmp1 = tl.math.rint(tmp0) + tmp1 = libdevice.rint(tmp0) tl.store(out_ptr0 + x0, tmp1, mask=xmask) diff --git a/third_party/ascend/examples/pytest_ut/test_rms_norm.py b/third_party/ascend/unittest/pytest_ut/test_rms_norm.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_rms_norm.py rename to third_party/ascend/unittest/pytest_ut/test_rms_norm.py diff --git a/third_party/ascend/examples/pytest_ut/test_rotary_embedding.py b/third_party/ascend/unittest/pytest_ut/test_rotary_embedding.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_rotary_embedding.py rename to third_party/ascend/unittest/pytest_ut/test_rotary_embedding.py diff --git a/third_party/ascend/examples/pytest_ut/test_rotatry_gpt.py b/third_party/ascend/unittest/pytest_ut/test_rotatry_gpt.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_rotatry_gpt.py rename to third_party/ascend/unittest/pytest_ut/test_rotatry_gpt.py diff --git a/third_party/ascend/examples/pytest_ut/test_rotaty_embedding_gpt.py b/third_party/ascend/unittest/pytest_ut/test_rotaty_embedding_gpt.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_rotaty_embedding_gpt.py rename to third_party/ascend/unittest/pytest_ut/test_rotaty_embedding_gpt.py diff --git a/third_party/ascend/examples/pytest_ut/test_rshift.py b/third_party/ascend/unittest/pytest_ut/test_rshift.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_rshift.py rename to third_party/ascend/unittest/pytest_ut/test_rshift.py diff --git a/third_party/ascend/examples/pytest_ut/test_rsqrt.py b/third_party/ascend/unittest/pytest_ut/test_rsqrt.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_rsqrt.py rename to third_party/ascend/unittest/pytest_ut/test_rsqrt.py diff --git a/third_party/ascend/examples/pytest_ut/test_scalarPointer.py b/third_party/ascend/unittest/pytest_ut/test_scalarPointer.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_scalarPointer.py rename to third_party/ascend/unittest/pytest_ut/test_scalarPointer.py diff --git a/third_party/ascend/examples/pytest_ut/test_scalar_calc.py b/third_party/ascend/unittest/pytest_ut/test_scalar_calc.py similarity index 99% rename from third_party/ascend/examples/pytest_ut/test_scalar_calc.py rename to third_party/ascend/unittest/pytest_ut/test_scalar_calc.py index 10b588323..84ecdbeb1 100644 --- a/third_party/ascend/examples/pytest_ut/test_scalar_calc.py +++ b/third_party/ascend/unittest/pytest_ut/test_scalar_calc.py @@ -22,6 +22,8 @@ import torch_npu import triton import triton.language as tl +import triton.language.extra.cann.libdevice as libdevice + import pytest import test_common @@ -654,7 +656,7 @@ def test_scalar_tanh_calc(param_list): def triton_kernel(out_ptr0, in_ptr0, N: tl.constexpr): idx = 0 tmp0 = tl.load(in_ptr0 + idx) - tmp1 = tl.math.tanh(tmp0) + tmp1 = libdevice.tanh(tmp0) tl.store(out_ptr0 + idx, tmp1) def torch_func(x0): diff --git a/third_party/ascend/unittest/pytest_ut/test_scope.py b/third_party/ascend/unittest/pytest_ut/test_scope.py new file mode 100755 index 000000000..9973348bd --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_scope.py @@ -0,0 +1,159 @@ +#!/usr/bin/env python3 +import os + +os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = "0" + +import pytest +import triton +import triton.language as tl +import triton.language.extra.cann.extension as al +from triton.compiler.compiler import ASTSource +from triton.compiler.code_generator import ast_to_ttir +from triton._C.libtriton import ir +from triton._C.libtriton.ascend import ir as ascend_ir + + +class Options: + num_warps = 4 + num_stages = 3 + num_ctas = 1 + cluster_dims = (1, 1, 1) + enable_fp_fusion = True + debug = False + + +def compile_kernel(kernel, signature, constants): + """Helper to compile a kernel to MLIR.""" + src = ASTSource(kernel, signature, constants) + context = ir.context() + ir.load_dialects(context) + ascend_ir.load_dialects(context) + module = ast_to_ttir(kernel, src, context, Options(), {}, {}) + return str(module) + + +# ============== Kernel definitions ============== + + +@triton.jit +def kernel_nested_scope(x_ptr, y_ptr, out_ptr, n, BLOCK: tl.constexpr): + """Test nested scopes.""" + i = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) + with al.scope(core_mode="vector"): + with al.scope(core_mode="vector"): + with al.scope(core_mode="cube"): + x = tl.load(x_ptr + i, mask=i < n) + y = tl.load(y_ptr + i, mask=i < n) + result = x + y + tl.store(out_ptr + i, result, mask=i < n) + + +@triton.jit +def kernel_scope_escape(x_ptr, out_ptr, n, BLOCK: tl.constexpr): + """Test variable defined inside scope, used outside.""" + i = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) + with al.scope(core_mode="vector"): + x = tl.load(x_ptr + i, mask=i < n) + # Use x outside of the scope + a = x + 1.0 + tl.store(out_ptr + i, a, mask=i < n) + + +@triton.jit +def kernel_scope_cube(x_ptr, y_ptr, out_ptr, n, BLOCK: tl.constexpr): + """Test cube core mode.""" + i = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) + with al.scope(core_mode="cube"): + x = tl.load(x_ptr + i, mask=i < n) + y = tl.load(y_ptr + i, mask=i < n) + result = x + y + tl.store(out_ptr + i, result, mask=i < n) + + +@triton.jit +def kernel_scope_vector(x_ptr, y_ptr, out_ptr, n, BLOCK: tl.constexpr): + """Test vector core mode.""" + i = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) + with al.scope(core_mode="vector"): + x = tl.load(x_ptr + i, mask=i < n) + y = tl.load(y_ptr + i, mask=i < n) + result = x + y + tl.store(out_ptr + i, result, mask=i < n) + + +@triton.jit +def kernel_scope_disable_auto_sync(x_ptr, y_ptr, out_ptr, n, BLOCK: tl.constexpr): + """Test disable auto sync.""" + i = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) + with al.scope(core_mode="vector", disable_auto_sync=True): + x = tl.load(x_ptr + i, mask=i < n) + y = tl.load(y_ptr + i, mask=i < n) + result = x + y + tl.store(out_ptr + i, result, mask=i < n) + + +# ============== Pytest tests ============== + + +def test_nested_scope(): + """Test nested scopes compile successfully.""" + mlir = compile_kernel(kernel_nested_scope, {"x_ptr": "*fp32", "y_ptr": "*fp32", "out_ptr": "*fp32", "n": "i32"}, + {"BLOCK": 256}) + assert "scope.scope" in mlir + assert len(mlir) > 0 + + +def test_scope_escape(): + """Test variable escaping from scope.""" + mlir = compile_kernel(kernel_scope_escape, {"x_ptr": "*fp32", "out_ptr": "*fp32", "n": "i32"}, {"BLOCK": 256}) + assert "scope.scope" in mlir + assert len(mlir) > 0 + + +def test_scope_cube_mode(): + """Test cube core mode generates correct attributes.""" + mlir = compile_kernel(kernel_scope_cube, {"x_ptr": "*fp32", "y_ptr": "*fp32", "out_ptr": "*fp32", "n": "i32"}, + {"BLOCK": 256}) + assert "scope.scope" in mlir + # Check for cube core type attribute + assert "hivm.tcore_type" in mlir or "CUBE" in mlir.upper() + + +def test_scope_vector_mode(): + """Test vector core mode generates correct attributes.""" + mlir = compile_kernel(kernel_scope_vector, {"x_ptr": "*fp32", "y_ptr": "*fp32", "out_ptr": "*fp32", "n": "i32"}, + {"BLOCK": 256}) + assert "scope.scope" in mlir + # Check for vector core type attribute + assert "hivm.tcore_type" in mlir or "VECTOR" in mlir.upper() + + +def test_scope_disable_auto_sync(): + """Test disable auto sync generates correct attributes.""" + mlir = compile_kernel( + kernel_scope_disable_auto_sync, + {"x_ptr": "*fp32", "y_ptr": "*fp32", "out_ptr": "*fp32", "n": "i32"}, + {"BLOCK": 256}, + ) + assert "scope.scope" in mlir + # Check for disable auto sync attribute + assert "hivm.disable_auto_sync" in mlir + + +# ============== Main for manual testing ============== + +if __name__ == "__main__": + print("=" * 60) + print("Test 1: Nested Scopes") + print("=" * 60) + mlir = compile_kernel(kernel_nested_scope, {"x_ptr": "*fp32", "y_ptr": "*fp32", "out_ptr": "*fp32", "n": "i32"}, + {"BLOCK": 256}) + print(f"✅ Generated MLIR ({len(mlir)} chars):\n") + print(mlir) + + print("\n" + "=" * 60) + print("Test 2: Scope Escape") + print("=" * 60) + mlir = compile_kernel(kernel_scope_escape, {"x_ptr": "*fp32", "out_ptr": "*fp32", "n": "i32"}, {"BLOCK": 256}) + print(f"✅ Generated MLIR ({len(mlir)} chars):\n") + print(mlir) diff --git a/third_party/ascend/unittest/pytest_ut/test_select_analysis_for_divsiop.py b/third_party/ascend/unittest/pytest_ut/test_select_analysis_for_divsiop.py new file mode 100644 index 000000000..12e216672 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_select_analysis_for_divsiop.py @@ -0,0 +1,84 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import triton +import triton.language as tl +import torch +import torch_npu +import pytest +import test_common + + +def torch_divsiop_select_analysis(offs_num, divnum, maxindex, index, query): + idx = query[index] + offs_m = torch.arange(offs_num, dtype=torch.int32).npu() + query_pos = offs_m // divnum - idx + mask = query_pos < maxindex + + tensor0 = torch.tensor(0, dtype=torch.int32).npu() + tensor1 = torch.tensor(1, dtype=torch.int32).npu() + result = torch.where(mask, tensor1, tensor0).npu() + return result + + +@triton.jit +def divsiop_select_analysis_kernel1(index, out_ptr, query, offs_num: tl.constexpr, divnum: tl.constexpr, + maxindex: tl.constexpr): + idx = tl.load(query + index) + offs_m = tl.arange(0, offs_num) + + query_pos = offs_m // divnum - idx + + mask = query_pos < maxindex + query_mask = tl.where(mask, 1, 0).to(tl.int1) + tl.store(out_ptr + tl.arange(0, offs_num), query_mask) + + +@triton.jit +def divsiop_select_analysis_kernel2(index, out_ptr, query, offs_num: tl.constexpr, divnum: tl.constexpr, + maxindex: tl.constexpr): + idx = tl.load(query + index) + offs_m = tl.arange(0, offs_num) + + query_pos = -idx + offs_m // divnum + mask = query_pos < maxindex + + query_mask = tl.where(mask, 1, 0).to(tl.int1) + tl.store(out_ptr + tl.arange(0, offs_num), query_mask) + + +@pytest.mark.parametrize('param_list', [[16, 4, 2, index] for index in range(0, 4)]) +def test_divsiop_select_analysis1(param_list): + offs_num, divnum, maxindex, index = param_list + query = torch.tensor(range(0, divnum)).npu() + y_ref = torch_divsiop_select_analysis(offs_num, divnum, maxindex, index, query).npu() + y_cal = torch.full((offs_num, ), 2, dtype=torch.int32).npu() + divsiop_select_analysis_kernel1[(1, )](index, y_cal, query, offs_num, divnum, maxindex) + test_common.validate_cmp('int32', y_cal, y_ref) + + +@pytest.mark.parametrize('param_list', [[16, 4, 2, index] for index in range(0, 4)]) +def test_divsiop_select_analysis2(param_list): + offs_num, divnum, maxindex, index = param_list + query = torch.tensor(range(0, divnum)).npu() + y_ref = torch_divsiop_select_analysis(offs_num, divnum, maxindex, index, query).npu() + y_cal = torch.full((offs_num, ), 2, dtype=torch.int32).npu() + divsiop_select_analysis_kernel2[(1, )](index, y_cal, query, offs_num, divnum, maxindex) + test_common.validate_cmp('int32', y_cal, y_ref) diff --git a/third_party/ascend/examples/pytest_ut/test_sigmoid.py b/third_party/ascend/unittest/pytest_ut/test_sigmoid.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_sigmoid.py rename to third_party/ascend/unittest/pytest_ut/test_sigmoid.py diff --git a/third_party/ascend/examples/pytest_ut/test_signbit.py b/third_party/ascend/unittest/pytest_ut/test_signbit.py similarity index 97% rename from third_party/ascend/examples/pytest_ut/test_signbit.py rename to third_party/ascend/unittest/pytest_ut/test_signbit.py index 922fc3831..693b601c2 100644 --- a/third_party/ascend/examples/pytest_ut/test_signbit.py +++ b/third_party/ascend/unittest/pytest_ut/test_signbit.py @@ -24,7 +24,7 @@ import torch_npu import pytest import test_common -import triton.language.extra.ascend.libdevice as libdevice +import triton.language.extra.cann.libdevice as libdevice @triton.jit diff --git a/third_party/ascend/examples/pytest_ut/test_silu.py b/third_party/ascend/unittest/pytest_ut/test_silu.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_silu.py rename to third_party/ascend/unittest/pytest_ut/test_silu.py diff --git a/third_party/ascend/examples/pytest_ut/test_sin.py b/third_party/ascend/unittest/pytest_ut/test_sin.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_sin.py rename to third_party/ascend/unittest/pytest_ut/test_sin.py diff --git a/third_party/ascend/examples/pytest_ut/test_sinh.py b/third_party/ascend/unittest/pytest_ut/test_sinh.py similarity index 97% rename from third_party/ascend/examples/pytest_ut/test_sinh.py rename to third_party/ascend/unittest/pytest_ut/test_sinh.py index c9a0cd8c3..81227bd86 100644 --- a/third_party/ascend/examples/pytest_ut/test_sinh.py +++ b/third_party/ascend/unittest/pytest_ut/test_sinh.py @@ -22,7 +22,7 @@ import triton import triton.language as tl -import triton.language.extra.ascend.libdevice as libdevice +import triton.language.extra.cann.libdevice as libdevice import test_common import torch diff --git a/third_party/ascend/examples/pytest_ut/test_softmax.py b/third_party/ascend/unittest/pytest_ut/test_softmax.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_softmax.py rename to third_party/ascend/unittest/pytest_ut/test_softmax.py diff --git a/third_party/ascend/unittest/pytest_ut/test_softmax_mindspore.py b/third_party/ascend/unittest/pytest_ut/test_softmax_mindspore.py new file mode 100644 index 000000000..9efd0fbcd --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_softmax_mindspore.py @@ -0,0 +1,99 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import os +import mindspore +import triton +import triton.language as tl +import numpy as np +from triton.runtime import driver +import pytest + +pytestmark = pytest.mark.backend("mindspore") + + +@triton.jit +def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, + BLOCK_SIZE: tl.constexpr): + # starting row of the program + row_start = tl.program_id(0) + row_step = tl.num_programs(0) + for row_idx in tl.range(row_start, n_rows, row_step): + # The stride represents how much we need to increase the pointer to advance 1 row + row_start_ptr = input_ptr + row_idx * input_row_stride + # The block size is the next power of two greater than n_cols, so we can fit each + # row in a single block + col_offsets = tl.arange(0, BLOCK_SIZE) + input_ptrs = row_start_ptr + col_offsets + # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols + mask = col_offsets < n_cols + row = tl.load(input_ptrs, mask=mask, other=-float('inf')) + # Subtract maximum for numerical stability + row_minus_max = row - tl.max(row, axis=0) + + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=0) + softmax_output = numerator / denominator + # Write back output to DRAM + output_row_start_ptr = output_ptr + row_idx * output_row_stride + output_ptrs = output_row_start_ptr + col_offsets + tl.store(output_ptrs, softmax_output, mask=mask) + + +target = triton.runtime.driver.active.get_current_target() +kernels = {} + + +def softmax(x): + n_rows, n_cols = x.shape + + # The block size of each loop iteration is the smallest power of two greater than the number of columns in `x` + BLOCK_SIZE = triton.next_power_of_2(n_cols) + + # Allocate output + y = mindspore.mint.empty_like(x) + + # pre-compile kernel to get register usage and compute thread occupancy. + kernel, num_programs = kernels.get(BLOCK_SIZE, (None, 0)) + if kernel is None: + num_programs = 32 + kernel = softmax_kernel + kernels[BLOCK_SIZE] = (kernel, num_programs) + + num_programs = min(num_programs, n_rows) + + # Create a number of persistent programs. + kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE) + return y + + +@pytest.mark.parametrize('param_list', [ + ['float32', (1823, 781)], + ['float16', (1823, 781)], +]) +def test_softmax_mindspore(param_list): + os.environ["TRITON_BACKEND"] = "mindspore" + dtype, shape = param_list + mindspore.set_seed(0) + x = mindspore.ops.randn(shape, dtype=eval('mindspore.' + dtype)) + output_triton = softmax(x) + output_mindspore = mindspore.ops.softmax(x, axis=1) + assert np.allclose(output_triton.asnumpy(), output_mindspore.asnumpy(), rtol=1e-3, atol=1e-3) + del os.environ["TRITON_BACKEND"] diff --git a/third_party/ascend/examples/pytest_ut/test_sort.py b/third_party/ascend/unittest/pytest_ut/test_sort.py similarity index 95% rename from third_party/ascend/examples/pytest_ut/test_sort.py rename to third_party/ascend/unittest/pytest_ut/test_sort.py index 2cc3514fe..62a103447 100644 --- a/third_party/ascend/examples/pytest_ut/test_sort.py +++ b/third_party/ascend/unittest/pytest_ut/test_sort.py @@ -23,6 +23,7 @@ import numpy as np import torch import triton.language as tl +import triton.language.extra.cann.extension as extension import test_common # --------------- @@ -36,7 +37,7 @@ def sort_kernel_2d(X, Z, N: tl.constexpr, M: tl.constexpr, descending: tl.conste offy = tl.arange(0, N) * M off2d = offx[None, :] + offy[:, None] x = tl.load(X + off2d) - x = tl.sort(x, descending=descending, dim=1) + x = extension.sort(x, descending=descending, dim=1) tl.store(Z + off2d, x) @@ -72,7 +73,7 @@ def sort_kernel_3d(X, Z, D0: tl.constexpr, D1: tl.constexpr, D2: tl.constexpr, d off = off2[None, None, :] + off1[None, :, None] + off0[:, None, None] x = tl.load(X + off) - x = tl.sort(x, descending=descending, dim=2) + x = extension.sort(x, descending=descending, dim=2) tl.store(Z + off, x) diff --git a/third_party/ascend/examples/pytest_ut/test_split.py b/third_party/ascend/unittest/pytest_ut/test_split.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_split.py rename to third_party/ascend/unittest/pytest_ut/test_split.py diff --git a/third_party/ascend/examples/pytest_ut/test_sqrt.py b/third_party/ascend/unittest/pytest_ut/test_sqrt.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_sqrt.py rename to third_party/ascend/unittest/pytest_ut/test_sqrt.py diff --git a/third_party/ascend/examples/pytest_ut/test_static_print_and_assert.py b/third_party/ascend/unittest/pytest_ut/test_static_print_and_assert.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_static_print_and_assert.py rename to third_party/ascend/unittest/pytest_ut/test_static_print_and_assert.py diff --git a/third_party/ascend/examples/pytest_ut/test_store_scalar.py b/third_party/ascend/unittest/pytest_ut/test_store_scalar.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_store_scalar.py rename to third_party/ascend/unittest/pytest_ut/test_store_scalar.py diff --git a/third_party/ascend/unittest/pytest_ut/test_stride0_load.py b/third_party/ascend/unittest/pytest_ut/test_stride0_load.py new file mode 100644 index 000000000..9c9295381 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_stride0_load.py @@ -0,0 +1,66 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import torch +import torch_npu +import triton +import triton.language as tl + + +def torch_expanddims_load(in0, in1, out0, YBLOCK, XBLOCK): + for i in range(0, YBLOCK): + tmp = in0[i] + 10 + res = in1[tmp] + out0[i, :] = res + + +@triton.jit +def triton_expanddims_load(in0, in1, out0, YBLOCK: tl.constexpr, XBLOCK: tl.constexpr): + base_y = tl.arange(0, YBLOCK) + y_idx = base_y[:, None] + y_mask = y_idx < YBLOCK + base_x = tl.arange(0, XBLOCK) + x_idx = base_x[None, :] + x_mask = x_idx < XBLOCK + + y = tl.load(in0 + y_idx, mask=y_mask) + tmp0 = tl.full([YBLOCK, XBLOCK], 10, tl.int32) + tmp1 = y + tmp0 + res = tl.load(in1 + tmp1, mask=y_mask) + tl.store(out0 + (y_idx * XBLOCK + x_idx), res, mask=y_mask & x_mask) + + +def test_case(): + YBLOCK = 4 + XBLOCK = 8 + in0 = torch.arange(0, YBLOCK, device="npu", dtype=torch.int32) + in1 = torch.arange(0, YBLOCK + 10, device="npu", dtype=torch.int32) + out0 = torch.empty((YBLOCK, XBLOCK), device="npu", dtype=torch.int32) + + torch_expanddims_load(in0, in1, out0, YBLOCK, XBLOCK) + + in0_triton = in0 + in1_triton = in1 + out0_triton = torch.empty((YBLOCK * XBLOCK, ), device="npu", dtype=torch.int32) + + triton_expanddims_load[(1, 1, 1)](in0_triton, in1_triton, out0_triton, YBLOCK, XBLOCK) + out0_triton = out0_triton.view(YBLOCK, XBLOCK) + + assert torch.allclose(out0, out0_triton, rtol=1e-03, atol=1e-03, equal_nan=True) diff --git a/third_party/ascend/examples/pytest_ut/test_strides.py b/third_party/ascend/unittest/pytest_ut/test_strides.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_strides.py rename to third_party/ascend/unittest/pytest_ut/test_strides.py diff --git a/third_party/ascend/examples/pytest_ut/test_sub.py b/third_party/ascend/unittest/pytest_ut/test_sub.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_sub.py rename to third_party/ascend/unittest/pytest_ut/test_sub.py diff --git a/third_party/ascend/unittest/pytest_ut/test_sub_vec_id.py b/third_party/ascend/unittest/pytest_ut/test_sub_vec_id.py new file mode 100644 index 000000000..3fa24e76e --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_sub_vec_id.py @@ -0,0 +1,132 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import triton +import triton.language as tl +import pytest +import test_common +import triton.language.extra.cann.extension as al + + +@triton.jit +def triton_matmul_exp( + A_ptr, + B_ptr, + C_ptr, + TBuff_ptr, + M: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + sub_M: tl.constexpr, +): + """function: + 1) The matrix matmul + 2) The matrix exp computation + For example, + 1) [2, 3] @ [3, 5] -> [2, 5] + + [[-1.4310, 0.3144, 0.1952], [[-0.1099, 0.7062, 0.6576, 1.3056, 0.3783], + [ 1.6719, -0.2581, -1.0243]] @ [ 0.9769, -0.6924, 0.4765, 1.1012, 0.3814] + [-1.4598, -0.5444, 0.5582, -2.0959, -0.0568]] + -> + [[ 0.1795, -1.3346, -0.6822, -1.9311, -0.4324], + [ 1.0593, 1.9171, 0.4047, 4.0454, 0.5921]] + + 2) exp([2, 5]) + exp([[ 0.1795, -1.3346, -0.6822, -1.9311, -0.4324], -> [[ 1.1966, 0.2633, 0.5055, 0.1450, 0.6489], + [ 1.0593, 1.9171, 0.4047, 4.0454, 0.5921]]) [ 2.8845, 6.8013, 1.4988, 57.1358, 1.8078]] + """ + # Each program computes one element C[row, col] using 2D tl.dot + row_matmul = tl.program_id(0) + col = tl.program_id(1) + + # Build small 2D grids so tl.dot sees [M,K] x [K,N] + offs_i = tl.arange(0, tl.constexpr(M))[:, None] # [M,1] (row axis) + offs_j = tl.arange(0, N)[None, :] # [1,N] (col axis) + offs_k = tl.arange(0, K) # [K] + + # A row: [M, K] + a_ptrs = A_ptr + (row_matmul + offs_i) * K + offs_k[None, :] + a_vals = tl.load(a_ptrs) # [M, K] + + # B column: [K, N] + b_ptrs = B_ptr + offs_k[:, None] * N + (col + offs_j) + b_vals = tl.load(b_ptrs) # [K, N] + + tbuff_ptrs = TBuff_ptr + (row_matmul + offs_i) * N + (col + offs_j) + + # Dot: [M, K] @ [K, N] -> [M, N] + acc_11 = tl.dot(a_vals, b_vals) # [M, N] + tl.store(tbuff_ptrs, acc_11) + + # Load Matrix [M/2, N] + sub_vec_id = al.sub_vec_id() + row_exp = row_matmul + sub_M * sub_vec_id + offs_exp_i = tl.arange(0, tl.constexpr(sub_M))[:, None] # [M/2, 1] (row axis) + tbuff_exp_ptrs = TBuff_ptr + (row_exp + offs_exp_i) * N + (col + offs_j) + acc_11_reload = tl.load(tbuff_exp_ptrs) + # Pointer grid for the single output element: shape [M/2, N] + c_ptrs = C_ptr + (row_exp + offs_exp_i) * N + (col + offs_j) + + # Store exp(acc) without scalar indexing + tl.store(c_ptrs, tl.exp(acc_11_reload)) + + +@pytest.mark.parametrize( + "dtype, ashape, bshape", + [ + # dtype, A-shape, B-shape + ["float32", (2, 3), (3, 5)], + ["float32", (2, 1), (1, 5)], + ], +) +def test_sub_vec_id_1to2(dtype, ashape, bshape): + """function: + A 1:2 demo using sub_vec_id. + 1. The matrix computation and the vector computation unit each have their own independent Scalar scheduler units, + deploying separately on cube core and vector core. + 2. Combine cube core and vector core in a certain ratio (1:2) + + For example, [2, 3] @ [3, 5] -> [2, 5] matrix matmul computation and matrix exp([2, 5]) computation + using sub_vec_id was used during the matrix exp. + """ + M, K = ashape + K2, N = bshape + assert K == K2, "Inner dimensions must match" + assert M % 2 == 0, "M dimensions must be divisible by 2" + sub_M = int(M / 2) + + # Generate input tensors + A = test_common.generate_tensor(ashape, dtype).npu() + B = test_common.generate_tensor(bshape, dtype).npu() + C = test_common.generate_tensor((M, N), dtype).npu() + TBuff = test_common.generate_tensor((M, N), dtype).npu() + + # Run + grid_matmul_exp = (1, ) # grid + triton_matmul_exp[grid_matmul_exp](A, B, C, TBuff, M, N, K, sub_M) + + # Reference result + C_ref = (A @ B).exp() + test_common.validate_cmp(dtype, C, C_ref) + + +if __name__ == "__main__": + test_sub_vec_id_1to2("float32", (2, 3), (3, 5)) diff --git a/third_party/ascend/unittest/pytest_ut/test_sub_vec_num.py b/third_party/ascend/unittest/pytest_ut/test_sub_vec_num.py new file mode 100644 index 000000000..b46302836 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_sub_vec_num.py @@ -0,0 +1,130 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import triton +import triton.language as tl +import pytest +import test_common +import triton.language.extra.cann.extension as al + + +@triton.jit +def triton_matmul_exp( + A_ptr, + B_ptr, + C_ptr, + TBuff_ptr, + M: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, +): + """function: + 1) The matrix matmul + 2) The matrix exp computation + For example, + 1) [2, 3] @ [3, 5] -> [2, 5] + + [[-1.4310, 0.3144, 0.1952], [[-0.1099, 0.7062, 0.6576, 1.3056, 0.3783], + [ 1.6719, -0.2581, -1.0243]] @ [ 0.9769, -0.6924, 0.4765, 1.1012, 0.3814] + [-1.4598, -0.5444, 0.5582, -2.0959, -0.0568]] + -> + [[ 0.1795, -1.3346, -0.6822, -1.9311, -0.4324], + [ 1.0593, 1.9171, 0.4047, 4.0454, 0.5921]] + + 2) exp([2, 5]) + exp([[ 0.1795, -1.3346, -0.6822, -1.9311, -0.4324], -> [[ 1.1966, 0.2633, 0.5055, 0.1450, 0.6489], + [ 1.0593, 1.9171, 0.4047, 4.0454, 0.5921]]) [ 2.8845, 6.8013, 1.4988, 57.1358, 1.8078]] + """ + # Each program computes one element C[row, col] using 2D tl.dot + row_matmul = tl.program_id(0) + col = tl.program_id(1) + + # Build small 2D grids so tl.dot sees [M,K] x [K,N] + offs_i = tl.arange(0, tl.constexpr(M))[:, None] # [M,1] (row axis) + offs_j = tl.arange(0, N)[None, :] # [1,N] (col axis) + offs_k = tl.arange(0, K) # [K] + + # A row: [M, K] + a_ptrs = A_ptr + (row_matmul + offs_i) * K + offs_k[None, :] + a_vals = tl.load(a_ptrs) # [M, K] + + # B column: [K, N] + b_ptrs = B_ptr + offs_k[:, None] * N + (col + offs_j) + b_vals = tl.load(b_ptrs) # [K, N] + + tbuff_ptrs = TBuff_ptr + (row_matmul + offs_i) * N + (col + offs_j) + + # Dot: [M, K] @ [K, N] -> [M, N] + acc_11 = tl.dot(a_vals, b_vals) # [M, N] + tl.store(tbuff_ptrs, acc_11) + + # Load Matrix [M/2, N] + sub_vec_id = al.sub_vec_id() + row_exp = row_matmul + (M // al.sub_vec_num()) * sub_vec_id + offs_exp_i = tl.arange(0, M // al.sub_vec_num())[:, None] # [M/2, 1] (row axis) + tbuff_exp_ptrs = TBuff_ptr + (row_exp + offs_exp_i) * N + (col + offs_j) + acc_11_reload = tl.load(tbuff_exp_ptrs) + # Pointer grid for the single output element: shape [M/2, N] + c_ptrs = C_ptr + (row_exp + offs_exp_i) * N + (col + offs_j) + + # Store exp(acc) without scalar indexing + tl.store(c_ptrs, tl.exp(acc_11_reload)) + + +@pytest.mark.parametrize( + "dtype, ashape, bshape", + [ + # dtype, A-shape, B-shape + ["float32", (2, 3), (3, 5)], + ["float32", (2, 1), (1, 5)], + ], +) +def test_sub_vec_num_1to2(dtype, ashape, bshape): + """function: + A 1:2 demo using sub_vec_id. + 1. The matrix computation and the vector computation unit each have their own independent Scalar scheduler units, + deploying separately on cube core and vector core. + 2. Combine cube core and vector core in a certain ratio (1:2) + + For example, [2, 3] @ [3, 5] -> [2, 5] matrix matmul computation and matrix exp([2, 5]) computation + using sub_vec_id was used during the matrix exp. + """ + M, K = ashape + K2, N = bshape + assert K == K2, "Inner dimensions must match" + assert M % 2 == 0, "M dimensions must be divisible by 2" + + # Generate input tensors + A = test_common.generate_tensor(ashape, dtype).npu() + B = test_common.generate_tensor(bshape, dtype).npu() + C = test_common.generate_tensor((M, N), dtype).npu() + TBuff = test_common.generate_tensor((M, N), dtype).npu() + + # Run + grid_matmul_exp = (1, ) # grid + triton_matmul_exp[grid_matmul_exp](A, B, C, TBuff, M, N, K) + + # Reference result + C_ref = (A @ B).exp() + test_common.validate_cmp(dtype, C, C_ref) + + +if __name__ == "__main__": + test_sub_vec_num_1to2("float32", (2, 3), (3, 5)) diff --git a/third_party/ascend/unittest/pytest_ut/test_subview.py b/third_party/ascend/unittest/pytest_ut/test_subview.py new file mode 100644 index 000000000..d5f6179f3 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_subview.py @@ -0,0 +1,89 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import os + +import pytest +import triton +import triton.language as tl +from triton.compiler.compiler import ASTSource +from triton.compiler.code_generator import ast_to_ttir +import triton.extension.buffer.language as bl +import triton.language.extra.cann.extension as al +from triton._C.libtriton import ir, buffer_ir +from triton._C.libtriton.ascend import ir as ascend_ir + +os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = "0" + + +class Options: + num_warps = 4 + num_stages = 3 + num_ctas = 1 + cluster_dims = (1, 1, 1) + enable_fp_fusion = True + debug = False + + +def compile_kernel(kernel, signature, constants): + """Helper to compile a kernel to MLIR.""" + src = ASTSource(kernel, signature, constants) + context = ir.context() + ir.load_dialects(context) + buffer_ir.load_dialects(context) + ascend_ir.load_dialects(context) + module = ast_to_ttir(kernel, src, context, Options(), {"create_address_space": al.semantic.create_address_space}, + {}) + return str(module) + + +# ============== Kernel definitions ============== + + +@triton.jit +def test_subview_kernel1(XBLOCK: tl.constexpr): + # 1. Allocate a local buffer + src_buffer = bl.alloc(tl.float32, [XBLOCK, XBLOCK]) + + result_buffer = bl.subview(src_buffer, offsets=[1, 1], sizes=[XBLOCK - 2, XBLOCK - 2], strides=[1, 1]) + + +@triton.jit +def test_subview_kernel2(XBLOCK: tl.constexpr, offsets: tl.constexpr, sizes: tl.constexpr, strides: tl.constexpr): + # this statement has no effect, just to test the builder + bl.alloc(tl.float32, [XBLOCK]).subview([offsets], [sizes], [strides]) + + +# ============== Main for manual testing ============== + +if __name__ == "__main__": + print("=" * 60) + print("Test 1: test_subview_function") + print("=" * 60) + mlir = compile_kernel(test_subview_kernel1, {}, {"XBLOCK": 8}) + print(f"Generated MLIR ({len(mlir)} chars):\n") + print(mlir) + + print("\n" + "=" * 60) + print("Test 2: test_subview_constructor") + print("=" * 60) + mlir = compile_kernel(test_subview_kernel2, {}, {"XBLOCK": 32, "offsets": 1, "sizes": 30, "strides": 1}) + print(f"Generated MLIR ({len(mlir)} chars):\n") + print(mlir) diff --git a/third_party/ascend/examples/pytest_ut/test_sum.py b/third_party/ascend/unittest/pytest_ut/test_sum.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_sum.py rename to third_party/ascend/unittest/pytest_ut/test_sum.py diff --git a/third_party/ascend/examples/pytest_ut/test_sum_dim0.py b/third_party/ascend/unittest/pytest_ut/test_sum_dim0.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_sum_dim0.py rename to third_party/ascend/unittest/pytest_ut/test_sum_dim0.py diff --git a/third_party/ascend/examples/pytest_ut/test_sum_dim1.py b/third_party/ascend/unittest/pytest_ut/test_sum_dim1.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_sum_dim1.py rename to third_party/ascend/unittest/pytest_ut/test_sum_dim1.py diff --git a/third_party/ascend/examples/pytest_ut/test_sum_vector.py b/third_party/ascend/unittest/pytest_ut/test_sum_vector.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_sum_vector.py rename to third_party/ascend/unittest/pytest_ut/test_sum_vector.py diff --git a/third_party/ascend/examples/pytest_ut/test_swap.py b/third_party/ascend/unittest/pytest_ut/test_swap.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_swap.py rename to third_party/ascend/unittest/pytest_ut/test_swap.py diff --git a/third_party/ascend/examples/pytest_ut/test_swiglu.py b/third_party/ascend/unittest/pytest_ut/test_swiglu.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_swiglu.py rename to third_party/ascend/unittest/pytest_ut/test_swiglu.py diff --git a/third_party/ascend/examples/pytest_ut/test_swizzle2d.py b/third_party/ascend/unittest/pytest_ut/test_swizzle2d.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_swizzle2d.py rename to third_party/ascend/unittest/pytest_ut/test_swizzle2d.py diff --git a/third_party/ascend/examples/pytest_ut/test_sync_block.py b/third_party/ascend/unittest/pytest_ut/test_sync_block.py similarity index 92% rename from third_party/ascend/examples/pytest_ut/test_sync_block.py rename to third_party/ascend/unittest/pytest_ut/test_sync_block.py index e59d3edad..0c84b3166 100644 --- a/third_party/ascend/examples/pytest_ut/test_sync_block.py +++ b/third_party/ascend/unittest/pytest_ut/test_sync_block.py @@ -23,6 +23,10 @@ import pytest import test_common +import triton.language.extra.cann.extension as extension + +pipe = extension.PIPE + # eg: pytest -v test_matmul_exp.py::test_matmul_exp ############################# @@ -52,8 +56,8 @@ def triton_matmul_exp(A_ptr, B_ptr, C_ptr, TBuff_ptr, M, N, K: tl.constexpr): acc_11 = tl.dot(a_vals, b_vals) # [1, 1] tl.store(tbuff_ptrs, acc_11) - tl.sync_block_set("cube", "vector", 5) - tl.sync_block_wait("cube", "vector", 5) + extension.sync_block_set("cube", "vector", 5, pipe.PIPE_MTE1, pipe.PIPE_MTE3) + extension.sync_block_wait("cube", "vector", 5, pipe.PIPE_MTE1, pipe.PIPE_MTE3) acc_11_reload = tl.load(tbuff_ptrs) # Pointer grid for the single output element: shape [1,1] diff --git a/third_party/ascend/unittest/pytest_ut/test_sync_block_all.py b/third_party/ascend/unittest/pytest_ut/test_sync_block_all.py new file mode 100755 index 000000000..a41885432 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_sync_block_all.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 + +import os +import pytest +import triton +import triton.language as tl +import triton.language.extra.cann.extension as al +from triton.compiler.compiler import ASTSource +from triton.compiler.code_generator import ast_to_ttir +from triton._C.libtriton import ir +from triton._C.libtriton.ascend import ir as ascend_ir + +os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = "0" + + +class Options: + num_warps = 4 + num_stages = 3 + num_ctas = 1 + cluster_dims = (1, 1, 1) + enable_fp_fusion = True + debug = False + + +def compile_kernel(kernel, signature, constants): + """Helper to compile a kernel to MLIR.""" + src = ASTSource(kernel, signature, constants) + context = ir.context() + ir.load_dialects(context) + ascend_ir.load_dialects(context) + module = ast_to_ttir(kernel, src, context, Options(), {}, {}) + return str(module) + + +# ============== Kernel definitions ============== + + +@triton.jit +def test_sync_block_all(): + al.sync_block_all("all_cube", 8) + al.sync_block_all("all_vector", 9) + al.sync_block_all("all", 10) + al.sync_block_all("all_sub_vector", 11) + + +# ============== Main for manual testing ============== + +if __name__ == "__main__": + print("=" * 60) + print("Test 1: test_sync_block_all") + print("=" * 60) + mlir = compile_kernel(test_sync_block_all, {}, {}) + print(f"✅ Generated MLIR ({len(mlir)} chars):\n") + print(mlir) diff --git a/third_party/ascend/examples/pytest_ut/test_tan.py b/third_party/ascend/unittest/pytest_ut/test_tan.py similarity index 98% rename from third_party/ascend/examples/pytest_ut/test_tan.py rename to third_party/ascend/unittest/pytest_ut/test_tan.py index e39e6e90c..13f63f51a 100644 --- a/third_party/ascend/examples/pytest_ut/test_tan.py +++ b/third_party/ascend/unittest/pytest_ut/test_tan.py @@ -27,7 +27,7 @@ import torch import torch_npu -import triton.language.extra.ascend.libdevice as libdevice +import triton.language.extra.cann.libdevice as libdevice def standard_unary(x0, dtype): diff --git a/third_party/ascend/examples/model_cases/llama.py b/third_party/ascend/unittest/pytest_ut/test_tanh.py similarity index 52% rename from third_party/ascend/examples/model_cases/llama.py rename to third_party/ascend/unittest/pytest_ut/test_tanh.py index b6f856acc..00e47d349 100644 --- a/third_party/ascend/examples/model_cases/llama.py +++ b/third_party/ascend/unittest/pytest_ut/test_tanh.py @@ -18,41 +18,40 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. -import logging -import os +import pytest + +import triton +import triton.language as tl +import triton.language.extra.cann.libdevice as libdevice +import test_common import torch import torch_npu -import torch_npu._inductor - -from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig - -os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" - -logging.basicConfig(level=logging.DEBUG) - -torch.npu.config.allow_internal_format = False -torch.manual_seed(0) -torch.npu.manual_seed(0) -tokenizer = AutoTokenizer.from_pretrained("./Meta-Llama-3-8B") - -inputs = tokenizer("Hello, how to make China great again?", return_tensors="pt").to("npu:0") -model_ = AutoModelForCausalLM.from_pretrained("./Meta-Llama-3-8B", device_map="npu:0", _attn_implementation="eager") -model_.eval() - - -def model(**model_inputs): - with torch.no_grad(): - return model_(**model_inputs).logits - - -y = model(**inputs) -logging.info("result eager: " + str(torch.flatten(y)[:100])) - -model_compiled = torch.compile(model_) -z = model_compiled(**inputs) -logging.info("result compiled: " + str(torch.flatten(z)[:100])) -torch.testing.assert_close(y, z, atol=1e-4, rtol=1e-4) -logging.info("llama accuracy check pass!") +def torch_tanh(x0): + return torch.tanh(x0) + + +@triton.jit +def triton_tanh(in_ptr0, out_ptr0, XBLOCK: tl.constexpr, XBLOCK_SUB: tl.constexpr): + offset = tl.program_id(0) * XBLOCK + base1 = tl.arange(0, XBLOCK_SUB) + loops1: tl.constexpr = XBLOCK // XBLOCK_SUB + for loop1 in range(loops1): + x0 = offset + (loop1 * XBLOCK_SUB) + base1 + tmp0 = tl.load(in_ptr0 + (x0), None) + tmp1 = libdevice.tanh(tmp0) + tl.store(out_ptr0 + (x0), tmp1, None) + + +@pytest.mark.parametrize('param_list', [ + ['float32', (2, 4096, 8), 2, 32768, 1024], +]) +def test_tanh(param_list): + dtype, shape, ncore, xblock, xblock_sub = param_list + x0 = test_common.generate_tensor(shape, dtype).npu() + y_ref = torch_tanh(x0) + y_cal = torch.zeros(shape, dtype=eval('torch.' + dtype)).npu() + triton_tanh[ncore, 1, 1](x0, y_cal, xblock, xblock_sub) + test_common.validate_cmp(dtype, y_cal, y_ref) diff --git a/third_party/ascend/examples/pytest_ut/test_template.py b/third_party/ascend/unittest/pytest_ut/test_template.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_template.py rename to third_party/ascend/unittest/pytest_ut/test_template.py diff --git a/third_party/ascend/examples/pytest_ut/test_tensor_descriptor.py b/third_party/ascend/unittest/pytest_ut/test_tensor_descriptor.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_tensor_descriptor.py rename to third_party/ascend/unittest/pytest_ut/test_tensor_descriptor.py diff --git a/third_party/ascend/unittest/pytest_ut/test_to_buffer.py b/third_party/ascend/unittest/pytest_ut/test_to_buffer.py new file mode 100644 index 000000000..2d81971d5 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_to_buffer.py @@ -0,0 +1,54 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import torch + +import triton +import triton.language as tl +from triton.compiler import ASTSource +import triton.extension.buffer.language as bl +import triton.language.extra.cann.extension as al + +target = triton.runtime.driver.active.get_current_target() + + +@triton.jit +def to_buffer(): + a = tl.full((32, 2, 4), 0, dtype=tl.int64) + a_buf = bl.to_buffer(a) + b = tl.full((32, 2, 4), 0, dtype=tl.int64) + b_buf = bl.to_buffer(b, al.ascend_address_space.UB) + c = tl.full((32, 2, 4), 0, dtype=tl.int64) + c_buf = bl.to_buffer(c, al.ascend_address_space.L1) + d = tl.full((32, 2, 4), 0, dtype=tl.int64) + d_buf = bl.to_buffer(d, al.ascend_address_space.L0A) + e = tl.full((32, 2, 4), 0, dtype=tl.int64) + e_buf = bl.to_buffer(e, al.ascend_address_space.L0B) + f = tl.full((32, 2, 4), 0, dtype=tl.int64) + f_buf = bl.to_buffer(f, al.ascend_address_space.L0C) + + +def test_to_buffer(): + src = ASTSource( + fn=to_buffer, + constants={}, + signature={}, + ) + triton.compile(src=src, target=target) diff --git a/third_party/ascend/unittest/pytest_ut/test_to_tensor.py b/third_party/ascend/unittest/pytest_ut/test_to_tensor.py new file mode 100644 index 000000000..d75bed760 --- /dev/null +++ b/third_party/ascend/unittest/pytest_ut/test_to_tensor.py @@ -0,0 +1,75 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import os + +import triton +import triton.language as tl +from triton.compiler.compiler import ASTSource +from triton.compiler.code_generator import ast_to_ttir +import triton.extension.buffer.language as bl +import triton.language.extra.cann.extension as al +from triton._C.libtriton import ir, buffer_ir +from triton._C.libtriton.ascend import ir as ascend_ir + +os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = "0" + + +class Options: + num_warps = 4 + num_stages = 3 + num_ctas = 1 + cluster_dims = (1, 1, 1) + enable_fp_fusion = True + debug = False + + +def compile_kernel(kernel, signature, constants): + """Helper to compile a kernel to MLIR.""" + src = ASTSource(kernel, signature, constants) + context = ir.context() + ir.load_dialects(context) + buffer_ir.load_dialects(context) + ascend_ir.load_dialects(context) + module = ast_to_ttir(kernel, src, context, Options(), {"create_address_space": al.semantic.create_address_space}, + {}) + return str(module) + + +# ============== Kernel definitions ============== + + +@triton.jit +def kernel_func(XBLOCK: tl.constexpr): + buffer1 = bl.alloc(tl.float32, [XBLOCK]) + buffer1.to_tensor(writable=True) + buffer2 = bl.alloc(tl.float32, [XBLOCK]) + bl.to_tensor(buffer2, writable=True) + + +# ============== Main for manual testing ============== + +if __name__ == "__main__": + print("=" * 60) + print("Test 1: Nested Scopes") + print("=" * 60) + mlir = compile_kernel(kernel_func, {}, {"XBLOCK": 256}) + print(f"✅ Generated MLIR ({len(mlir)} chars):\n") + print(mlir) diff --git a/third_party/ascend/examples/pytest_ut/test_top2gating_argmax.py b/third_party/ascend/unittest/pytest_ut/test_top2gating_argmax.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_top2gating_argmax.py rename to third_party/ascend/unittest/pytest_ut/test_top2gating_argmax.py diff --git a/third_party/ascend/examples/pytest_ut/test_topk.py b/third_party/ascend/unittest/pytest_ut/test_topk.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_topk.py rename to third_party/ascend/unittest/pytest_ut/test_topk.py diff --git a/third_party/ascend/examples/pytest_ut/test_trans_3d.py b/third_party/ascend/unittest/pytest_ut/test_trans_3d.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_trans_3d.py rename to third_party/ascend/unittest/pytest_ut/test_trans_3d.py diff --git a/third_party/ascend/examples/pytest_ut/test_triton_eq.py b/third_party/ascend/unittest/pytest_ut/test_triton_eq.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_triton_eq.py rename to third_party/ascend/unittest/pytest_ut/test_triton_eq.py diff --git a/third_party/ascend/examples/pytest_ut/test_triton_le.py b/third_party/ascend/unittest/pytest_ut/test_triton_le.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_triton_le.py rename to third_party/ascend/unittest/pytest_ut/test_triton_le.py diff --git a/third_party/ascend/examples/pytest_ut/test_triton_lt.py b/third_party/ascend/unittest/pytest_ut/test_triton_lt.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_triton_lt.py rename to third_party/ascend/unittest/pytest_ut/test_triton_lt.py diff --git a/third_party/ascend/examples/pytest_ut/test_triton_neq.py b/third_party/ascend/unittest/pytest_ut/test_triton_neq.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_triton_neq.py rename to third_party/ascend/unittest/pytest_ut/test_triton_neq.py diff --git a/third_party/ascend/examples/pytest_ut/test_triton_unified_attention.py b/third_party/ascend/unittest/pytest_ut/test_triton_unified_attention.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_triton_unified_attention.py rename to third_party/ascend/unittest/pytest_ut/test_triton_unified_attention.py diff --git a/third_party/ascend/examples/pytest_ut/test_trunc.py b/third_party/ascend/unittest/pytest_ut/test_trunc.py similarity index 94% rename from third_party/ascend/examples/pytest_ut/test_trunc.py rename to third_party/ascend/unittest/pytest_ut/test_trunc.py index 137b29ee4..ec02d678b 100644 --- a/third_party/ascend/examples/pytest_ut/test_trunc.py +++ b/third_party/ascend/unittest/pytest_ut/test_trunc.py @@ -1,68 +1,68 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. - -import pytest -import triton -import torch -import triton.language as tl -import triton.language.extra.ascend.libdevice as libdevice -import test_common - - -@triton.jit -def trunc_kernel(x_ptr, y_ptr, n_elements, BLOCK_SIZE: tl.constexpr): - pid = tl.program_id(axis=0) - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - - mask = offsets < n_elements - - x = tl.load(x_ptr + offsets, mask=mask) - - y = libdevice.trunc(x) - - tl.store(y_ptr + offsets, y, mask=mask) - - -@pytest.mark.parametrize('shape', [ - (12, 16), -]) -@pytest.mark.parametrize('dtype', ['float32']) -def test_cases(shape, dtype): - n_elements = shape[0] * shape[1] - x = test_common.generate_tensor(shape, dtype).npu() - - # Make sure to include some edge cases. - x[0, 0] = 0.0 - x[0, 1] = 3.14 - x[0, 2] = -2.71 - x[0, 3] = 5.0 - x[0, 4] = -3.0 - - y = torch.empty_like(x) - - BLOCK_SIZE = 192 - grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) - - trunc_kernel[grid](x, y, n_elements, BLOCK_SIZE=BLOCK_SIZE) - - expected = torch.trunc(x) - - torch.testing.assert_close(y, expected, rtol=1e-3, atol=1e-3) +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + +import pytest +import triton +import torch +import triton.language as tl +import triton.language.extra.cann.libdevice as libdevice +import test_common + + +@triton.jit +def trunc_kernel(x_ptr, y_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + mask = offsets < n_elements + + x = tl.load(x_ptr + offsets, mask=mask) + + y = libdevice.trunc(x) + + tl.store(y_ptr + offsets, y, mask=mask) + + +@pytest.mark.parametrize('shape', [ + (12, 16), +]) +@pytest.mark.parametrize('dtype', ['float32']) +def test_cases(shape, dtype): + n_elements = shape[0] * shape[1] + x = test_common.generate_tensor(shape, dtype).npu() + + # Make sure to include some edge cases. + x[0, 0] = 0.0 + x[0, 1] = 3.14 + x[0, 2] = -2.71 + x[0, 3] = 5.0 + x[0, 4] = -3.0 + + y = torch.empty_like(x) + + BLOCK_SIZE = 192 + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + + trunc_kernel[grid](x, y, n_elements, BLOCK_SIZE=BLOCK_SIZE) + + expected = torch.trunc(x) + + torch.testing.assert_close(y, expected, rtol=1e-3, atol=1e-3) diff --git a/third_party/ascend/examples/pytest_ut/test_umulhi.py b/third_party/ascend/unittest/pytest_ut/test_umulhi.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_umulhi.py rename to third_party/ascend/unittest/pytest_ut/test_umulhi.py diff --git a/third_party/ascend/examples/pytest_ut/test_unlign_max_with_index.py b/third_party/ascend/unittest/pytest_ut/test_unlign_max_with_index.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_unlign_max_with_index.py rename to third_party/ascend/unittest/pytest_ut/test_unlign_max_with_index.py diff --git a/third_party/ascend/examples/pytest_ut/test_unlign_sum.py b/third_party/ascend/unittest/pytest_ut/test_unlign_sum.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_unlign_sum.py rename to third_party/ascend/unittest/pytest_ut/test_unlign_sum.py diff --git a/third_party/ascend/examples/pytest_ut/test_unused_func_arg.py b/third_party/ascend/unittest/pytest_ut/test_unused_func_arg.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_unused_func_arg.py rename to third_party/ascend/unittest/pytest_ut/test_unused_func_arg.py diff --git a/third_party/ascend/examples/pytest_ut/test_view.py b/third_party/ascend/unittest/pytest_ut/test_view.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_view.py rename to third_party/ascend/unittest/pytest_ut/test_view.py diff --git a/third_party/ascend/examples/pytest_ut/test_where_lt.py b/third_party/ascend/unittest/pytest_ut/test_where_lt.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_where_lt.py rename to third_party/ascend/unittest/pytest_ut/test_where_lt.py diff --git a/third_party/ascend/examples/pytest_ut/test_where_mask.py b/third_party/ascend/unittest/pytest_ut/test_where_mask.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_where_mask.py rename to third_party/ascend/unittest/pytest_ut/test_where_mask.py diff --git a/third_party/ascend/examples/pytest_ut/test_where_var.py b/third_party/ascend/unittest/pytest_ut/test_where_var.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_where_var.py rename to third_party/ascend/unittest/pytest_ut/test_where_var.py diff --git a/third_party/ascend/examples/pytest_ut/test_xor.py b/third_party/ascend/unittest/pytest_ut/test_xor.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_xor.py rename to third_party/ascend/unittest/pytest_ut/test_xor.py diff --git a/third_party/ascend/examples/pytest_ut/test_xor_sum.py b/third_party/ascend/unittest/pytest_ut/test_xor_sum.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_xor_sum.py rename to third_party/ascend/unittest/pytest_ut/test_xor_sum.py diff --git a/third_party/ascend/examples/pytest_ut/test_zeros.py b/third_party/ascend/unittest/pytest_ut/test_zeros.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_zeros.py rename to third_party/ascend/unittest/pytest_ut/test_zeros.py diff --git a/third_party/ascend/examples/pytest_ut/test_zeroslike.py b/third_party/ascend/unittest/pytest_ut/test_zeroslike.py similarity index 100% rename from third_party/ascend/examples/pytest_ut/test_zeroslike.py rename to third_party/ascend/unittest/pytest_ut/test_zeroslike.py diff --git a/third_party/ascend/examples/run_prtest.sh b/third_party/ascend/unittest/run_prtest.sh similarity index 65% rename from third_party/ascend/examples/run_prtest.sh rename to third_party/ascend/unittest/run_prtest.sh index 822216f5f..cbff6889b 100755 --- a/third_party/ascend/examples/run_prtest.sh +++ b/third_party/ascend/unittest/run_prtest.sh @@ -1,16 +1,17 @@ #!/bin/bash -# notice: 本脚本需运行在py311的TA环境上 +# notice: this script supports python3.11.x set -ex script=$(readlink -f "$0") script_dir=$(dirname "$script") -# 清理旧日志 +# clean old logs mkdir -p /home/pr_test_log +UNITTEST_DIR="triton-ascend/third_party/ascend/unittest" -# 新增:定义统计文件路径 -SUMMARY_FILE="${WORKSPACE}/triton-ascend/ascend/examples/summary.txt" +# define summary file path +SUMMARY_FILE="${WORKSPACE}/${UNITTEST_DIR}/summary.txt" function clean_cache() { if [ -d /tmp/torchinductor_* ];then @@ -38,20 +39,20 @@ function run_case_by_multi_card() { test_dir=$1 cd ${test_dir} - # 清理旧日志 + # clean logs rm -rf logs && mkdir logs - # 记录测试开始时间 + # record start time start_time=$(date +"%Y-%m-%d %H:%M:%S") - echo "===== 测试开始时间: ${start_time} =====" + echo "===== Test Start Time: ${start_time} =====" - # 运行测试并捕获退出状态 + # run tests and capture exit status set +e python -m pytest ${test_dir} -n auto --dist=loadfile -v --junitxml=logs/results.xml | tee logs/raw_output.log pytest_exit=$? set -e - # 处理日志(添加设备标签) + # process logs (add device tags) awk ' />> Worker gw[0-9]+ using NPU device/ { split($0, parts, / /) @@ -63,14 +64,14 @@ function run_case_by_multi_card() { { print "[" strftime("%Y-%m-%d %H:%M:%S") "| DEV-" dev_id "] " $0 } ' logs/raw_output.log > logs/combined.log - # 新增:解析测试结果统计 + # parse test result statistics total_tests=0 passed_tests=0 failed_tests=0 skipped_tests=0 error_tests=0 - # 使用Python解析JUnit XML报告 + # use Python to parse JUnit XML report python3 -c " import xml.etree.ElementTree as ET import os @@ -104,44 +105,44 @@ print(f'skipped_tests={skipped}') print(f'error_tests={errors}') " > logs/stats.tmp - # 加载统计结果 + # load stats source logs/stats.tmp rm logs/stats.tmp - # 记录测试结束时间 + # record end time end_time=$(date +"%Y-%m-%d %H:%M:%S") duration=$(( $(date -d "$end_time" +%s) - $(date -d "$start_time" +%s) )) duration_str=$(printf "%02dh %02dm %02ds" $((duration/3600)) $(((duration%3600)/60)) $((duration%60))) - # 新增:生成统计摘要 + # generate summary stats_summary=" -===== generalization_cases测试统计摘要 ===== -测试目录: $(basename ${test_dir}) -测试开始时间: ${start_time} -测试结束时间: ${end_time} -总耗时: ${duration_str} +===== Test Summary - [generalization_cases] ===== +Test Directory: $(basename ${test_dir}) +Test Start Time: ${start_time} +Test End Time: ${end_time} +Total Duration: ${duration_str} ------------------------ -总用例数: ${total_tests} -成功用例: ${passed_tests} -失败用例: ${failed_tests} -跳过用例: ${skipped_tests} -错误用例: ${error_tests} -成功率: $(( passed_tests * 100 / total_tests ))% (成功/总数) -设备数量: ${NPU_DEVICES} +Total Tests: ${total_tests} +Passed Tests: ${passed_tests} +Failed Tests: ${failed_tests} +Skipped Tests: ${skipped_tests} +Error Tests: ${error_tests} +Success Rate: $(( passed_tests * 100 / total_tests ))% (Passed/Total) +NPU Devices: ${NPU_DEVICES} ======================== " - # 输出统计信息到控制台 + # output stats summary to console echo "${stats_summary}" - # 追加统计信息到summary.txt + # append stats summary to summary.txt echo "${stats_summary}" >> ${SUMMARY_FILE} echo "========================================" echo "All tests completed!" echo "JUnit Report: logs/results.xml" echo "Combined Log: logs/combined.log" - echo "统计摘要已追加到: ${SUMMARY_FILE}" + echo "Stats Summary has been appended to: ${SUMMARY_FILE}" echo "========================================" zip_file=$2 @@ -149,33 +150,21 @@ print(f'error_tests={errors}') zip ${zip_file} combined.log cp ${zip_file} "/home/pr_test_log" - # 返回pytest的退出状态 + # return pytest exit status return $pytest_exit } -# 初始化统计文件 -echo "生成时间: $(date +"%Y-%m-%d %H:%M:%S")" >> ${SUMMARY_FILE} +# initialize stats file +echo "Generate Time: $(date +"%Y-%m-%d %H:%M:%S")" >> ${SUMMARY_FILE} echo "========================================" >> ${SUMMARY_FILE} # run gene case -zip_file="test_generalizetion_case_$(date +%Y%m%d).zip" -TEST_generalization="${WORKSPACE}/triton-ascend/ascend/examples/generalization_cases" +zip_file="test_generalization_case_$(date +%Y%m%d).zip" +TEST_generalization="${WORKSPACE}/${UNITTEST_DIR}/generalization_cases" clean_cache run_case_by_multi_card ${TEST_generalization} ${zip_file} echo "========================================" >> ${SUMMARY_FILE} -# run flaggems cases -TEST_flaggems_cases="${WORKSPACE}/triton-ascend/ascend/examples/flaggems_cases" -cd ${TEST_flaggems_cases} -clean_cache -bash run_flaggems_test.sh - -# run inductor cases -TEST_inductor_cases="${WORKSPACE}/triton-ascend/ascend/examples/inductor_cases" -cd ${TEST_inductor_cases} -clean_cache -bash run_inductor_test.sh - # copy summary.txt to /home/pr_test_log cp ${SUMMARY_FILE} /home/pr_test_log diff --git a/third_party/tests/ascend/vector-add.py b/third_party/tests/ascend/vector-add.py deleted file mode 100644 index 288c1d987..000000000 --- a/third_party/tests/ascend/vector-add.py +++ /dev/null @@ -1,80 +0,0 @@ -""" -Vector Addition -=============== - -In this tutorial, you will write a simple vector addition using Triton. - -In doing so, you will learn about: - -* The basic programming model of Triton. - -* The `triton.jit` decorator, which is used to define Triton kernels. - -* The best practices for validating and benchmarking your custom ops against native reference implementations. - -""" - -# %% -# Compute Kernel -# -------------- - -import torch -import torch_npu - -import triton -import triton.language as tl - - -@triton.jit -def add_kernel(x_ptr, # *Pointer* to first input vector. - y_ptr, # *Pointer* to second input vector. - output_ptr, # *Pointer* to output vector. - n_elements, # Size of the vector. - BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. - # NOTE: `constexpr` so it can be used as a shape value. - ): - # There are multiple 'programs' processing different data. We identify which program - # we are here: - pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. - # This program will process inputs that are offset from the initial data. - # For instance, if you had a vector of length 256 and block_size of 64, the programs - # would each access the elements [0:64, 64:128, 128:192, 192:256]. - # Note that offsets is a list of pointers: - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - # Create a mask to guard memory operations against out-of-bounds accesses. - mask = offsets < n_elements - # Load x and y from DRAM, masking out any extra elements in case the input is not a - # multiple of the block size. - x = tl.load(x_ptr + offsets, mask=mask) - y = tl.load(y_ptr + offsets, mask=mask) - output = x + y - # Write x + y back to DRAM. - tl.store(output_ptr + offsets, output, mask=mask) - - -# %% -# Let's also declare a helper function to (1) allocate the `z` tensor -# and (2) enqueue the above kernel with appropriate grid/block sizes: - - -def add(x: torch.Tensor, y: torch.Tensor): - output = torch.empty_like(x) - n_elements = output.numel() - grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) - add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) - return output - - -# %% -# We can now use the above function to compute the element-wise sum of two `torch.tensor` objects and test its correctness: -torch.manual_seed(0) -size = 98432 -x = torch.rand(size, device='npu') -y = torch.rand(size, device='npu') -output_torch = x + y -output_triton = add(x, y) -print(output_torch) -print(output_triton) -print(f'The maximum difference between torch and triton is ' - f'{torch.max(torch.abs(output_torch - output_triton))}')