Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 193 additions & 0 deletions transformer_engine/musa/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
import sys
import torch
import torch.utils
import torch.utils.data
import torch_musa


def patch_before_import_te():
from .pytorch import attention
from .pytorch import tensor
from .pytorch import fp8
from .pytorch import distributed
from .pytorch.module import base
from .pytorch.ops import op
from .pytorch.cpp_extensions import cast
from .pytorch.module import linear
from .pytorch.module import grouped_linear
from .pytorch import utils

def patch_after_import_torch():
def hook_cuda_device(device):
if isinstance(device, str) and device.startswith("cuda"):
return device.replace("cuda", "musa")
if isinstance(device, torch.device) and device.type == "cuda":
return torch.device("musa", device.index)
return device

def maybe_hook_cuda_args(args, kwargs):
new_args = []
for arg in args:
new_args.append(hook_cuda_device(arg))
if "device" in kwargs:
v = kwargs["device"]
kwargs['device'] = hook_cuda_device(v)
return tuple(new_args), kwargs

torch.cuda.is_available = torch.musa.is_available
torch.cuda.current_device = torch.musa.current_device
torch.cuda.device_count = torch.musa.device_count
torch.cuda.set_device = torch.musa.set_device
torch.cuda.DoubleTensor = torch.musa.DoubleTensor
torch.cuda.FloatTensor = torch.musa.FloatTensor
torch.cuda.LongTensor = torch.musa.LongTensor
torch.cuda.HalfTensor = torch.musa.HalfTensor
torch.cuda.BFloat16Tensor = torch.musa.BFloat16Tensor
torch.cuda.IntTensor = torch.musa.IntTensor
torch.cuda.synchronize = torch.musa.synchronize
torch.cuda.get_rng_state = torch.musa.get_rng_state
torch.cuda.set_rng_state = torch.musa.set_rng_state
torch.cuda.synchronize = torch.musa.synchronize
torch.cuda.empty_cache = torch.musa.empty_cache
torch.Tensor.cuda = torch.Tensor.musa
torch.cuda.manual_seed = torch.musa.manual_seed
torch.cuda.Event = torch.musa.Event
torch.cuda.Stream = torch.musa.Stream
torch.cuda.current_stream = torch.musa.current_stream
torch.cuda.set_stream = torch.musa.set_stream
torch.cuda.get_device_properties = torch.musa.get_device_properties
# add torch.musa.current_devce() to activate torch.musa.default_generators
d = torch.musa.current_device()
torch.cuda.default_generators = torch.musa.default_generators

torch.cuda.memory_allocated = torch.musa.memory_allocated
torch.cuda.max_memory_allocated = torch.musa.max_memory_allocated
torch.cuda.memory_reserved = torch.musa.memory_reserved
torch.cuda.max_memory_reserved = torch.musa.max_memory_reserved

# (yehua.zhang) replace lazy_call to avoid cpu memory leak,
# because failure of cuda init in lazy_call will cause endless operation of emplace back.
torch.cuda._lazy_call = torch.musa.core._lazy_init._lazy_call
torch.cuda._lazy_init = torch.musa.core._lazy_init._lazy_init

original_tensor = torch.tensor
def patched_tensor(*args, **kwargs):
args, kwargs = maybe_hook_cuda_args(args, kwargs)
result = original_tensor(*args, **kwargs)
return result
torch.tensor = patched_tensor

orig_type = torch.Tensor.type
def musa_type(*args, **kwargs):
result = orig_type(*args, **kwargs)
if isinstance(result, str):
result = result.replace("musa", "cuda")
return result
torch.Tensor.type = musa_type

original_zeros = torch.zeros
def patched_zeros(*args, **kwargs):
args, kwargs = maybe_hook_cuda_args(args, kwargs)
result = original_zeros(*args, **kwargs)
return result
torch.zeros = patched_zeros

original_ones = torch.ones
def patched_ones(*args, **kwargs):
args, kwargs = maybe_hook_cuda_args(args, kwargs)
result = original_ones(*args, **kwargs)
return result
torch.ones = patched_ones

original_empty = torch.empty
def patched_empty(*args, **kwargs):
args, kwargs = maybe_hook_cuda_args(args, kwargs)
result = original_empty(*args, **kwargs)
return result
torch.empty = patched_empty

original_rand = torch.rand
def patched_rand(*args, **kwargs):
args, kwargs = maybe_hook_cuda_args(args, kwargs)
result = original_rand(*args, **kwargs)
return result
torch.rand = patched_rand

original_arange = torch.arange
def patched_arange(*args, **kwargs):
args, kwargs = maybe_hook_cuda_args(args, kwargs)
result = original_arange(*args, **kwargs)
return result
torch.arange = patched_arange

original_empty_like = torch.empty_like
def patched_empty_like(*args, **kwargs):
args, kwargs = maybe_hook_cuda_args(args, kwargs)
result = original_empty_like(*args, **kwargs)
return result
torch.empty_like = patched_empty_like

original_is_cuda = torch.Tensor.is_cuda
def always_cuda(self):
return True
torch.Tensor.is_cuda = property(always_cuda)

