diff --git a/CMakeLists.txt b/CMakeLists.txt index 49633a6..fbc4717 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT + cmake_minimum_required(VERSION 3.18 FATAL_ERROR) project(AOTriton CXX C) diff --git a/Makefile b/Makefile deleted file mode 100644 index 9cbc77c..0000000 --- a/Makefile +++ /dev/null @@ -1,43 +0,0 @@ -NPROC=$(shell nproc) - -v2all: - mkdir -p build - python -m v2python.generate_compile --target_gpus MI200 - # (. build/venv/bin/activate; cd build; LD_PRELOAD=/opt/rocm/lib/libamdocl64.so make -j $(NPROC) -f Makefile.compile) - python -m v2python.generate_shim --target_gpus MI200 --build_dir build - # (. build/venv/bin/activate; cd build/flash/; hipcc -std=c++20 -c -I../../include -I../../third_party/incbin/ attn_fwd.cc) - # (. build/venv/bin/activate; cd build/flash/autotune.attn_fwd; hipcc -std=c++20 -c -I../../../include -I../../../third_party/incbin/ 'FONLY__^bf16@16,1,128,False,True___MI200.cc') - (. build/venv/bin/activate; cd build; make -j $(NPROC) -f Makefile.shim) - -check: - nm -DC build/libaotriton_v2.so |grep aotriton|grep 'U ' - -test: - PYTHONPATH=39build/bindings/ pytest -s test/test_forward.py - -all: - mkdir -p build - python python/generate.py --target MI200 - (. build/venv/bin/activate; cd build; LD_PRELOAD=/opt/rocm/lib/libamdocl64.so make -j $(NPROC) -f Makefile.compile) - python python/generate_shim.py - (. build/venv/bin/activate; cd build; make -j $(NPROC) -f Makefile.shim) - -format: - find bindings/ include/ v2src/ \( -name '*.h' -or -name '*.cc' \) -not -path '*template/*' -exec clang-format -i {} \; - -test_compile: - hipcc -o build/test_compile test/test_compile.cc -L build -laotriton -Wl,-rpath=. -I/opt/rocm/include -Ibuild/ - -clean: - (cd build/; rm -f *.h *.so *.cc *.o *.json *.hsaco) - -create_venv: - python -m venv build/venv - -triton_install_develop: - (. build/venv/bin/activate; pip install -r requirements.txt; cd third_party/triton/python/; pip install -e .) - -triton_install: - (. build/venv/bin/activate; pip install -r requirements.txt; cd third_party/triton/python/; pip install .) - -.PHONY: all clean test_compile create_venv triton_install triton_install_develop check test diff --git a/README.md b/README.md index 574da98..2f1734b 100644 --- a/README.md +++ b/README.md @@ -1,16 +1,18 @@ -## Usage +## Build Instructions ``` mkdir build -python python/generate.py -(cd build; make -j `nproc` -f Makefile.compile) -python python/generate_shim.py -(cd build; make -j `nproc` -f Makefile.shim) +cd build +cmake .. -DCMAKE_INSTALL_PREFIX=./install_dir +# Use ccmake to tweak options +make install ``` -Then the `attn_fwd.so` and `attn_fwd.h` can be found under `build/` +The library and the header file can be found under `build/install_dir` afterwards. + +Note: do not run `make` separately, due to the limit of the current build +system, `make install` will run the whole build process unconditionally. ### Prerequisites -* `hipcc` -* `triton` +* `hipcc` in `/opt/rocm/bin`, as a part of [ROCm](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/) diff --git a/bindings/CMakeLists.txt b/bindings/CMakeLists.txt index db7cd5f..dbf6979 100644 --- a/bindings/CMakeLists.txt +++ b/bindings/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT + aux_source_directory(. PYAOTRITON_SRC) # find_package(hip REQUIRED) pybind11_add_module(pyaotriton ${PYAOTRITON_SRC}) diff --git a/bindings/hipruntime.cc b/bindings/hipruntime.cc index bff11f9..86a86b7 100644 --- a/bindings/hipruntime.cc +++ b/bindings/hipruntime.cc @@ -1,3 +1,6 @@ +// Copyright © 2023-2024 Advanced Micro Devices, Inc. +// SPDX-License-Identifier: MIT + #include #include namespace py = pybind11; diff --git a/bindings/module.cc b/bindings/module.cc index a11b014..19cf8e9 100644 --- a/bindings/module.cc +++ b/bindings/module.cc @@ -1,3 +1,6 @@ +// Copyright © 2023-2024 Advanced Micro Devices, Inc. +// SPDX-License-Identifier: MIT + #include #include #include diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index e0a6885..1210ec8 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT + message("CMAKE_SOURCE_DIR ${CMAKE_SOURCE_DIR}") message("CMAKE_CURRENT_LIST_DIR ${CMAKE_CURRENT_LIST_DIR}") message("CMAKE_CURRENT_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}") diff --git a/csrc/README.md b/csrc/README.md new file mode 100644 index 0000000..72608bd --- /dev/null +++ b/csrc/README.md @@ -0,0 +1,3 @@ +# Deprecated Content + +The code under this directory is deprecated, and will be removed in future releases. diff --git a/csrc/aotriton_kernel.h b/csrc/aotriton_kernel.h index dc01fb6..643d1db 100644 --- a/csrc/aotriton_kernel.h +++ b/csrc/aotriton_kernel.h @@ -1,3 +1,6 @@ +// Copyright © 2023-2024 Advanced Micro Devices, Inc. +// SPDX-License-Identifier: MIT + #ifndef AOTRITON_KERNEL_H #define AOTRITON_KERNEL_H diff --git a/csrc/template/kernel_shim.cc b/csrc/template/kernel_shim.cc index d0ba1d8..5677870 100644 --- a/csrc/template/kernel_shim.cc +++ b/csrc/template/kernel_shim.cc @@ -1,3 +1,6 @@ +// Copyright © 2023-2024 Advanced Micro Devices, Inc. +// SPDX-License-Identifier: MIT + #define INCBIN_PREFIX g_aotriton_kernel_for_shim_ #define INCBIN_STYLE INCBIN_STYLE_SNAKE #include diff --git a/csrc/template/kernel_shim.footer.h b/csrc/template/kernel_shim.footer.h index 0d54950..88ba795 100644 --- a/csrc/template/kernel_shim.footer.h +++ b/csrc/template/kernel_shim.footer.h @@ -1,3 +1,6 @@ +// Copyright © 2023-2024 Advanced Micro Devices, Inc. +// SPDX-License-Identifier: MIT + }; // namespace aotriton::v1 #endif diff --git a/csrc/template/kernel_shim.header.h b/csrc/template/kernel_shim.header.h index bebe56a..d6e9815 100644 --- a/csrc/template/kernel_shim.header.h +++ b/csrc/template/kernel_shim.header.h @@ -1,3 +1,6 @@ +// Copyright © 2023-2024 Advanced Micro Devices, Inc. +// SPDX-License-Identifier: MIT + #ifndef AOTRITON_{shim_kernel_name}_H #define AOTRITON_{shim_kernel_name}_H diff --git a/include/aotriton/_internal/triton_kernel.h b/include/aotriton/_internal/triton_kernel.h index 38ca133..690894e 100644 --- a/include/aotriton/_internal/triton_kernel.h +++ b/include/aotriton/_internal/triton_kernel.h @@ -1,3 +1,6 @@ +// Copyright © 2023-2024 Advanced Micro Devices, Inc. +// SPDX-License-Identifier: MIT + #ifndef AOTRITON_V2_API_TRITON_KERNEL_H #define AOTRITON_V2_API_TRITON_KERNEL_H diff --git a/include/aotriton/dtypes.h b/include/aotriton/dtypes.h index 12991d9..b1841da 100644 --- a/include/aotriton/dtypes.h +++ b/include/aotriton/dtypes.h @@ -1,3 +1,6 @@ +// Copyright © 2023-2024 Advanced Micro Devices, Inc. +// SPDX-License-Identifier: MIT + #ifndef AOTRITON_V2_API_DTYPES_H #define AOTRITON_V2_API_DTYPES_H diff --git a/include/aotriton/flash.h b/include/aotriton/flash.h index fedae11..ae30751 100644 --- a/include/aotriton/flash.h +++ b/include/aotriton/flash.h @@ -1,3 +1,6 @@ +// Copyright © 2023-2024 Advanced Micro Devices, Inc. +// SPDX-License-Identifier: MIT + #ifndef AOTRITON_V2_API_FLASH_ATTN_H #define AOTRITON_V2_API_FLASH_ATTN_H diff --git a/include/aotriton/runtime.h b/include/aotriton/runtime.h index d8b86db..229b94f 100644 --- a/include/aotriton/runtime.h +++ b/include/aotriton/runtime.h @@ -1,3 +1,6 @@ +// Copyright © 2023-2024 Advanced Micro Devices, Inc. +// SPDX-License-Identifier: MIT + #ifndef AOTRITON_V2_API_RUNTIME_H #define AOTRITON_V2_API_RUNTIME_H diff --git a/include/aotriton/util.h b/include/aotriton/util.h index 1666019..13ef750 100644 --- a/include/aotriton/util.h +++ b/include/aotriton/util.h @@ -1,3 +1,6 @@ +// Copyright © 2023-2024 Advanced Micro Devices, Inc. +// SPDX-License-Identifier: MIT + #ifndef AOTRITON_V2_API_UTIL_H #define AOTRITON_V2_API_UTIL_H diff --git a/python/README.md b/python/README.md new file mode 100644 index 0000000..72608bd --- /dev/null +++ b/python/README.md @@ -0,0 +1,3 @@ +# Deprecated Content + +The code under this directory is deprecated, and will be removed in future releases. diff --git a/python/compile.py b/python/compile.py deleted file mode 100755 index 9aaa0e5..0000000 --- a/python/compile.py +++ /dev/null @@ -1,132 +0,0 @@ -#!/usr/bin/env python - -import hashlib -import importlib.util -import sys -from argparse import ArgumentParser -from pathlib import Path -from typing import List - -import triton - -import shutil -import subprocess -import json - -desc = """ -Triton ahead-of-time compiler: -""" - -def main(): - # command-line arguments - parser = ArgumentParser(description=desc) - parser.add_argument("path", help="Path to Python source containing desired kernel in its scope. File will be executed.") - parser.add_argument("--target", type=str, default=None, help="Ahead of Time (AOT) Compile Architecture. PyTorch is required for autodetection if --target is missing.") - parser.add_argument("--kernel_name", "-n", type=str, default="", help="Name of the kernel to compile", required=True) - parser.add_argument("--num_warps", "-w", type=int, default=1, help="Number of warps to launch the kernel") - parser.add_argument("--num_stages", "-ns", type=int, default=3, help="Number of stages (meta-parameter of the kernel)") - parser.add_argument("--waves_per_eu", type=int, default=0) - parser.add_argument("--out_path", "-o", type=Path, default=None, help="Out filename", required=True) - parser.add_argument("--signature", "-s", type=str, help="Signature of the kernel", required=True) - parser.add_argument("--grid", "-g", type=str, help="Launch grid of the kernel", required=True) - parser.add_argument("--verbose", "-v", help="Enable vebose output", action='store_true') - parser.add_argument("--nostrip", help="Keep debugging symbols", action='store_true') - args = parser.parse_args() - - out_path = args.out_path - out_path = out_path.with_suffix('') - - # execute python sources and extract functions wrapped in JITFunction - arg_path = Path(args.path) - sys.path.insert(0, str(arg_path.parent)) - ''' - spec = importlib.util.spec_from_file_location(arg_path.stem, arg_path) - mod = importlib.util.module_from_spec(spec) - spec.loader.exec_module(mod) - kernel = getattr(mod, args.kernel_name) - ''' - if True: - exec_string = f'import {arg_path.stem}' - # print(exec_string) - exec(exec_string, globals()) # importlib code path miss things - # print(globals()) - # kernel = globals()[f"{arg_path.stem}.{args.kernel_name}"] - mod = globals()[arg_path.stem] - kernel = getattr(mod, args.kernel_name) - # print(fused_attention_trimmed.attn_fwd) - if False: - mod = importlib.import_module(arg_path.stem) - print(mod.attn_fwd) - # print(fused_attention_trimmed.attn_fwd) - kernel = globals()[f"{arg_path.stem}.{args.kernel_name}"] - print(f"{kernel=}") - - grid = args.grid.split(",") - assert len(grid) == 3 - - # validate and parse signature - signature = list(map(lambda s: s.strip(" "), args.signature.split(","))) - - def hash_signature(signature: List[str]): - m = hashlib.sha256() - m.update(" ".join(signature).encode()) - return m.hexdigest()[:8] - - meta_sig = f"warps{args.num_warps}xstages{args.num_stages}" - sig_hash = hash_signature(signature + [meta_sig]) - - def constexpr(s): - try: - ret = int(s) - return ret - except ValueError: - pass - try: - ret = float(s) - return ret - except ValueError: - pass - if s == 'True': - return True - if s == 'False': - return False - return None - - hints = {i: constexpr(s.split(":")[1]) for i, s in enumerate(signature) if ":" in s} - hints = {k: v for k, v in hints.items() if v is not None} - constexprs = {i: constexpr(s) for i, s in enumerate(signature)} - constexprs = {k: v for k, v in constexprs.items() if v is not None} - # print(f"{constexprs=}") - signature = {i: s.split(":")[0] for i, s in enumerate(signature) if i not in constexprs} - const_sig = 'x'.join([str(v) for v in constexprs.values()]) - doc_string = [f"{kernel.arg_names[i]}={constexprs[i]}" for i in constexprs.keys()] - doc_string += [f"num_warps={args.num_warps}", f"num_stages={args.num_stages}"] - - # compile ast into cubin - for h in hints.values(): - assert h in [1, 16], f"Only 1 and 16 are valid hints, got {h}" - divisible_by_16 = [i for i, h in hints.items() if h == 16] - equal_to_1 = [i for i, h in hints.items() if h == 1] - config = triton.compiler.instance_descriptor(divisible_by_16=divisible_by_16, equal_to_1=equal_to_1) - for i in equal_to_1: - constexprs.update({i: 1}) - # print(f'{kernel=}') - ccinfo = triton.compile(kernel, signature=signature, constants=constexprs, configs=[config], num_warps=args.num_warps, num_stages=args.num_stages, waves_per_eu=args.waves_per_eu, aot_arch=args.target) - hsaco_path = ccinfo.asm.get('hsaco_path', None) - if args.verbose: - print(dir(ccinfo)) - print(f'{ccinfo.asm.keys()=}') - print(f'{ccinfo.fn=}') - print(f'{hsaco_path=}') - - if hsaco_path is not None: - if args.nostrip: - shutil.copy(hsaco_path, out_path.with_suffix('.hsaco')) - else: - subprocess.run(['/opt/rocm/llvm/bin/llvm-objcopy', '--remove-section', '.debug_*', str(hsaco_path), str(out_path.with_suffix('.hsaco'))]) - - with out_path.with_suffix('.json').open("w") as fp: - json.dump(ccinfo.metadata, fp, indent=2) - -if __name__ == "__main__": - main() diff --git a/python/compile.py b/python/compile.py new file mode 120000 index 0000000..22ff193 --- /dev/null +++ b/python/compile.py @@ -0,0 +1 @@ +../v2python/compile.py \ No newline at end of file diff --git a/python/generate.py b/python/generate.py index 4335b08..29c3526 100644 --- a/python/generate.py +++ b/python/generate.py @@ -1,3 +1,6 @@ +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT + import rules import io import shutil diff --git a/python/generate_shim.py b/python/generate_shim.py index 31623da..49ec3f0 100755 --- a/python/generate_shim.py +++ b/python/generate_shim.py @@ -1,3 +1,6 @@ +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT + import rules import io import shutil diff --git a/python/kernel_desc.py b/python/kernel_desc.py index e474b2c..2ca802d 100644 --- a/python/kernel_desc.py +++ b/python/kernel_desc.py @@ -1,4 +1,6 @@ #!/usr/bin/env python +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT from abc import ABC, abstractmethod from pathlib import Path diff --git a/python/object_desc.py b/python/object_desc.py index 19d4f11..36b4753 100644 --- a/python/object_desc.py +++ b/python/object_desc.py @@ -1,4 +1,6 @@ #!/usr/bin/env python +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT from pathlib import Path import json diff --git a/python/rules.py b/python/rules.py index d9d5994..37f3ff3 100644 --- a/python/rules.py +++ b/python/rules.py @@ -1,3 +1,6 @@ +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT + from kernel_desc import KernelDescription, get_possible_types def _pattern(arguments, prefix): diff --git a/test/aotriton_flash.py b/test/aotriton_flash.py index e448dd4..1c50bf9 100644 --- a/test/aotriton_flash.py +++ b/test/aotriton_flash.py @@ -1,3 +1,6 @@ +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT + from pyaotriton.v2.flash import attn_fwd as fa_forward, attn_bwd as fa_backward from pyaotriton import T1, T2, T4, DType, Stream diff --git a/test/attn_torch_function.py b/test/attn_torch_function.py index 34f5e56..0c57c94 100644 --- a/test/attn_torch_function.py +++ b/test/attn_torch_function.py @@ -1,4 +1,6 @@ #!/usr/bin/env python +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT import torch from aotriton_flash import attn_fwd, attn_bwd diff --git a/test/bwd_preprocess.py b/test/bwd_preprocess.py index 43313e2..96d3fee 100644 --- a/test/bwd_preprocess.py +++ b/test/bwd_preprocess.py @@ -1,4 +1,6 @@ #!/usr/bin/env python +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT """ Fused Attention diff --git a/test/bwd_split_kernel.py b/test/bwd_split_kernel.py index 8beeb45..153f874 100644 --- a/test/bwd_split_kernel.py +++ b/test/bwd_split_kernel.py @@ -1,4 +1,6 @@ #!/usr/bin/env python +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT """ Fused Attention diff --git a/test/fwd_kernel.py b/test/fwd_kernel.py index 0637bc4..401383a 100644 --- a/test/fwd_kernel.py +++ b/test/fwd_kernel.py @@ -1,4 +1,6 @@ #!/usr/bin/env python +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT """ Fused Attention diff --git a/test/test_backward.py b/test/test_backward.py index f29f4ef..dc6e339 100644 --- a/test/test_backward.py +++ b/test/test_backward.py @@ -1,4 +1,6 @@ #!/usr/bin/env python +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT import pytest import torch diff --git a/test/test_forward.py b/test/test_forward.py index 027045d..d0fc4dd 100644 --- a/test/test_forward.py +++ b/test/test_forward.py @@ -1,4 +1,6 @@ #!/usr/bin/env python +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT import pytest import torch diff --git a/test/triton_attn_torch_function.py b/test/triton_attn_torch_function.py index 0ed157a..f8944a2 100644 --- a/test/triton_attn_torch_function.py +++ b/test/triton_attn_torch_function.py @@ -1,4 +1,6 @@ #!/usr/bin/env python +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT import torch import triton diff --git a/test/triton_backward.py b/test/triton_backward.py index 9c1f200..0d5d519 100644 --- a/test/triton_backward.py +++ b/test/triton_backward.py @@ -1,4 +1,6 @@ #!/usr/bin/env python +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT import pytest import torch diff --git a/test/triton_forward.py b/test/triton_forward.py index 151acdb..1bb902d 100644 --- a/test/triton_forward.py +++ b/test/triton_forward.py @@ -1,4 +1,6 @@ #!/usr/bin/env python +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT import pytest import torch diff --git a/test/v1_test_compile.cc b/test/v1_test_compile.cc index e927cd3..dee9368 100644 --- a/test/v1_test_compile.cc +++ b/test/v1_test_compile.cc @@ -1,3 +1,6 @@ +// Copyright © 2023-2024 Advanced Micro Devices, Inc. +// SPDX-License-Identifier: MIT + #include #include "attn_fwd.h" diff --git a/tritonsrc/attn_torch_function.py b/tritonsrc/attn_torch_function.py index abe2522..7cec806 100644 --- a/tritonsrc/attn_torch_function.py +++ b/tritonsrc/attn_torch_function.py @@ -1,4 +1,6 @@ #!/usr/bin/env python +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT import torch import triton diff --git a/tritonsrc/bwd_preprocess.py b/tritonsrc/bwd_preprocess.py index 43313e2..96d3fee 100644 --- a/tritonsrc/bwd_preprocess.py +++ b/tritonsrc/bwd_preprocess.py @@ -1,4 +1,6 @@ #!/usr/bin/env python +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT """ Fused Attention diff --git a/tritonsrc/bwd_split_kernel.py b/tritonsrc/bwd_split_kernel.py index a47fd4e..d13c363 100644 --- a/tritonsrc/bwd_split_kernel.py +++ b/tritonsrc/bwd_split_kernel.py @@ -1,4 +1,6 @@ #!/usr/bin/env python +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT """ Fused Attention diff --git a/tritonsrc/flash.py b/tritonsrc/flash.py index eacaf2a..f3a41bb 100644 --- a/tritonsrc/flash.py +++ b/tritonsrc/flash.py @@ -1,3 +1,6 @@ +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT + """ Fused Attention =============== diff --git a/tritonsrc/fwd_kernel.py b/tritonsrc/fwd_kernel.py index 0637bc4..401383a 100644 --- a/tritonsrc/fwd_kernel.py +++ b/tritonsrc/fwd_kernel.py @@ -1,4 +1,6 @@ #!/usr/bin/env python +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT """ Fused Attention diff --git a/tritonsrc/performance_forward.py b/tritonsrc/performance_forward.py index 50e4cc7..2977523 100644 --- a/tritonsrc/performance_forward.py +++ b/tritonsrc/performance_forward.py @@ -1,4 +1,6 @@ #!/usr/bin/env python +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT import pytest import torch diff --git a/tritonsrc/rocm_arch.py b/tritonsrc/rocm_arch.py index ffdd428..8578d69 100644 --- a/tritonsrc/rocm_arch.py +++ b/tritonsrc/rocm_arch.py @@ -1,3 +1,6 @@ +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT + import subprocess def rocm_get_gpuarch(): diff --git a/tritonsrc/test_backward.py b/tritonsrc/test_backward.py index 104fd53..d170863 100644 --- a/tritonsrc/test_backward.py +++ b/tritonsrc/test_backward.py @@ -1,4 +1,6 @@ #!/usr/bin/env python +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT import pytest import torch diff --git a/tritonsrc/test_forward.py b/tritonsrc/test_forward.py index 497b505..2419780 100644 --- a/tritonsrc/test_forward.py +++ b/tritonsrc/test_forward.py @@ -1,4 +1,6 @@ #!/usr/bin/env python +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT import pytest import torch diff --git a/tritonsrc/tune_flash.py b/tritonsrc/tune_flash.py index 5796234..3b3265b 100644 --- a/tritonsrc/tune_flash.py +++ b/tritonsrc/tune_flash.py @@ -1,4 +1,6 @@ #!/usr/bin/env python +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT import pytest import torch diff --git a/tritonsrc/v1/bwd_preprocess.py b/tritonsrc/v1/bwd_preprocess.py index ff3203a..0078197 100644 --- a/tritonsrc/v1/bwd_preprocess.py +++ b/tritonsrc/v1/bwd_preprocess.py @@ -1,4 +1,6 @@ #!/usr/bin/env python +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT """ Fused Attention diff --git a/tritonsrc/v1/bwd_split_kernel.py b/tritonsrc/v1/bwd_split_kernel.py index 1e1e798..9bfdd44 100644 --- a/tritonsrc/v1/bwd_split_kernel.py +++ b/tritonsrc/v1/bwd_split_kernel.py @@ -1,4 +1,6 @@ #!/usr/bin/env python +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT """ Fused Attention diff --git a/tritonsrc/v1/fused_attention_trimmed.py b/tritonsrc/v1/fused_attention_trimmed.py index c4e0690..e685d8b 100644 --- a/tritonsrc/v1/fused_attention_trimmed.py +++ b/tritonsrc/v1/fused_attention_trimmed.py @@ -1,3 +1,6 @@ +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT + """ Fused Attention =============== diff --git a/tritonsrc/v1/fwd_kernel.py b/tritonsrc/v1/fwd_kernel.py index dced207..5e38746 100644 --- a/tritonsrc/v1/fwd_kernel.py +++ b/tritonsrc/v1/fwd_kernel.py @@ -1,4 +1,6 @@ #!/usr/bin/env python +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT """ Fused Attention diff --git a/v2python/autotune_binning.py b/v2python/autotune_binning.py index f11d043..460344e 100644 --- a/v2python/autotune_binning.py +++ b/v2python/autotune_binning.py @@ -1,3 +1,6 @@ +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT + class Binning(object): pass diff --git a/v2python/compile.py b/v2python/compile.py deleted file mode 120000 index 13eab3c..0000000 --- a/v2python/compile.py +++ /dev/null @@ -1 +0,0 @@ -../python/compile.py \ No newline at end of file diff --git a/v2python/compile.py b/v2python/compile.py new file mode 100755 index 0000000..cad0e38 --- /dev/null +++ b/v2python/compile.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT + +import hashlib +import importlib.util +import sys +from argparse import ArgumentParser +from pathlib import Path +from typing import List + +import triton + +import shutil +import subprocess +import json + +desc = """ +Triton ahead-of-time compiler: +""" + +def main(): + # command-line arguments + parser = ArgumentParser(description=desc) + parser.add_argument("path", help="Path to Python source containing desired kernel in its scope. File will be executed.") + parser.add_argument("--target", type=str, default=None, help="Ahead of Time (AOT) Compile Architecture. PyTorch is required for autodetection if --target is missing.") + parser.add_argument("--kernel_name", "-n", type=str, default="", help="Name of the kernel to compile", required=True) + parser.add_argument("--num_warps", "-w", type=int, default=1, help="Number of warps to launch the kernel") + parser.add_argument("--num_stages", "-ns", type=int, default=3, help="Number of stages (meta-parameter of the kernel)") + parser.add_argument("--waves_per_eu", type=int, default=0) + parser.add_argument("--out_path", "-o", type=Path, default=None, help="Out filename", required=True) + parser.add_argument("--signature", "-s", type=str, help="Signature of the kernel", required=True) + parser.add_argument("--grid", "-g", type=str, help="Launch grid of the kernel", required=True) + parser.add_argument("--verbose", "-v", help="Enable vebose output", action='store_true') + parser.add_argument("--nostrip", help="Keep debugging symbols", action='store_true') + args = parser.parse_args() + + out_path = args.out_path + out_path = out_path.with_suffix('') + + # execute python sources and extract functions wrapped in JITFunction + arg_path = Path(args.path) + sys.path.insert(0, str(arg_path.parent)) + ''' + spec = importlib.util.spec_from_file_location(arg_path.stem, arg_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + kernel = getattr(mod, args.kernel_name) + ''' + if True: + exec_string = f'import {arg_path.stem}' + # print(exec_string) + exec(exec_string, globals()) # importlib code path miss things + # print(globals()) + # kernel = globals()[f"{arg_path.stem}.{args.kernel_name}"] + mod = globals()[arg_path.stem] + kernel = getattr(mod, args.kernel_name) + # print(fused_attention_trimmed.attn_fwd) + if False: + mod = importlib.import_module(arg_path.stem) + print(mod.attn_fwd) + # print(fused_attention_trimmed.attn_fwd) + kernel = globals()[f"{arg_path.stem}.{args.kernel_name}"] + print(f"{kernel=}") + + grid = args.grid.split(",") + assert len(grid) == 3 + + # validate and parse signature + signature = list(map(lambda s: s.strip(" "), args.signature.split(","))) + + def hash_signature(signature: List[str]): + m = hashlib.sha256() + m.update(" ".join(signature).encode()) + return m.hexdigest()[:8] + + meta_sig = f"warps{args.num_warps}xstages{args.num_stages}" + sig_hash = hash_signature(signature + [meta_sig]) + + def constexpr(s): + try: + ret = int(s) + return ret + except ValueError: + pass + try: + ret = float(s) + return ret + except ValueError: + pass + if s == 'True': + return True + if s == 'False': + return False + return None + + hints = {i: constexpr(s.split(":")[1]) for i, s in enumerate(signature) if ":" in s} + hints = {k: v for k, v in hints.items() if v is not None} + constexprs = {i: constexpr(s) for i, s in enumerate(signature)} + constexprs = {k: v for k, v in constexprs.items() if v is not None} + # print(f"{constexprs=}") + signature = {i: s.split(":")[0] for i, s in enumerate(signature) if i not in constexprs} + const_sig = 'x'.join([str(v) for v in constexprs.values()]) + doc_string = [f"{kernel.arg_names[i]}={constexprs[i]}" for i in constexprs.keys()] + doc_string += [f"num_warps={args.num_warps}", f"num_stages={args.num_stages}"] + + # compile ast into cubin + for h in hints.values(): + assert h in [1, 16], f"Only 1 and 16 are valid hints, got {h}" + divisible_by_16 = [i for i, h in hints.items() if h == 16] + equal_to_1 = [i for i, h in hints.items() if h == 1] + config = triton.compiler.instance_descriptor(divisible_by_16=divisible_by_16, equal_to_1=equal_to_1) + for i in equal_to_1: + constexprs.update({i: 1}) + # print(f'{kernel=}') + ccinfo = triton.compile(kernel, signature=signature, constants=constexprs, configs=[config], num_warps=args.num_warps, num_stages=args.num_stages, waves_per_eu=args.waves_per_eu, aot_arch=args.target) + hsaco_path = ccinfo.asm.get('hsaco_path', None) + if args.verbose: + print(dir(ccinfo)) + print(f'{ccinfo.asm.keys()=}') + print(f'{ccinfo.fn=}') + print(f'{hsaco_path=}') + + if hsaco_path is not None: + if args.nostrip: + shutil.copy(hsaco_path, out_path.with_suffix('.hsaco')) + else: + subprocess.run(['/opt/rocm/llvm/bin/llvm-objcopy', '--remove-section', '.debug_*', str(hsaco_path), str(out_path.with_suffix('.hsaco'))]) + + with out_path.with_suffix('.json').open("w") as fp: + json.dump(ccinfo.metadata, fp, indent=2) + +if __name__ == "__main__": + main() diff --git a/v2python/generate_compile.py b/v2python/generate_compile.py index 800e4b9..8117e5d 100644 --- a/v2python/generate_compile.py +++ b/v2python/generate_compile.py @@ -1,3 +1,6 @@ +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT + from .rules import kernels as triton_kernels from .tuning_database import KernelTuningDatabase import io diff --git a/v2python/generate_shim.py b/v2python/generate_shim.py index 8fb1a48..8ffebb3 100755 --- a/v2python/generate_shim.py +++ b/v2python/generate_shim.py @@ -1,3 +1,6 @@ +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT + from .rules import kernels as triton_kernels from .tuning_database import KernelTuningDatabase import io diff --git a/v2python/gpu_targets.py b/v2python/gpu_targets.py index 36e0668..86ef5dc 100644 --- a/v2python/gpu_targets.py +++ b/v2python/gpu_targets.py @@ -1,3 +1,6 @@ +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT + AOTRITON_SUPPORTED_GPUS = { 'MI200' : 'GPU_ARCH_AMD_GFX90A', diff --git a/v2python/kernel_argument.py b/v2python/kernel_argument.py index 487f88d..f25047b 100644 --- a/v2python/kernel_argument.py +++ b/v2python/kernel_argument.py @@ -1,3 +1,6 @@ +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT + import numpy as np from enum import Enum from .object_desc import ObjectFileDescription diff --git a/v2python/kernel_desc.py b/v2python/kernel_desc.py index ca04696..3e8a7b2 100644 --- a/v2python/kernel_desc.py +++ b/v2python/kernel_desc.py @@ -1,3 +1,6 @@ +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT + import itertools from collections import defaultdict import io diff --git a/v2python/kernel_signature.py b/v2python/kernel_signature.py index 4b62546..10dae4d 100644 --- a/v2python/kernel_signature.py +++ b/v2python/kernel_signature.py @@ -1,3 +1,6 @@ +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT + from .gpu_targets import AOTRITON_GPU_ARCH_TUNING_STRING class KernelSignature(object): diff --git a/v2python/object_desc.py b/v2python/object_desc.py index 2d3b128..b41af8c 100644 --- a/v2python/object_desc.py +++ b/v2python/object_desc.py @@ -1,4 +1,6 @@ #!/usr/bin/env python +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT from pathlib import Path import json diff --git a/v2python/rules/__init__.py b/v2python/rules/__init__.py index 3cd4234..f5a54a9 100644 --- a/v2python/rules/__init__.py +++ b/v2python/rules/__init__.py @@ -1 +1,4 @@ +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT + from .flash import kernels diff --git a/v2python/rules/flash/__init__.py b/v2python/rules/flash/__init__.py index 0714985..e9ea4b2 100644 --- a/v2python/rules/flash/__init__.py +++ b/v2python/rules/flash/__init__.py @@ -1,3 +1,6 @@ +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT + from .attn_fwd import attn_fwd from .bwd_preprocess import bwd_preprocess from .bwd_kernel_dk_dv import bwd_kernel_dk_dv diff --git a/v2python/rules/flash/_common.py b/v2python/rules/flash/_common.py index 7fd4306..b024dc6 100644 --- a/v2python/rules/flash/_common.py +++ b/v2python/rules/flash/_common.py @@ -1,3 +1,6 @@ +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT + from ...kernel_desc import KernelDescription, get_possible_types, select_pattern from ...autotune_binning import BinningLessOrEqual, BinningExact diff --git a/v2python/rules/flash/attn_fwd.py b/v2python/rules/flash/attn_fwd.py index cde6997..7526ac7 100644 --- a/v2python/rules/flash/attn_fwd.py +++ b/v2python/rules/flash/attn_fwd.py @@ -1,3 +1,6 @@ +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT + from ._common import FlashKernel, select_pattern, BinningLessOrEqual, BinningExact class attn_fwd(FlashKernel): diff --git a/v2python/rules/flash/bwd_kernel_dk_dv.py b/v2python/rules/flash/bwd_kernel_dk_dv.py index c78e87f..011cb32 100644 --- a/v2python/rules/flash/bwd_kernel_dk_dv.py +++ b/v2python/rules/flash/bwd_kernel_dk_dv.py @@ -1,3 +1,6 @@ +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT + from ._common import FlashKernel, get_possible_types, select_pattern, BinningLessOrEqual, BinningExact from .attn_fwd import attn_fwd diff --git a/v2python/rules/flash/bwd_kernel_dq.py b/v2python/rules/flash/bwd_kernel_dq.py index 67356fd..90895d2 100644 --- a/v2python/rules/flash/bwd_kernel_dq.py +++ b/v2python/rules/flash/bwd_kernel_dq.py @@ -1,3 +1,6 @@ +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT + from ._common import FlashKernel, get_possible_types, select_pattern, BinningLessOrEqual, BinningExact from .attn_fwd import attn_fwd from .bwd_kernel_dk_dv import bwd_kernel_dk_dv diff --git a/v2python/rules/flash/bwd_preprocess.py b/v2python/rules/flash/bwd_preprocess.py index f330bc2..22296c5 100644 --- a/v2python/rules/flash/bwd_preprocess.py +++ b/v2python/rules/flash/bwd_preprocess.py @@ -1,3 +1,6 @@ +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT + from ._common import FlashKernel, get_possible_types, select_pattern, BinningLessOrEqual, BinningExact from .attn_fwd import attn_fwd diff --git a/v2python/tuning_database.py b/v2python/tuning_database.py index 3f13f65..83059f5 100644 --- a/v2python/tuning_database.py +++ b/v2python/tuning_database.py @@ -1,3 +1,6 @@ +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT + import json import pathlib from copy import deepcopy diff --git a/v2python/tuning_lut.py b/v2python/tuning_lut.py index e658789..798e787 100644 --- a/v2python/tuning_lut.py +++ b/v2python/tuning_lut.py @@ -1,3 +1,6 @@ +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT + from .kernel_signature import KernelSignature from .kernel_desc import get_template import numpy as np diff --git a/v2src/CMakeLists.txt b/v2src/CMakeLists.txt index 48d6120..cd941b6 100644 --- a/v2src/CMakeLists.txt +++ b/v2src/CMakeLists.txt @@ -1,3 +1,6 @@ +# Copyright © 2023-2024 Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT + message("CMAKE_SOURCE_DIR ${CMAKE_SOURCE_DIR}") message("CMAKE_CURRENT_LIST_DIR ${CMAKE_CURRENT_LIST_DIR}") message("CMAKE_CURRENT_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}") diff --git a/v2src/flash/attn_bwd.cc b/v2src/flash/attn_bwd.cc index f98a870..8ec8f2e 100644 --- a/v2src/flash/attn_bwd.cc +++ b/v2src/flash/attn_bwd.cc @@ -1,3 +1,6 @@ +// Copyright © 2023-2024 Advanced Micro Devices, Inc. +// SPDX-License-Identifier: MIT + #include #include #include diff --git a/v2src/flash/attn_check.cc b/v2src/flash/attn_check.cc index 47da111..7aab43c 100644 --- a/v2src/flash/attn_check.cc +++ b/v2src/flash/attn_check.cc @@ -1,3 +1,6 @@ +// Copyright © 2023-2024 Advanced Micro Devices, Inc. +// SPDX-License-Identifier: MIT + #include #include #include diff --git a/v2src/flash/attn_fwd.cc b/v2src/flash/attn_fwd.cc index 345fe14..d49b1b9 100644 --- a/v2src/flash/attn_fwd.cc +++ b/v2src/flash/attn_fwd.cc @@ -1,3 +1,6 @@ +// Copyright © 2023-2024 Advanced Micro Devices, Inc. +// SPDX-License-Identifier: MIT + #include #include #include diff --git a/v2src/template/autotune_table_entry.cc b/v2src/template/autotune_table_entry.cc index 0eade10..ed58c46 100644 --- a/v2src/template/autotune_table_entry.cc +++ b/v2src/template/autotune_table_entry.cc @@ -1,3 +1,6 @@ +// Copyright © 2023-2024 Advanced Micro Devices, Inc. +// SPDX-License-Identifier: MIT + // clang-format off #define INCBIN_PREFIX g_aotriton_FAMILY_[[kernel_family_name]]_KERNEL_[[shim_kernel_name]]_GPU_[[gpu]]_ #define INCBIN_STYLE INCBIN_STYLE_SNAKE diff --git a/v2src/template/shim.cc b/v2src/template/shim.cc index f31c0d1..476d039 100644 --- a/v2src/template/shim.cc +++ b/v2src/template/shim.cc @@ -1,3 +1,6 @@ +// Copyright © 2023-2024 Advanced Micro Devices, Inc. +// SPDX-License-Identifier: MIT + // clang-format off #include "shim.[[shim_kernel_name]].h" #include diff --git a/v2src/template/shim.h b/v2src/template/shim.h index 2125720..f8ecd02 100644 --- a/v2src/template/shim.h +++ b/v2src/template/shim.h @@ -1,3 +1,6 @@ +// Copyright © 2023-2024 Advanced Micro Devices, Inc. +// SPDX-License-Identifier: MIT + // clang-format off #pragma once #include diff --git a/v2src/triton_kernel.cc b/v2src/triton_kernel.cc index 47dbe59..4e06bc8 100644 --- a/v2src/triton_kernel.cc +++ b/v2src/triton_kernel.cc @@ -1,3 +1,6 @@ +// Copyright © 2023-2024 Advanced Micro Devices, Inc. +// SPDX-License-Identifier: MIT + #include #include #include diff --git a/v2src/util.cc b/v2src/util.cc index f482241..28fca9d 100644 --- a/v2src/util.cc +++ b/v2src/util.cc @@ -1,3 +1,6 @@ +// Copyright © 2023-2024 Advanced Micro Devices, Inc. +// SPDX-License-Identifier: MIT + #include #include #include