origin_init_process_group = torch.distributed.init_process_group
def patched_init_process_group(*args, **kwargs):
if 'backend' in kwargs and kwargs['backend'] == 'nccl':
kwargs['backend'] = 'mccl'
result = origin_init_process_group(*args, **kwargs)
return result
torch.distributed.init_process_group = patched_init_process_group

# def pin_memory(data, device=None):
# return data
# torch.utils.data._utils.pin_memory.pin_memory = pin_memory

def _pass_pvtx(*args, **kwargs):
return
torch.cuda.nvtx.range_push = _pass_pvtx
torch.cuda.nvtx.range_pop = _pass_pvtx

torch.cuda.is_current_stream_capturing = lambda: False

origin_module_to = torch.nn.Module.to
def patched_module_to(self, *args, **kwargs):
args, kwargs = maybe_hook_cuda_args(args, kwargs)
return origin_module_to(self, *args, **kwargs)
torch.nn.Module.to = patched_module_to

origin_tensor_to = torch.Tensor.to
def patched_tensor_to(self, *args, **kwargs):
args, kwargs = maybe_hook_cuda_args(args, kwargs)
return origin_tensor_to(self, *args, **kwargs)
torch.Tensor.to = patched_tensor_to

def get_default_device():
device = torch.device("musa", torch.musa.current_device())
return device
torch.get_default_device = get_default_device

def is_autocast_enabled(device_type=None):
return False
torch.is_autocast_enabled = is_autocast_enabled

import os
#HACK(sherry): enable torch.compile
os.environ["NVTE_TORCH_COMPILE"] = "1"
os.environ["TORCHDYNAMO_DISABLE"] = "0"
#HACK(sherry)

def py_patch():
if sys.version_info >= (3.9, 0):
return
import math
def lcm(a, b):
return abs(a * b) // math.gcd(a, b)
math.lcm = lcm
return


py_patch()
patch_before_import_te()
patch_after_import_torch()
195 changes: 195 additions & 0 deletions transformer_engine/musa/common/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
cmake_minimum_required(VERSION 3.21)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

set(MUSA_DIR "/usr/local/musa")
set(MUSA_ARCH "31")

# Transformer Engine library
project(transformer_engine LANGUAGES CXX)

list(APPEND CMAKE_MODULE_PATH "${MUSA_DIR}/cmake")
list(APPEND CMAKE_MODULE_PATH "${MUSA_DIR}/lib/cmake/mudnn")

find_package(MUSA REQUIRED)
string(APPEND MUSA_MCC_FLAGS " -std=c++${CMAKE_CXX_STANDARD}")
string(APPEND MUSA_MCC_FLAGS " --offload-arch=mp_${MUSA_ARCH}")
# -mllvm -mtgpu-tempint-prealloc=1 just work for MUSA_ARCH=31
if (MUSA_ARCH STREQUAL "31")
string(APPEND MUSA_MCC_FLAGS " -mllvm -mtgpu-tempint-prealloc=1")
endif()
set(MUSA_VERBOSE_BUILD ON)
set(MUSA_LINK_LIBRARIES_KEYWORD PUBLIC)

if (CMAKE_BUILD_TYPE STREQUAL "Debug")
string(APPEND MUSA_MCC_FLAGS " -g")
endif()

set(DEPENDENT_TARGETS)

find_package(MUSAToolkit REQUIRED)
list(APPEND DEPENDENT_TARGETS MUSA::toolkit)

include(mudnnTargets)
list(APPEND DEPENDENT_TARGETS mudnn)

find_package(MCCL REQUIRED)
add_library(MUSA::mccl SHARED IMPORTED)
set_target_properties(MUSA::mccl PROPERTIES
IMPORTED_LOCATION ${MCCL_LIBRARIES}
INTERFACE_INCLUDE_DIRECTORIES ${MCCL_INCLUDE_DIRS}
)
list(APPEND DEPENDENT_TARGETS MUSA::mccl)

find_package(Python COMPONENTS Interpreter Development.Module REQUIRED)
list(APPEND DEPENDENT_TARGETS Python::Module)

execute_process(
COMMAND
${Python_EXECUTABLE} -c "import os, torch_musa;print(os.path.dirname(torch_musa.__file__))"
ERROR_QUIET
OUTPUT_VARIABLE TORCH_MUSA_PYTHONPATH
)
string(REGEX REPLACE "^(.+)\n$" "\\1" TORCH_MUSA_PYTHONPATH ${TORCH_MUSA_PYTHONPATH})

add_library(torch_musa_python SHARED IMPORTED)
set_target_properties(torch_musa_python PROPERTIES
IMPORTED_LOCATION "${TORCH_MUSA_PYTHONPATH}/lib/libmusa_python.so"
)
set_property(TARGET torch_musa_python APPEND PROPERTY
INTERFACE_INCLUDE_DIRECTORIES "${TORCH_MUSA_PYTHONPATH}/.."
)
set_property(TARGET torch_musa_python APPEND PROPERTY
INTERFACE_INCLUDE_DIRECTORIES "${TORCH_MUSA_PYTHONPATH}/share/torch_musa_codegen"
)
set_property(TARGET torch_musa_python APPEND PROPERTY
INTERFACE_INCLUDE_DIRECTORIES "${TORCH_MUSA_PYTHONPATH}/share/generated_cuda_compatible/include"
)
set_property(TARGET torch_musa_python APPEND PROPERTY
INTERFACE_INCLUDE_DIRECTORIES "${TORCH_MUSA_PYTHONPATH}/share/generated_cuda_compatible/include/torch/csrc/api/include"
)
list(APPEND DEPENDENT_TARGETS torch_musa_python)

execute_process(
COMMAND
${Python_EXECUTABLE} -c "import os, torch;print(os.path.dirname(torch.__file__))"
ERROR_QUIET
OUTPUT_VARIABLE TORCH_PYTHONPATH
)
string(REGEX REPLACE "^(.+)\n$" "\\1" TORCH_PYTHONPATH ${TORCH_PYTHONPATH})

add_library(torch_python SHARED IMPORTED)
set_target_properties(torch_python PROPERTIES
IMPORTED_LOCATION "${TORCH_PYTHONPATH}/lib/libtorch_python.so"
)
list(APPEND DEPENDENT_TARGETS torch_python)

# Configure Transformer Engine library
set(transformer_engine_SOURCES)
set(PLUGIN_NAME "transformer_engine")
list(APPEND transformer_engine_SOURCES
common.mu
transformer_engine.cpp
activation/gelu.mu
activation/relu.mu
activation/swiglu.mu
comm_gemm_overlap/comm_gemm_overlap.cpp
comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/userbuffers/userbuffers.mu
fused_attn/fused_attn.cpp
fused_attn/thd_utils.mu
fused_rope/fused_rope.mu
fused_softmax/scaled_aligned_causal_masked_softmax.mu
fused_softmax/scaled_masked_softmax.mu
fused_softmax/scaled_upper_triang_masked_softmax.mu
gemm/mudnn_gemm.cpp
permutation/permutation.mu
permutation/permutation_mask.mu
recipe/delayed_scaling.mu
swizzle/swizzle.mu
transpose/multi_cast_transpose.mu
transpose/cast_transpose_fusion.mu
transpose/transpose_fusion.mu
transpose/transpose.mu
transpose/cast_transpose.mu
util/cast.mu
util/musa_driver.cpp
util/musa_runtime.cpp
util/padding.mu
util/rtc.cpp
util/system.cpp
util/mtfp8_cast_transpose.mu
util/mtfp8_dequantize.mu
)
set_source_files_properties(${transformer_engine_SOURCES}
PROPERTIES
MUSA_SOURCE_PROPERTY_FORMAT OBJ
)

musa_add_library(${PLUGIN_NAME} SHARED ${transformer_engine_SOURCES})
target_include_directories(${PLUGIN_NAME} PUBLIC
"${CMAKE_CURRENT_SOURCE_DIR}/.."
"${CMAKE_CURRENT_SOURCE_DIR}/include"
)
target_link_libraries(${PLUGIN_NAME} PUBLIC ${DEPENDENT_TARGETS})

# Compiling Userbuffers with native MPI bootstrapping requires linking against MPI
option(NVTE_UB_WITH_MPI "Bootstrap Userbuffers with MPI" OFF)
if (NVTE_UB_WITH_MPI)
find_package(MPI REQUIRED)
target_link_libraries(${PLUGIN_NAME} PUBLIC MPI::MPI_CXX)
target_compile_definitions(${PLUGIN_NAME} PUBLIC NVTE_UB_WITH_MPI)
endif()

# Helper functions to make header files with C++ strings
function(make_string_header STRING STRING_NAME)
configure_file(
"util/string_header.h.in"
"string_headers/${STRING_NAME}.h"
@ONLY
)
endfunction()
function(make_string_header_from_file file_ STRING_NAME)
file(READ "${file_}" STRING)
configure_file(
util/string_header.h.in
"string_headers/${STRING_NAME}.h"
@ONLY
)
endfunction()

# Header files with C++ strings
make_string_header(
"${MUSA_DIR}/include"
string_path_musa_include
)
make_string_header_from_file(
transpose/rtc/cast_transpose_fusion.mu
string_code_transpose_rtc_cast_transpose_fusion_mu
)
make_string_header_from_file(
transpose/rtc/cast_transpose.mu
string_code_transpose_rtc_cast_transpose_mu
)
make_string_header_from_file(
transpose/rtc/transpose.mu
string_code_transpose_rtc_transpose_mu
)
make_string_header_from_file(
utils.muh
string_code_utils_muh
)
make_string_header_from_file(
util/math.h
string_code_util_math_h
)
target_include_directories(${PLUGIN_NAME} PRIVATE
"${CMAKE_CURRENT_BINARY_DIR}/string_headers"
)

set_target_properties(${PLUGIN_NAME} PROPERTIES INSTALL_RPATH_USE_LINK_PATH ON)

# Install library
install(TARGETS ${PLUGIN_NAME} DESTINATION .)
Loading