diff --git a/python/llaisys/libllaisys/__init__.py b/python/llaisys/libllaisys/__init__.py index f536fb52..1ab673e7 100644 --- a/python/llaisys/libllaisys/__init__.py +++ b/python/llaisys/libllaisys/__init__.py @@ -12,7 +12,7 @@ from .tensor import llaisysTensor_t from .tensor import load_tensor from .ops import load_ops - +from .models import load_models, LlaisysQwen2Meta, LlaisysQwen2Weights def load_shared_library(): lib_dir = Path(__file__).parent @@ -38,6 +38,7 @@ def load_shared_library(): load_runtime(LIB_LLAISYS) load_tensor(LIB_LLAISYS) load_ops(LIB_LLAISYS) +load_models(LIB_LLAISYS) __all__ = [ @@ -52,4 +53,5 @@ def load_shared_library(): "llaisysMemcpyKind_t", "MemcpyKind", "llaisysStream_t", + "LlaisysQwen2Meta", "LlaisysQwen2Weights", ] diff --git a/python/llaisys/libllaisys/models.py b/python/llaisys/libllaisys/models.py new file mode 100644 index 00000000..22a6602f --- /dev/null +++ b/python/llaisys/libllaisys/models.py @@ -0,0 +1,72 @@ +from ctypes import Structure, POINTER, c_size_t, c_float, c_int, c_int64 +from .llaisys_types import llaisysDataType_t, llaisysDeviceType_t +from .tensor import llaisysTensor_t + + +class LlaisysQwen2Meta(Structure): + _fields_ = [ + ("dtype", llaisysDataType_t), + ("nlayer", c_size_t), + ("hs", c_size_t), + ("nh", c_size_t), + ("nkvh", c_size_t), + ("dh", c_size_t), + ("di", c_size_t), + ("maxseq", c_size_t), + ("voc", c_size_t), + ("epsilon", c_float), + ("theta", c_float), + ("end_token", c_int64), + ] + + +class LlaisysQwen2Weights(Structure): + _fields_ = [ + ("in_embed", llaisysTensor_t), + ("out_embed", llaisysTensor_t), + ("out_norm_w", llaisysTensor_t), + ("attn_norm_w", POINTER(llaisysTensor_t)), + ("attn_q_w", POINTER(llaisysTensor_t)), + ("attn_q_b", POINTER(llaisysTensor_t)), + ("attn_k_w", POINTER(llaisysTensor_t)), + ("attn_k_b", POINTER(llaisysTensor_t)), + ("attn_v_w", POINTER(llaisysTensor_t)), + ("attn_v_b", POINTER(llaisysTensor_t)), + ("attn_o_w", POINTER(llaisysTensor_t)), + ("mlp_norm_w", POINTER(llaisysTensor_t)), + ("mlp_gate_w", POINTER(llaisysTensor_t)), + ("mlp_up_w", POINTER(llaisysTensor_t)), + ("mlp_down_w", POINTER(llaisysTensor_t)), + ] + +# Opaque handle +class LlaisysQwen2Model(Structure): + pass # 定义一个空结构体用于类型占位 + +LlaisysQwen2Model_p = POINTER(LlaisysQwen2Model) + +def load_models(lib): + # llaisysQwen2ModelCreate + lib.llaisysQwen2ModelCreate.argtypes = [ + POINTER(LlaisysQwen2Meta), + llaisysDeviceType_t, + POINTER(c_int), # device_ids + c_int # ndevice + ] + lib.llaisysQwen2ModelCreate.restype = LlaisysQwen2Model_p + + # llaisysQwen2ModelDestroy + lib.llaisysQwen2ModelDestroy.argtypes = [LlaisysQwen2Model_p] + lib.llaisysQwen2ModelDestroy.restype = None + + # llaisysQwen2ModelWeights + lib.llaisysQwen2ModelWeights.argtypes = [LlaisysQwen2Model_p] + lib.llaisysQwen2ModelWeights.restype = POINTER(LlaisysQwen2Weights) + + # llaisysQwen2ModelInfer + lib.llaisysQwen2ModelInfer.argtypes = [ + LlaisysQwen2Model_p, + POINTER(c_int64), # token_ids + c_size_t # ntoken + ] + lib.llaisysQwen2ModelInfer.restype = c_int64 \ No newline at end of file diff --git a/python/llaisys/models/qwen2.py b/python/llaisys/models/qwen2.py index 0d07b0b2..3ba0000b 100644 --- a/python/llaisys/models/qwen2.py +++ b/python/llaisys/models/qwen2.py @@ -1,33 +1,172 @@ from typing import Sequence from ..libllaisys import LIB_LLAISYS from ..libllaisys import DeviceType - +from ..libllaisys.models import LlaisysQwen2Meta, LlaisysQwen2Weights +from ..tensor import Tensor +import ctypes +from ctypes import POINTER, c_int, c_int64, c_float, c_size_t from pathlib import Path -import safetensors - +import json +import numpy as np +import mmap +import struct class Qwen2: def __init__(self, model_path, device: DeviceType = DeviceType.CPU): - # TODO: Implement model constructor - model_path = Path(model_path) + + # 1. Load Config + with open(model_path / "config.json", "r") as f: + config = json.load(f) + + # 2. Prepare Meta + self.meta = LlaisysQwen2Meta() + self.meta.dtype = 19 # BF16 + self.meta.nlayer = config["num_hidden_layers"] + self.meta.hs = config["hidden_size"] + self.meta.nh = config["num_attention_heads"] + self.meta.nkvh = config["num_key_value_heads"] + self.meta.dh = self.meta.hs // self.meta.nh + self.meta.di = config["intermediate_size"] + self.meta.maxseq = 2048 + self.meta.voc = config["vocab_size"] + self.meta.epsilon = config["rms_norm_eps"] + self.meta.theta = config.get("rope_theta", 10000.0) + self.meta.end_token = 151643 # <|end_of_text|> + + # 3. Create C Model + device_ids = (c_int * 1)(0) + self._model_handle = LIB_LLAISYS.llaisysQwen2ModelCreate( + ctypes.byref(self.meta), + device.value, + device_ids, + 1 + ) + + # 4. Get Weights Structure Pointers + self.weights_ptr = LIB_LLAISYS.llaisysQwen2ModelWeights(self._model_handle).contents + # 5. Load Weights Manually + print("Loading weights...") + self._load_safetensors_manually(model_path) + + def _load_safetensors_manually(self, model_path: Path): + """ + Manually parse safetensors headers and mmap data as uint16. + """ for file in sorted(model_path.glob("*.safetensors")): - data_ = safetensors.safe_open(file, framework="numpy", device="cpu") - for name_ in data_.keys(): - ## TODO: load the model weights - pass + with open(file, 'rb') as f: + # Read Header Size + header_len_bytes = f.read(8) + if not header_len_bytes: continue + header_len = struct.unpack('=2.4.0 transformers accelerate + numpy>=1.21.0 [options.package_data] llaisys = libllaisys/*.so libllaisys/*.dll - libllaisys/*.dylib + libllaisys/*.dylib \ No newline at end of file diff --git a/src/core/runtime/runtime.cpp b/src/core/runtime/runtime.cpp index 7f03a862..e867e3a9 100644 --- a/src/core/runtime/runtime.cpp +++ b/src/core/runtime/runtime.cpp @@ -2,6 +2,9 @@ #include "../../device/runtime_api.hpp" #include "../allocator/naive_allocator.hpp" +#ifdef ENABLE_NVIDIA_API +#include "../../device/nvidia/nvidia_resource.cuh" +#endif namespace llaisys::core { Runtime::Runtime(llaisysDeviceType_t device_type, int device_id) @@ -9,6 +12,15 @@ Runtime::Runtime(llaisysDeviceType_t device_type, int device_id) _api = llaisys::device::getRuntimeAPI(_device_type); _stream = _api->create_stream(); _allocator = new allocators::NaiveAllocator(_api); + if (device_type == LLAISYS_DEVICE_NVIDIA) { +#ifdef ENABLE_NVIDIA_API + // 只有在 CUDA 工具链存在的环境下,才会被编译进二进制文件 + _device_resource = new llaisys::device::nvidia::Resource(device_id); +#else + // 在无 CUDA 环境下抛出绝对明确的运行时异常 + throw std::runtime_error("Llaisys Runtime Error: The framework was compiled without NVIDIA backend support. No CUDA toolkit detected during build."); +#endif +} } Runtime::~Runtime() { @@ -19,6 +31,10 @@ Runtime::~Runtime() { _allocator = nullptr; _api->destroy_stream(_stream); _api = nullptr; + if (_device_resource) { + delete _device_resource; + _device_resource = nullptr; + } } void Runtime::_activate() { diff --git a/src/core/runtime/runtime.hpp b/src/core/runtime/runtime.hpp index 43235824..8dec4c55 100644 --- a/src/core/runtime/runtime.hpp +++ b/src/core/runtime/runtime.hpp @@ -3,6 +3,7 @@ #include "../../device/runtime_api.hpp" #include "../allocator/allocator.hpp" +#include "../../device/device_resource.hpp" namespace llaisys::core { class Runtime { @@ -11,6 +12,7 @@ class Runtime { int _device_id; const LlaisysRuntimeAPI *_api; MemoryAllocator *_allocator; + llaisys::device::DeviceResource *_device_resource = nullptr; bool _is_active; void _activate(); void _deactivate(); @@ -35,7 +37,7 @@ class Runtime { bool isActive() const; const LlaisysRuntimeAPI *api() const; - + llaisys::device::DeviceResource *deviceResource() const { return _device_resource; } storage_t allocateDeviceStorage(size_t size); ; storage_t allocateHostStorage(size_t size); diff --git a/src/device/device_resource.hpp b/src/device/device_resource.hpp index e9062e51..2a337ee0 100644 --- a/src/device/device_resource.hpp +++ b/src/device/device_resource.hpp @@ -14,7 +14,7 @@ class DeviceResource { : _device_type(device_type), _device_id(device_id) { } - ~DeviceResource() = default; + virtual ~DeviceResource() = default; llaisysDeviceType_t getDeviceType() const { return _device_type; } int getDeviceId() const { return _device_id; }; diff --git a/src/device/nvidia/nvidia_resource.cu b/src/device/nvidia/nvidia_resource.cu index 2e63647e..138189c6 100644 --- a/src/device/nvidia/nvidia_resource.cu +++ b/src/device/nvidia/nvidia_resource.cu @@ -1,7 +1,29 @@ #include "nvidia_resource.cuh" +#include +#include +#include namespace llaisys::device::nvidia { -Resource::Resource(int device_id) : llaisys::device::DeviceResource(LLAISYS_DEVICE_NVIDIA, device_id) {} +Resource::Resource(int device_id) : llaisys::device::DeviceResource(LLAISYS_DEVICE_NVIDIA, device_id) { + cudaError_t err=cudaSetDevice(device_id); + if (err != cudaSuccess) { + throw std::runtime_error("Failed to set CUDA device in Resource constructor"); + } + // Create cuBLAS handle + cublasStatus_t status=cublasCreate(&_cublas_handle); + if (status != CUBLAS_STATUS_SUCCESS) { + throw std::runtime_error("Failed to create cuBLAS handle in Resource constructor"); + } + + +} +Resource::~Resource() { + // Destroy cuBLAS handle + if (_cublas_handle) { + cublasDestroy(_cublas_handle); + _cublas_handle = nullptr; + } +} } // namespace llaisys::device::nvidia diff --git a/src/device/nvidia/nvidia_resource.cuh b/src/device/nvidia/nvidia_resource.cuh index a3002170..99332366 100644 --- a/src/device/nvidia/nvidia_resource.cuh +++ b/src/device/nvidia/nvidia_resource.cuh @@ -1,11 +1,16 @@ #pragma once #include "../device_resource.hpp" - +#include "cublas_v2.h" namespace llaisys::device::nvidia { class Resource : public llaisys::device::DeviceResource { public: Resource(int device_id); ~Resource(); + + cublasHandle_t cublasHandle() const{ return _cublas_handle; } + +private: + cublasHandle_t _cublas_handle=nullptr; }; } // namespace llaisys::device::nvidia diff --git a/src/device/nvidia/nvidia_runtime_api.cu b/src/device/nvidia/nvidia_runtime_api.cu index cab92826..8a879eeb 100644 --- a/src/device/nvidia/nvidia_runtime_api.cu +++ b/src/device/nvidia/nvidia_runtime_api.cu @@ -1,56 +1,102 @@ #include "../runtime_api.hpp" +#include "../../utils/check.hpp" + +#include +#include #include #include +#define CUDA_CHECK(call) \ + do { \ + cudaError_t err = call; \ + if (err != cudaSuccess) { \ + std::cerr << "CUDA Error: " << cudaGetErrorString(err) \ + << " (" << err << ") at " << __FILE__ << ":" << __LINE__ << std::endl; \ + std::abort(); \ + } \ + } while (0) + namespace llaisys::device::nvidia { namespace runtime_api { int getDeviceCount() { - TO_BE_IMPLEMENTED(); + int count=0; + if(cudaGetDeviceCount(&count)!=cudaSuccess) + { + return 0; + } + return count; } -void setDevice(int) { - TO_BE_IMPLEMENTED(); +void setDevice(int device_id) { + CUDA_CHECK(cudaSetDevice(device_id)); } void deviceSynchronize() { - TO_BE_IMPLEMENTED(); + CUDA_CHECK(cudaDeviceSynchronize()); } llaisysStream_t createStream() { - TO_BE_IMPLEMENTED(); + cudaStream_t stream; + CUDA_CHECK(cudaStreamCreate(&stream)); + return static_cast(stream); } void destroyStream(llaisysStream_t stream) { - TO_BE_IMPLEMENTED(); + CUDA_CHECK(cudaStreamDestroy(static_cast(stream))); } void streamSynchronize(llaisysStream_t stream) { - TO_BE_IMPLEMENTED(); + CUDA_CHECK(cudaStreamSynchronize(static_cast(stream))); } void *mallocDevice(size_t size) { - TO_BE_IMPLEMENTED(); + void *ptr=nullptr; + CUDA_CHECK(cudaMalloc(&ptr, size)); + return ptr; } void freeDevice(void *ptr) { - TO_BE_IMPLEMENTED(); + if(ptr!=nullptr) + { + CUDA_CHECK(cudaFree(ptr)); + } } void *mallocHost(size_t size) { - TO_BE_IMPLEMENTED(); + void *ptr=nullptr; + CUDA_CHECK(cudaMallocHost(&ptr, size)); + return ptr; } void freeHost(void *ptr) { - TO_BE_IMPLEMENTED(); + if(ptr!=nullptr) + { + CUDA_CHECK(cudaFreeHost(ptr)); + } } +static cudaMemcpyKind toCudaMemcpyKind(llaisysMemcpyKind_t kind) { + switch (kind) { + case LLAISYS_MEMCPY_H2H: + return cudaMemcpyHostToHost; + case LLAISYS_MEMCPY_H2D: + return cudaMemcpyHostToDevice; + case LLAISYS_MEMCPY_D2H: + return cudaMemcpyDeviceToHost; + case LLAISYS_MEMCPY_D2D: + return cudaMemcpyDeviceToDevice; + default: + return cudaMemcpyDefault; + } +} void memcpySync(void *dst, const void *src, size_t size, llaisysMemcpyKind_t kind) { - TO_BE_IMPLEMENTED(); + CUDA_CHECK(cudaMemcpy(dst, src, size, toCudaMemcpyKind(kind))); } -void memcpyAsync(void *dst, const void *src, size_t size, llaisysMemcpyKind_t kind) { - TO_BE_IMPLEMENTED(); +void memcpyAsync(void *dst, const void *src, size_t size, llaisysMemcpyKind_t kind,llaisysStream_t stream) { + cudaStream_t cuda_stream = static_cast(stream); + CUDA_CHECK(cudaMemcpyAsync(dst, src, size, toCudaMemcpyKind(kind), cuda_stream)); } static const LlaisysRuntimeAPI RUNTIME_API = { diff --git a/src/llaisys/models/qwen2.cc b/src/llaisys/models/qwen2.cc new file mode 100644 index 00000000..5ea05f1f --- /dev/null +++ b/src/llaisys/models/qwen2.cc @@ -0,0 +1,70 @@ +#include "llaisys/models/qwen2.h" +#include "../../models/qwen2/qwen2.hpp" +#include "../llaisys_tensor.hpp" + +// Helper +llaisysTensor_t to_c(llaisys::tensor_t t) { return new LlaisysTensor{t}; } + +__C { + +struct LlaisysQwen2Model { + llaisys::models::qwen2::Qwen2Model *model; + struct LlaisysQwen2Weights c_weights; + // Buffers to hold array pointers + std::vector p_attn_norm, p_attn_q_w, p_attn_q_b, p_attn_k_w, p_attn_k_b, p_attn_v_w, p_attn_v_b, p_attn_o_w; + std::vector p_mlp_norm, p_mlp_gate, p_mlp_up, p_mlp_down; +}; + +LlaisysQwen2Model *llaisysQwen2ModelCreate(const LlaisysQwen2Meta *meta, llaisysDeviceType_t device, int *device_ids, int ndevice) { + auto qwen = new llaisys::models::qwen2::Qwen2Model(meta, device, ndevice > 0 ? device_ids[0] : 0); + auto wrapper = new LlaisysQwen2Model; + wrapper->model = qwen; + + wrapper->c_weights.in_embed = to_c(qwen->in_embed); + wrapper->c_weights.out_embed = to_c(qwen->out_embed); + wrapper->c_weights.out_norm_w = to_c(qwen->out_norm_w); + + auto fill = [&](const std::vector& src, std::vector& buf, llaisysTensor_t*& dst_ptr) { + buf.resize(src.size()); + for(size_t i=0; iattn_norm_w, wrapper->p_attn_norm, wrapper->c_weights.attn_norm_w); + fill(qwen->attn_q_w, wrapper->p_attn_q_w, wrapper->c_weights.attn_q_w); + fill(qwen->attn_q_b, wrapper->p_attn_q_b, wrapper->c_weights.attn_q_b); + fill(qwen->attn_k_w, wrapper->p_attn_k_w, wrapper->c_weights.attn_k_w); + fill(qwen->attn_k_b, wrapper->p_attn_k_b, wrapper->c_weights.attn_k_b); + fill(qwen->attn_v_w, wrapper->p_attn_v_w, wrapper->c_weights.attn_v_w); + fill(qwen->attn_v_b, wrapper->p_attn_v_b, wrapper->c_weights.attn_v_b); + fill(qwen->attn_o_w, wrapper->p_attn_o_w, wrapper->c_weights.attn_o_w); + fill(qwen->mlp_norm_w, wrapper->p_mlp_norm, wrapper->c_weights.mlp_norm_w); + fill(qwen->mlp_gate_w, wrapper->p_mlp_gate, wrapper->c_weights.mlp_gate_w); + fill(qwen->mlp_up_w, wrapper->p_mlp_up, wrapper->c_weights.mlp_up_w); + fill(qwen->mlp_down_w, wrapper->p_mlp_down, wrapper->c_weights.mlp_down_w); + + return wrapper; +} + +void llaisysQwen2ModelDestroy(LlaisysQwen2Model * model) { + if (model) { + auto free_t = [](llaisysTensor_t t) { if(t) delete t; }; + free_t(model->c_weights.in_embed); + free_t(model->c_weights.out_embed); + free_t(model->c_weights.out_norm_w); + for(auto t : model->p_attn_norm) free_t(t); + // ... free other vectors ... + delete model->model; + delete model; + } +} + +LlaisysQwen2Weights *llaisysQwen2ModelWeights(LlaisysQwen2Model * model) { + return &model->c_weights; +} + +int64_t llaisysQwen2ModelInfer(LlaisysQwen2Model * model, int64_t * token_ids, size_t ntoken) { + return model->model->forward(token_ids, ntoken); +} + +} \ No newline at end of file diff --git a/src/models/qwen2/qwen2.cpp b/src/models/qwen2/qwen2.cpp new file mode 100644 index 00000000..3cca959a --- /dev/null +++ b/src/models/qwen2/qwen2.cpp @@ -0,0 +1,191 @@ +#include "qwen2.hpp" +#include "../../core/context/context.hpp" // 新增:用于访问 memcpy_sync +#include +#include +#include // for memcpy if needed, but we use runtime api + +namespace llaisys::models::qwen2 { + +Qwen2Model::Qwen2Model(const LlaisysQwen2Meta *meta, llaisysDeviceType_t device, int device_id) { + _meta = *meta; + init_weights(device, device_id); + init_kv_cache(device, device_id); +} + +void Qwen2Model::init_weights(llaisysDeviceType_t device, int device_id) { + auto dtype = _meta.dtype; + + // Helper to create empty tensor + auto mk = [&](const std::vector& shape) { + return Tensor::create(shape, dtype, device, device_id); + }; + + in_embed = mk({_meta.voc, _meta.hs}); + out_embed = mk({_meta.voc, _meta.hs}); + out_norm_w = mk({_meta.hs}); + + auto init_layer_w = [&](std::vector& vec, const std::vector& shape) { + vec.resize(_meta.nlayer); + for(size_t i=0; i<_meta.nlayer; ++i) vec[i] = mk(shape); + }; + + init_layer_w(attn_norm_w, {_meta.hs}); + init_layer_w(attn_q_w, {_meta.nh * _meta.dh, _meta.hs}); + init_layer_w(attn_q_b, {_meta.nh * _meta.dh}); + init_layer_w(attn_k_w, {_meta.nkvh * _meta.dh, _meta.hs}); + init_layer_w(attn_k_b, {_meta.nkvh * _meta.dh}); + init_layer_w(attn_v_w, {_meta.nkvh * _meta.dh, _meta.hs}); + init_layer_w(attn_v_b, {_meta.nkvh * _meta.dh}); + init_layer_w(attn_o_w, {_meta.hs, _meta.nh * _meta.dh}); + + init_layer_w(mlp_norm_w, {_meta.hs}); + init_layer_w(mlp_gate_w, {_meta.di, _meta.hs}); + init_layer_w(mlp_up_w, {_meta.di, _meta.hs}); + init_layer_w(mlp_down_w, {_meta.hs, _meta.di}); +} + +void Qwen2Model::init_kv_cache(llaisysDeviceType_t device, int device_id) { + k_cache.resize(_meta.nlayer); + v_cache.resize(_meta.nlayer); + for(size_t i=0; i<_meta.nlayer; ++i) { + // [max_seq, nkvh, dh] + k_cache[i] = Tensor::create({_meta.maxseq, _meta.nkvh, _meta.dh}, _meta.dtype, device, device_id); + v_cache[i] = Tensor::create({_meta.maxseq, _meta.nkvh, _meta.dh}, _meta.dtype, device, device_id); + } +} + +int64_t Qwen2Model::forward(const int64_t *token_ids_ptr, size_t ntoken) { + auto device = in_embed->deviceType(); + int device_id = in_embed->deviceId(); + auto dtype = _meta.dtype; + + // 0. Inputs + auto input_ids_host = Tensor::create({ntoken}, LLAISYS_DTYPE_I64, LLAISYS_DEVICE_CPU, 0); + input_ids_host->load(token_ids_ptr); + + tensor_t input_ids; + if (device == LLAISYS_DEVICE_CPU) { + input_ids = input_ids_host; + } else { + input_ids = Tensor::create({ntoken}, LLAISYS_DTYPE_I64, device, device_id); + llaisys::core::context().runtime().api()->memcpy_sync(input_ids->data(), input_ids_host->data(), ntoken * sizeof(int64_t), LLAISYS_MEMCPY_H2D); + } + + std::vector pos_vec(ntoken); + for(size_t i=0; iload(pos_vec.data()); + + tensor_t pos_ids; + if (device == LLAISYS_DEVICE_CPU) { + pos_ids = pos_ids_host; + } else { + pos_ids = Tensor::create({ntoken}, LLAISYS_DTYPE_I64, device, device_id); + llaisys::core::context().runtime().api()->memcpy_sync(pos_ids->data(), pos_ids_host->data(), ntoken * sizeof(int64_t), LLAISYS_MEMCPY_H2D); + } + + auto mk = [&](const std::vector& s) { return Tensor::create(s, dtype, device, device_id); }; + + // 1. Embedding + auto hidden_states = mk({ntoken, _meta.hs}); + ops::embedding(hidden_states, input_ids, in_embed); + + // 2. Layers + for (size_t i = 0; i < _meta.nlayer; ++i) { + auto residual = hidden_states; + + // Attention Block + auto hidden_norm = mk({ntoken, _meta.hs}); + ops::rms_norm(hidden_norm, hidden_states, attn_norm_w[i], _meta.epsilon); + + auto q_flat = mk({ntoken, _meta.nh * _meta.dh}); + ops::linear(q_flat, hidden_norm, attn_q_w[i], attn_q_b[i]); + auto q = q_flat->view({ntoken, _meta.nh, _meta.dh}); + + auto k_flat = mk({ntoken, _meta.nkvh * _meta.dh}); + ops::linear(k_flat, hidden_norm, attn_k_w[i], attn_k_b[i]); + auto k = k_flat->view({ntoken, _meta.nkvh, _meta.dh}); + + auto v_flat = mk({ntoken, _meta.nkvh * _meta.dh}); + ops::linear(v_flat, hidden_norm, attn_v_w[i], attn_v_b[i]); + auto v = v_flat->view({ntoken, _meta.nkvh, _meta.dh}); + + ops::rope(q, q, pos_ids, _meta.theta); + ops::rope(k, k, pos_ids, _meta.theta); + + // KV Cache Update + auto k_slot = k_cache[i]->slice(0, _current_pos, _current_pos + ntoken); + auto v_slot = v_cache[i]->slice(0, _current_pos, _current_pos + ntoken); + + llaisysMemcpyKind_t kind = (device == LLAISYS_DEVICE_CPU) ? LLAISYS_MEMCPY_H2H : LLAISYS_MEMCPY_D2D; + size_t copy_bytes = k->numel() * k->elementSize(); + + // 执行拷贝: k -> k_slot + llaisys::core::context().runtime().api()->memcpy_sync( + k_slot->data(), + k->data(), + copy_bytes, + kind + ); + // 执行拷贝: v -> v_slot + llaisys::core::context().runtime().api()->memcpy_sync( + v_slot->data(), + v->data(), + copy_bytes, + kind + ); + + // Attention + auto k_active = k_cache[i]->slice(0, 0, _current_pos + ntoken); + auto v_active = v_cache[i]->slice(0, 0, _current_pos + ntoken); + + auto attn_out_view = mk({ntoken, _meta.nh, _meta.dh}); + float scale = 1.0f / std::sqrt(static_cast(_meta.dh)); + ops::self_attention(attn_out_view, q, k_active, v_active, scale); + + auto attn_out_flat = attn_out_view->view({ntoken, _meta.nh * _meta.dh}); + auto attn_proj = mk({ntoken, _meta.hs}); + ops::linear(attn_proj, attn_out_flat, attn_o_w[i], nullptr); + + // Residual Add + ops::add(hidden_states, hidden_states, attn_proj); + + // MLP Block + auto mlp_norm = mk({ntoken, _meta.hs}); + ops::rms_norm(mlp_norm, hidden_states, mlp_norm_w[i], _meta.epsilon); + + auto gate = mk({ntoken, _meta.di}); + ops::linear(gate, mlp_norm, mlp_gate_w[i], nullptr); + + auto up = mk({ntoken, _meta.di}); + ops::linear(up, mlp_norm, mlp_up_w[i], nullptr); + + auto mlp_act = mk({ntoken, _meta.di}); + ops::swiglu(mlp_act, gate, up); + + auto mlp_out = mk({ntoken, _meta.hs}); + ops::linear(mlp_out, mlp_act, mlp_down_w[i], nullptr); + + ops::add(hidden_states, hidden_states, mlp_out); + } + + // 3. Final + auto last_hidden = hidden_states->slice(0, ntoken - 1, ntoken); + auto final_norm = mk({1, _meta.hs}); + ops::rms_norm(final_norm, last_hidden, out_norm_w, _meta.epsilon); + + auto logits = mk({1, _meta.voc}); + ops::linear(logits, final_norm, out_embed, nullptr); + + auto max_idx = Tensor::create({1}, LLAISYS_DTYPE_I64, device, device_id); + auto max_val = mk({1}); + ops::argmax(max_idx, max_val, logits); + + int64_t next_token; + llaisys::core::context().runtime().api()->memcpy_sync(&next_token, max_idx->data(), sizeof(int64_t), LLAISYS_MEMCPY_D2H); + + _current_pos += ntoken; + return next_token; +} + +} \ No newline at end of file diff --git a/src/models/qwen2/qwen2.hpp b/src/models/qwen2/qwen2.hpp new file mode 100644 index 00000000..f1408ee4 --- /dev/null +++ b/src/models/qwen2/qwen2.hpp @@ -0,0 +1,57 @@ +#pragma once +#include "llaisys/models/qwen2.h" +#include "../../tensor/tensor.hpp" +#include "llaisys/ops.h" +#include "../../ops/add/op.hpp" +#include "../../ops/argmax/op.hpp" +#include "../../ops/embedding/op.hpp" +#include "../../ops/linear/op.hpp" +#include "../../ops/rms_norm/op.hpp" +#include "../../ops/rope/op.hpp" +#include "../../ops/self_attention/op.hpp" +#include "../../ops/swiglu/op.hpp" +#include "../../ops/rearrange/op.hpp" +#include + +namespace llaisys::models::qwen2 { + +class Qwen2Model { +public: + LlaisysQwen2Meta _meta; + LlaisysQwen2Weights _weights_ptr; // 用于传递给 Python 的指针结构 + + // 实际存储权重的容器 + tensor_t in_embed; + tensor_t out_embed; + tensor_t out_norm_w; + + // Layers weights (Vector of tensors) + std::vector attn_norm_w; + std::vector attn_q_w; + std::vector attn_q_b; + std::vector attn_k_w; + std::vector attn_k_b; + std::vector attn_v_w; + std::vector attn_v_b; + std::vector attn_o_w; + std::vector mlp_norm_w; + std::vector mlp_gate_w; + std::vector mlp_up_w; + std::vector mlp_down_w; + + // KV Cache + std::vector k_cache; + std::vector v_cache; + size_t _current_pos = 0; + + Qwen2Model(const LlaisysQwen2Meta *meta, llaisysDeviceType_t device, int device_id); + ~Qwen2Model() = default; + + int64_t forward(const int64_t *token_ids, size_t ntoken); + +private: + void init_weights(llaisysDeviceType_t device, int device_id); + void init_kv_cache(llaisysDeviceType_t device, int device_id); +}; + +} \ No newline at end of file diff --git a/src/ops/add/nvidia/add_nv.cu b/src/ops/add/nvidia/add_nv.cu new file mode 100644 index 00000000..ab63f0ed --- /dev/null +++ b/src/ops/add/nvidia/add_nv.cu @@ -0,0 +1,57 @@ +#include "../op.hpp" +#include "../../../tensor/tensor.hpp" + +#include +#include +#include + +namespace llaisys::ops::nvidia{ + template + __global__ void add_kernel(T* c,const T* a,const T* b,size_t n) + { + int idx=blockIdx.x*blockDim.x+threadIdx.x; + if(idx +void launch_add(void* c,const void* a,const void* b,size_t n) +{ + T* d_c=reinterpret_cast(c); + const T* d_a=reinterpret_cast(a); + const T* d_b=reinterpret_cast(b); + + int threads=256; + int blocks=(n+threads-1)/threads; + add_kernel<<>>(d_c,d_a,d_b,n); +} + + +void add(tensor_t output,const tensor_t input1,const tensor_t input2) +{ + size_t numel=output->numel(); + auto dtype=output->dtype(); + + void* d_c=output->data(); + const void* d_a=input1->data(); + const void* d_b=input2->data(); + + switch(dtype) + { + case LLAISYS_DTYPE_F32: + launch_add(d_c,d_a,d_b,numel); + break; + case LLAISYS_DTYPE_F16: + launch_add<__half>(d_c,d_a,d_b,numel); + break; + case LLAISYS_DTYPE_BF16: + launch_add<__nv_bfloat16>(d_c,d_a,d_b,numel); + break; + default: + fprintf(stderr, "[Add] Unsupported DataType on CUDA: %d\n", dtype); + abort(); + } +} +} // namespace llaisys::ops::nvidia \ No newline at end of file diff --git a/src/ops/add/op.cpp b/src/ops/add/op.cpp index a057330d..760e4aa5 100644 --- a/src/ops/add/op.cpp +++ b/src/ops/add/op.cpp @@ -5,6 +5,13 @@ #include "cpu/add_cpu.hpp" +namespace llaisys::ops::nvidia { +#ifdef ENABLE_NVIDIA_API + void add(tensor_t output, tensor_t input1, tensor_t input2); +#endif +} + + namespace llaisys::ops { void add(tensor_t c, tensor_t a, tensor_t b) { CHECK_SAME_DEVICE(c, a, b); @@ -25,8 +32,7 @@ void add(tensor_t c, tensor_t a, tensor_t b) { return cpu::add(c->data(), a->data(), b->data(), c->dtype(), c->numel()); #ifdef ENABLE_NVIDIA_API case LLAISYS_DEVICE_NVIDIA: - TO_BE_IMPLEMENTED(); - return; + return llaisys::ops::nvidia::add(c, a, b); #endif default: EXCEPTION_UNSUPPORTED_DEVICE; diff --git a/src/ops/argmax/cpu/argmax_cpu.cpp b/src/ops/argmax/cpu/argmax_cpu.cpp new file mode 100644 index 00000000..f30befcb --- /dev/null +++ b/src/ops/argmax/cpu/argmax_cpu.cpp @@ -0,0 +1,50 @@ +#include "argmax_cpu.hpp" +#include "../../../utils.hpp" + +namespace { + +template +void argmax_(int64_t *max_idx, T *max_val, const T *vals, size_t numel) { + if (numel == 0) return; + + int64_t best_idx = 0; + T best_val_raw = vals[0]; + float best_val_f = llaisys::utils::cast(vals[0]); + + for (size_t i = 1; i < numel; ++i) { + float curr_val_f = llaisys::utils::cast(vals[i]); + if (curr_val_f > best_val_f) { + best_val_f = curr_val_f; + best_idx = i; + best_val_raw = vals[i]; + } + } + + *max_idx = best_idx; + *max_val = best_val_raw; +} + +} // namespace + +namespace llaisys::ops::cpu { + +void argmax(std::byte *max_idx, std::byte *max_val, const std::byte *vals, llaisysDataType_t dtype, size_t numel) { + + int64_t *max_idx_ptr = reinterpret_cast(max_idx); + + switch (dtype) { + case LLAISYS_DTYPE_F32: + return argmax_(max_idx_ptr, reinterpret_cast(max_val), + reinterpret_cast(vals), numel); + case LLAISYS_DTYPE_BF16: + return argmax_(max_idx_ptr, reinterpret_cast(max_val), + reinterpret_cast(vals), numel); + case LLAISYS_DTYPE_F16: + return argmax_(max_idx_ptr, reinterpret_cast(max_val), + reinterpret_cast(vals), numel); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(dtype); + } +} + +} // namespace llaisys::ops::cpu \ No newline at end of file diff --git a/src/ops/argmax/cpu/argmax_cpu.hpp b/src/ops/argmax/cpu/argmax_cpu.hpp new file mode 100644 index 00000000..b659c6bf --- /dev/null +++ b/src/ops/argmax/cpu/argmax_cpu.hpp @@ -0,0 +1,9 @@ +#pragma once +#include "llaisys.h" +#include + +namespace llaisys::ops::cpu { + +void argmax(std::byte *max_idx, std::byte *max_val, const std::byte *vals, llaisysDataType_t dtype, size_t numel); + +} // namespace llaisys::ops::cpu \ No newline at end of file diff --git a/src/ops/argmax/nvidia/argmax_nv.cu b/src/ops/argmax/nvidia/argmax_nv.cu new file mode 100644 index 00000000..9dba3140 --- /dev/null +++ b/src/ops/argmax/nvidia/argmax_nv.cu @@ -0,0 +1,111 @@ +#include "../op.hpp" +#include "../../../tensor/tensor.hpp" + +#include +#include +#include +#include +#include + +namespace llaisys::ops::nvidia { + +template +struct Pair { + T val; + int64_t idx; +}; + +// argmax 核函数:支持同时输出 index 和 value +template +__global__ void argmax_kernel(int64_t* idx_out, T* val_out, const T* input, int n_cols) { + // 1. 确定行号和线程号 + int row = blockIdx.x; + int tid = threadIdx.x; + + // 2. 确定行首地址 + const T* row_ptr = input + row * n_cols; + + // 初始化局部最大值 + T max_val = -1e20f; // 简单粗暴的极小值 + int64_t max_idx = -1; + + // 3. 遍历该行所有元素,计算最大值和索引 (Grid-Stride Loop) + for (int i = tid; i < n_cols; i += blockDim.x) { + T val = row_ptr[i]; + // 强制转 float 比较,确保半精度类型的兼容性 + if ((float)val > (float)max_val) { + max_val = val; + max_idx = i; + } + } + + // 4. 将每个线程的最大值和索引存入共享内存 + extern __shared__ char s_mem[]; + Pair* s_data = reinterpret_cast*>(s_mem); + s_data[tid] = {max_val, max_idx}; + __syncthreads(); + + // 5. 归约计算最终的最大值和索引 + for (int s = blockDim.x / 2; s > 0; s >>= 1) { + if (tid < s) { + // 比较并交换 + if ((float)s_data[tid + s].val > (float)s_data[tid].val) { + s_data[tid] = s_data[tid + s]; + } + } + __syncthreads(); + } + + // 6. 将结果写入输出 (0号线程负责) + if (tid == 0) { + // 写入索引 + if (idx_out != nullptr) { + idx_out[row] = s_data[0].idx; + } + if (val_out != nullptr) { + val_out[row] = s_data[0].val; + } + } +} + +// Launcher:接收三个 Tensor +template +void launch_argmax(tensor_t max_idx, tensor_t max_val, tensor_t vals) { + int n_cols = vals->shape().back(); + int n_rows = vals->numel() / n_cols; + + // 获取数据指针 + int64_t* d_idx = (max_idx) ? reinterpret_cast(max_idx->data()) : nullptr; + T* d_val = (max_val) ? reinterpret_cast(max_val->data()) : nullptr; + const T* d_in = reinterpret_cast(vals->data()); + + int threads = 256; + int blocks = n_rows; + + size_t shared_mem_size = threads * sizeof(Pair); + + // 传入两个输出指针 + argmax_kernel<<>>(d_idx, d_val, d_in, n_cols); +} + +void argmax(tensor_t max_idx, tensor_t max_val, tensor_t vals) { + // 根据输入数据的类型进行分发 + auto dtype = vals->dtype(); + + switch (dtype) { + case LLAISYS_DTYPE_F32: + launch_argmax(max_idx, max_val, vals); + break; + case LLAISYS_DTYPE_F16: + launch_argmax<__half>(max_idx, max_val, vals); + break; + case LLAISYS_DTYPE_BF16: + launch_argmax<__nv_bfloat16>(max_idx, max_val, vals); + break; + default: + fprintf(stderr, "[Argmax] Unsupported DataType: %d\n", dtype); + abort(); + } +} + +} // namespace llaisys::ops::nvidia \ No newline at end of file diff --git a/src/ops/argmax/op.cpp b/src/ops/argmax/op.cpp index 6dc37d42..31edfe6c 100644 --- a/src/ops/argmax/op.cpp +++ b/src/ops/argmax/op.cpp @@ -1,7 +1,53 @@ #include "op.hpp" +#include "cpu/argmax_cpu.hpp" +#include "../../core/context/context.hpp" +#include "../../utils.hpp" +#include + +namespace llaisys::ops::nvidia { +#ifdef ENABLE_NVIDIA_API + void argmax(tensor_t max_idx, tensor_t max_val, tensor_t vals); +#endif +} namespace llaisys::ops { + void argmax(tensor_t max_idx, tensor_t max_val, tensor_t vals) { - TO_BE_IMPLEMENTED(); + // 基础检查 + CHECK_SAME_DEVICE(max_idx, vals); + + // max_val 是可选的,只有非空时才检查设备 + if (max_val) { + CHECK_SAME_DEVICE(max_val, vals); + //检查连续性 + ASSERT(max_val->isContiguous(), "Argmax max_val output must be contiguous"); + } + + ASSERT(max_idx->dtype() == LLAISYS_DTYPE_I64, "Argmax index output must be int32"); + ASSERT(vals->isContiguous() && max_idx->isContiguous(), "Argmax inputs/outputs must be contiguous"); + + // 获取当前上下文应当运行的设备类型 + auto device_type = vals->deviceType(); + + // 切换设备上下文 + llaisys::core::context().setDevice(device_type, vals->deviceId()); + + switch (device_type) { + case LLAISYS_DEVICE_CPU: { + // 防止 max_val 为空时调用 ->data() 导致崩溃 + std::byte* val_ptr = max_val ? max_val->data() : nullptr; + cpu::argmax(max_idx->data(), val_ptr, vals->data(), vals->dtype(), vals->numel()); + return; + } + +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + return llaisys::ops::nvidia::argmax(max_idx, max_val, vals); +#endif + + default: + throw std::runtime_error("Argmax: Unsupported device type"); + } } -} // namespace llaisys::ops + +} // namespace llaisys::ops \ No newline at end of file diff --git a/src/ops/embedding/cpu/embedding_cpu.cpp b/src/ops/embedding/cpu/embedding_cpu.cpp new file mode 100644 index 00000000..37deb551 --- /dev/null +++ b/src/ops/embedding/cpu/embedding_cpu.cpp @@ -0,0 +1,42 @@ +#include "embedding_cpu.hpp" +#include "../../../utils.hpp" +#include + +namespace{ + template + void embedding_(T* out,const int64_t *index,const T *weight,size_t num_indices,size_t embedding_dim){ + size_t stride_bytes=embedding_dim *sizeof(T); + + for(size_t i=0;i(index); + + switch (dtype) { + case LLAISYS_DTYPE_F32: + return embedding_(reinterpret_cast(out), index_ptr, + reinterpret_cast(weight), num_indices, embedding_dim); + case LLAISYS_DTYPE_F16: + return embedding_(reinterpret_cast(out), index_ptr, + reinterpret_cast(weight), num_indices, embedding_dim); + case LLAISYS_DTYPE_BF16: + return embedding_(reinterpret_cast(out), index_ptr, + reinterpret_cast(weight), num_indices, embedding_dim); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(dtype); + } +} + +} // namespace llaisys::ops::cpu \ No newline at end of file diff --git a/src/ops/embedding/cpu/embedding_cpu.hpp b/src/ops/embedding/cpu/embedding_cpu.hpp new file mode 100644 index 00000000..1287fdca --- /dev/null +++ b/src/ops/embedding/cpu/embedding_cpu.hpp @@ -0,0 +1,7 @@ +#pragma once +#include "llaisys.h" +#include + +namespace llaisys::ops::cpu { + void embedding(std::byte *out, const std::byte *index, const std::byte *weight, llaisysDataType_t dtype, size_t num_indices, size_t embedding_dim); +} // namespace llaisys::ops::cpu \ No newline at end of file diff --git a/src/ops/embedding/nvidia/embedding_nv.cu b/src/ops/embedding/nvidia/embedding_nv.cu new file mode 100644 index 00000000..be6be4b7 --- /dev/null +++ b/src/ops/embedding/nvidia/embedding_nv.cu @@ -0,0 +1,91 @@ +#include "../op.hpp" +#include "../../../tensor/tensor.hpp" + +#include +#include +#include +#include // for int64_t +#include + +namespace llaisys::ops::nvidia { + +// ========================================== +// Kernel: Embedding +// ========================================== +template +__global__ void embedding_kernel( + T* output, + const int64_t* indices, + const T* weight, + int embedding_dim, + size_t n +) { + // 1. 计算当前线程负责 Output 中的第几个元素 + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (idx < n) { + // 2. 计算当前元素对应的 Token 和维度 + size_t token_idx = idx / embedding_dim; + + // dim_idx: 当前处理的是该 Token 向量的第几维 + size_t dim_idx = idx % embedding_dim; + + // 3. 获取真实的查表索引 (Row ID) + int64_t row_id = indices[token_idx]; + + // 4. 计算 Weight 中的源地址 + size_t weight_offset = row_id * embedding_dim + dim_idx; + + // 5. 搬运数据 + output[idx] = weight[weight_offset]; + } +} + +// ========================================== +// Launcher +// ========================================== +template +void launch_embedding(tensor_t output, tensor_t index, tensor_t weight) { + // 1. 获取维度信息 + size_t num_indices = index->numel(); + int embedding_dim = weight->shape().back(); + + // Output 总元素个数 = Token数 * 向量维度 + size_t total_elements = num_indices * embedding_dim; + + // 2. 准备指针 + T* d_out = reinterpret_cast(output->data()); + const int64_t* d_idx = reinterpret_cast(index->data()); + const T* d_weight = reinterpret_cast(weight->data()); + + // 3. 配置 Kernel + int threads = 256; + // 向上取整计算 Blocks + int blocks = (total_elements + threads - 1) / threads; + + embedding_kernel<<>>(d_out, d_idx, d_weight, embedding_dim, total_elements); +} + +// ========================================== +// 入口函数 +// ========================================== +void embedding(tensor_t output, tensor_t index, tensor_t weight) { + auto dtype = weight->dtype(); + + switch (dtype) { + case LLAISYS_DTYPE_F32: + launch_embedding(output, index, weight); + break; + case LLAISYS_DTYPE_F16: + launch_embedding<__half>(output, index, weight); + break; + case LLAISYS_DTYPE_BF16: + launch_embedding<__nv_bfloat16>(output, index, weight); + break; + default: + fprintf(stderr, "[Embedding NVIDIA] Unsupported DataType: %d\n", dtype); + abort(); + } +} + +} // namespace llaisys::ops::nvidia \ No newline at end of file diff --git a/src/ops/embedding/op.cpp b/src/ops/embedding/op.cpp index 84b9a5d0..9c1d415b 100644 --- a/src/ops/embedding/op.cpp +++ b/src/ops/embedding/op.cpp @@ -1,7 +1,43 @@ #include "op.hpp" +#include "cpu/embedding_cpu.hpp" +#include "../../core/context/context.hpp" +#include "../../utils.hpp" +#include + +namespace llaisys::ops::nvidia { +#ifdef ENABLE_NVIDIA_API + void embedding(tensor_t output, tensor_t index, tensor_t weight); +#endif +} namespace llaisys::ops { + void embedding(tensor_t out, tensor_t index, tensor_t weight) { - TO_BE_IMPLEMENTED(); + if (index->dtype() != LLAISYS_DTYPE_I64) { + throw std::invalid_argument("Index tensor must be of type INT64."); + } + + CHECK_SAME_DEVICE(out, index); + CHECK_SAME_DEVICE(out, weight); + + size_t num_indices = index->numel(); + size_t embedding_dim = weight->shape().back(); + + auto device = out->deviceType(); + + llaisys::core::context().setDevice(device, out->deviceId()); + + if (device == LLAISYS_DEVICE_CPU) { + cpu::embedding(out->data(), index->data(), weight->data(), weight->dtype(), num_indices, embedding_dim); + } + #ifdef ENABLE_NVIDIA_API + else if (device == LLAISYS_DEVICE_NVIDIA) { + llaisys::ops::nvidia::embedding(out, index, weight); + } + #endif + else { + throw std::runtime_error("Embedding: Unsupported device type"); + } } -} // namespace llaisys::ops + +} // namespace llaisys::ops \ No newline at end of file diff --git a/src/ops/linear/cpu/linear_cpu.cpp b/src/ops/linear/cpu/linear_cpu.cpp new file mode 100644 index 00000000..b10f30e1 --- /dev/null +++ b/src/ops/linear/cpu/linear_cpu.cpp @@ -0,0 +1,56 @@ +#include "linear_cpu.hpp" +#include "../../../utils.hpp" + +namespace{ + template + void linear_(T* out,const T* input,const T* weight,const T* bias,size_t M,size_t K,size_t N) + { + for(size_t m=0;m(input[m*K+k]); + float w_val=llaisys::utils::cast(weight[n*K+k]); + sum+=x_val*w_val; + } + if(bias) + { + sum+=llaisys::utils::cast(bias[n]); + } + out[m*N+n]=llaisys::utils::cast(sum); + } + } + } +} + +namespace llaisys::ops::cpu +{ + void linear(std::byte *out, const std::byte *input, const std::byte *weight, const std::byte *bias,llaisysDataType_t dtype, size_t M, size_t K, size_t N) + { + switch (dtype) { + case LLAISYS_DTYPE_F32: + return linear_(reinterpret_cast(out), + reinterpret_cast(input), + reinterpret_cast(weight), + reinterpret_cast(bias), // bias 为空时这里也是 nullptr + M, K, N); + case LLAISYS_DTYPE_F16: + return linear_(reinterpret_cast(out), + reinterpret_cast(input), + reinterpret_cast(weight), + reinterpret_cast(bias), + M, K, N); + case LLAISYS_DTYPE_BF16: + return linear_(reinterpret_cast(out), + reinterpret_cast(input), + reinterpret_cast(weight), + reinterpret_cast(bias), + M, K, N); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(dtype); + } + } +}// namespace llaisys::ops::cpu \ No newline at end of file diff --git a/src/ops/linear/cpu/linear_cpu.hpp b/src/ops/linear/cpu/linear_cpu.hpp new file mode 100644 index 00000000..040c7933 --- /dev/null +++ b/src/ops/linear/cpu/linear_cpu.hpp @@ -0,0 +1,8 @@ +#pragma once +#include "llaisys.h" +#include + +namespace llaisys::ops::cpu { +void linear(std::byte *out, const std::byte *input, const std::byte *weight, const std::byte *bias, + llaisysDataType_t dtype, size_t M, size_t K, size_t N); +} // namespace llaisys::ops::cpu \ No newline at end of file diff --git a/src/ops/linear/nvidia/linear_nv.cu b/src/ops/linear/nvidia/linear_nv.cu new file mode 100644 index 00000000..5e8c5240 --- /dev/null +++ b/src/ops/linear/nvidia/linear_nv.cu @@ -0,0 +1,109 @@ +#include "../op.hpp" +#include "../../../tensor/tensor.hpp" +#include "../../../core/context/context.hpp" +#include "../../../device/nvidia/nvidia_resource.cuh" + +#include +#include +#include +#include +#include +#include + + +namespace llaisys::ops::nvidia { + +template +cudaDataType_t get_cuda_datatype(); + +template<> cudaDataType_t get_cuda_datatype(){return CUDA_R_32F;} +template<> cudaDataType_t get_cuda_datatype<__half>(){return CUDA_R_16F;} +template<> cudaDataType_t get_cuda_datatype<__nv_bfloat16>(){return CUDA_R_16BF;} + +// Bias addition kernel +template +__global__ void bias_add_kernel(T* output,const T* bias,int M,int N){ + int idx=blockIdx.x*blockDim.x+threadIdx.x; + int total_elements=M*N; + + if(idx +void launch_linear_kernel(tensor_t output,tensor_t input,tensor_t weight,tensor_t bias){ + // Get cuBLAS handle from NVIDIA resource + auto& ctx=llaisys::core::context(); + auto* nv_resource=dynamic_cast(ctx.runtime().deviceResource()); + if(!nv_resource){ + throw std::runtime_error("NVIDIA Resource not found in linear operator"); + } + cublasHandle_t handle=nv_resource->cublasHandle(); + + // Get dimensions + int K=input->shape().back(); + int M=input->numel()/K; + int N=weight->shape().front(); + + // Set cuBLAS parameters + cublasComputeType_t compute_type=CUBLAS_COMPUTE_32F; + cudaDataType_t input_type=get_cuda_datatype(); + + float alpha=1.0f,beta=0.0f; + + const void* d_in=input->data(); + const void* d_weight=weight->data(); + void* d_out=output->data(); + + // Perform matrix multiplication using cuBLAS + cublasStatus_t status=cublasGemmEx( + handle, + CUBLAS_OP_T,CUBLAS_OP_N, + N,M,K, + &alpha, + d_weight,input_type,K, + d_in,input_type,K, + &beta, + d_out,input_type,N, + compute_type, + CUBLAS_GEMM_DEFAULT + ); + if (status != CUBLAS_STATUS_SUCCESS) { + printf("cuBLAS GemmEx Error: %d\n", status); + } + // Launch bias addition kernel if bias is provided + if(bias){ + int threads=256; + int blocks=(M*N+threads-1)/threads; + + T* t_out=reinterpret_cast(output->data()); + const T* t_bias=reinterpret_cast(bias->data()); + bias_add_kernel<<>>(t_out,t_bias,M,N); + } +} + + +void linear(tensor_t output,tensor_t input,tensor_t weight,tensor_t bias) +{ + auto dtype=input->dtype(); + switch (dtype) { + case LLAISYS_DTYPE_F32: + launch_linear_kernel(output, input, weight, bias); + break; + case LLAISYS_DTYPE_F16: + launch_linear_kernel<__half>(output, input, weight, bias); + break; + case LLAISYS_DTYPE_BF16: + launch_linear_kernel<__nv_bfloat16>(output, input, weight, bias); + break; + default: + fprintf(stderr, "[Linear NVIDIA] Unsupported DataType: %d\n", dtype); + abort(); + } +} + +}// namespace llaisys::ops::nvidia \ No newline at end of file diff --git a/src/ops/linear/op.cpp b/src/ops/linear/op.cpp index 97d1f865..98f4f525 100644 --- a/src/ops/linear/op.cpp +++ b/src/ops/linear/op.cpp @@ -1,7 +1,52 @@ #include "op.hpp" +#include "cpu/linear_cpu.hpp" +#include "../../core/context/context.hpp" +#include "../../utils.hpp" +#include + +namespace llaisys::ops::nvidia { +#ifdef ENABLE_NVIDIA_API + void linear(tensor_t out, tensor_t in, tensor_t weight, tensor_t bias); +#endif +} namespace llaisys::ops { + void linear(tensor_t out, tensor_t in, tensor_t weight, tensor_t bias) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(out, in); + CHECK_SAME_DEVICE(out, weight); + if (bias) { + CHECK_SAME_DEVICE(out, bias); + } + + auto device = core::context().runtime().deviceType(); + + llaisys::core::context().setDevice(device, out->deviceId()); + + if (device == LLAISYS_DEVICE_CPU) { + + size_t M = in->shape()[0]; + size_t K = in->shape()[1]; + + size_t N = weight->shape()[0]; + + const std::byte* bias_ptr = nullptr; + if (bias && bias->numel() > 0) { + bias_ptr = bias->data(); + } + + cpu::linear(out->data(), in->data(), weight->data(), bias_ptr, + in->dtype(), M, K, N); + + } + #ifdef ENABLE_NVIDIA_API + else if (device == LLAISYS_DEVICE_NVIDIA) { + llaisys::ops::nvidia::linear(out, in, weight, bias); + } +#endif + else { + throw std::runtime_error("Linear: Unsupported device type"); + } } -} // namespace llaisys::ops + +} // namespace llaisys::ops \ No newline at end of file diff --git a/src/ops/rms_norm/cpu/rms_norm_cpu.cpp b/src/ops/rms_norm/cpu/rms_norm_cpu.cpp new file mode 100644 index 00000000..3c5b580e --- /dev/null +++ b/src/ops/rms_norm/cpu/rms_norm_cpu.cpp @@ -0,0 +1,55 @@ +#include "rms_norm_cpu.hpp" +#include "../../../utils.hpp" +#include + +namespace { + template + void rms_norm_(T *out,const T* input,const T* weight,size_t rows,size_t cols,float eps) + { + for(size_t i=0;i(input[i * cols + j]); + sum_sq += val * val; + } + float mean_sq= sum_sq / static_cast(cols); + float inv_rms=1.0f / std::sqrt(mean_sq + eps); + + for (size_t j = 0; j < cols; ++j) { + float val = llaisys::utils::cast(input[i * cols + j]); + float w = llaisys::utils::cast(weight[j]); + + float result = val * inv_rms * w; + + out[i * cols + j] = llaisys::utils::cast(result); + } + } + } +} + +namespace llaisys::ops::cpu { +void rms_norm(std::byte *out, const std::byte *input, const std::byte *weight, + llaisysDataType_t dtype, size_t rows, size_t cols, float eps) { + switch (dtype) { + case LLAISYS_DTYPE_F32: + return rms_norm_(reinterpret_cast(out), + reinterpret_cast(input), + reinterpret_cast(weight), + rows, cols, eps); + case LLAISYS_DTYPE_F16: + return rms_norm_(reinterpret_cast(out), + reinterpret_cast(input), + reinterpret_cast(weight), + rows, cols, eps); + case LLAISYS_DTYPE_BF16: + return rms_norm_(reinterpret_cast(out), + reinterpret_cast(input), + reinterpret_cast(weight), + rows, cols, eps); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(dtype); + } +} + +} // namespace llaisys::ops::cpu \ No newline at end of file diff --git a/src/ops/rms_norm/cpu/rms_norm_cpu.hpp b/src/ops/rms_norm/cpu/rms_norm_cpu.hpp new file mode 100644 index 00000000..6df2c005 --- /dev/null +++ b/src/ops/rms_norm/cpu/rms_norm_cpu.hpp @@ -0,0 +1,11 @@ +#pragma once +#include "llaisys.h" +#include + +namespace llaisys::ops::cpu { + + +void rms_norm(std::byte *out, const std::byte *input, const std::byte *weight, + llaisysDataType_t dtype, size_t rows, size_t cols, float eps); + +} // namespace llaisys::ops::cpu \ No newline at end of file diff --git a/src/ops/rms_norm/nvidia/rms_norm_nv.cu b/src/ops/rms_norm/nvidia/rms_norm_nv.cu new file mode 100644 index 00000000..dee04e6e --- /dev/null +++ b/src/ops/rms_norm/nvidia/rms_norm_nv.cu @@ -0,0 +1,133 @@ + #include "../op.hpp" +#include "../../../tensor/tensor.hpp" +#include "../../../core/context/context.hpp" +#include "../../../device/nvidia/nvidia_resource.cuh" + +#include +#include +#include +#include +#include +#include + + +namespace llaisys::ops::nvidia { +//warp reduce +template +__device__ __forceinline__ T warpReduceSum(T val) +{ + for(int mask=16;mask>0;mask>>=1) + { + val+=__shfl_down_sync(0xffffffff,val,mask); + } + return val; +} +//block reduce +template +__device__ __forceinline__ T blockReduceSum(T val) +{ + __shared__ T shared_val[32]; + __shared__ T final_val; + + int lane = threadIdx.x % 32; + int wid = threadIdx.x / 32; + + val = warpReduceSum(val); + + if(lane == 0) + { + shared_val[wid] = val; + } + __syncthreads(); + + int num_warps = (blockDim.x + 31) / 32; + val = (threadIdx.x < num_warps) ? shared_val[lane] : (T)(0.0f); + + if(wid == 0) + { + val = warpReduceSum(val); + if(threadIdx.x == 0) + { + final_val = val; + } + } + + __syncthreads(); + + return final_val; +} + +template +__global__ void rmsnorm_kernel( + T* output, + const T* input, + const T* weight, + float eps, + int hidden_dim) +{ + int row_idx=blockIdx.x; + int tid=threadIdx.x; + + const T* in_row=input+row_idx*hidden_dim; + T* out_row=output+row_idx*hidden_dim; + // Compute sum of squares + float local_sq_sum=0.0f; + for(int i=tid;i(local_sq_sum); + + float mean_sq=total_sq_sum/(float)hidden_dim; + float rsqrt_val=rsqrtf(mean_sq+eps); + // Normalize and apply weight + for(int i=tid;i +void launch_rmsnorm_kernel( + tensor_t output, + tensor_t input, + tensor_t weight, + float eps) +{ + int hidden_dim=input->shape().back(); + int num_tokens=input->numel()/hidden_dim; + + T* d_out = reinterpret_cast(output->data()); + const T* d_in = reinterpret_cast(input->data()); + const T* d_w = reinterpret_cast(weight->data()); + dim3 grid(num_tokens); + int threads=(hidden_dim<1024)?hidden_dim:1024; + threads=((threads+31)/32)*32; + dim3 block(threads); + rmsnorm_kernel<<>>(d_out,d_in,d_w,eps,hidden_dim); +} +// Host function to launch RMSNorm kernel +void rmsnorm(tensor_t output, tensor_t input, tensor_t weight, float eps) { + auto dtype = input->dtype(); + + switch (dtype) { + case LLAISYS_DTYPE_F32: + launch_rmsnorm_kernel(output, input, weight, eps); + break; + case LLAISYS_DTYPE_F16: + launch_rmsnorm_kernel<__half>(output, input, weight, eps); + break; + case LLAISYS_DTYPE_BF16: + launch_rmsnorm_kernel<__nv_bfloat16>(output, input, weight, eps); + break; + default: + fprintf(stderr, "[RMSNorm NVIDIA] Unsupported DataType: %d\n", dtype); + abort(); + } +} + +}//namespace llaisys::ops::nvidia diff --git a/src/ops/rms_norm/op.cpp b/src/ops/rms_norm/op.cpp index 529553d9..8e30922b 100644 --- a/src/ops/rms_norm/op.cpp +++ b/src/ops/rms_norm/op.cpp @@ -1,7 +1,50 @@ #include "op.hpp" +#include "cpu/rms_norm_cpu.hpp" +#include "../../core/context/context.hpp" +#include "../../utils.hpp" +#include + +namespace llaisys::ops::nvidia { +#ifdef ENABLE_NVIDIA_API + void rmsnorm(tensor_t out, tensor_t in, tensor_t weight, float eps); +#endif +} namespace llaisys::ops { + void rms_norm(tensor_t out, tensor_t in, tensor_t weight, float eps) { - TO_BE_IMPLEMENTED(); + // 1. 物理设备与介质隔离校验 + CHECK_SAME_DEVICE(out, in); + CHECK_SAME_DEVICE(out, weight); + + auto device = core::context().runtime().deviceType(); + + // 2. 硬件上下文强制切换,防止多卡显存踩踏 + llaisys::core::context().setDevice(device, out->deviceId()); + + // 3. 算子路由与分发 + if (device == LLAISYS_DEVICE_CPU) { + + // 提取 Hidden_dim,将前端多维逻辑张量展平为底层处理所需的二维矩阵 + size_t cols = in->shape().back(); + size_t rows = in->numel() / cols; + + if (weight->numel() != cols) { + throw std::invalid_argument("RMSNorm: Weight shape mismatch. Must match Hidden_dim."); + } + + cpu::rms_norm(out->data(), in->data(), weight->data(), + in->dtype(), rows, cols, eps); + + } +#ifdef ENABLE_NVIDIA_API + else if (device == LLAISYS_DEVICE_NVIDIA) { + llaisys::ops::nvidia::rmsnorm(out, in, weight, eps); + } +#endif + else { + throw std::runtime_error("RMSNorm: Unsupported device type"); + } } -} // namespace llaisys::ops + +} // namespace llaisys::ops \ No newline at end of file diff --git a/src/ops/rope/cpu/rope_cpu.cpp b/src/ops/rope/cpu/rope_cpu.cpp new file mode 100644 index 00000000..7c245b99 --- /dev/null +++ b/src/ops/rope/cpu/rope_cpu.cpp @@ -0,0 +1,81 @@ +#include "rope_cpu.hpp" +#include "../../../utils.hpp" +#include +#include + +namespace { + +template +void rope_(T *out, const T *input, const int64_t *pos_ids, + size_t seq_len, size_t n_heads, size_t head_dim, float theta) { + + size_t dim_half = head_dim / 2; + + std::vector cos_cache(dim_half); + std::vector sin_cache(dim_half); + + for (size_t s = 0; s < seq_len; ++s) { + int64_t pos = pos_ids[s]; + + for (size_t j = 0; j < dim_half; ++j) { + float freq_expon = static_cast(2 * j) / static_cast(head_dim); + float freq = static_cast(pos) / std::pow(theta, freq_expon); + + cos_cache[j] = std::cos(freq); + sin_cache[j] = std::sin(freq); + } + + for (size_t h = 0; h < n_heads; ++h) { + + size_t offset = s * n_heads * head_dim + h * head_dim; + + const T* src_vec = input + offset; + T* dst_vec = out + offset; + + for (size_t j = 0; j < dim_half; ++j) { + + float val_a = llaisys::utils::cast(src_vec[j]); + float val_b = llaisys::utils::cast(src_vec[j + dim_half]); + + float cos_val = cos_cache[j]; + float sin_val = sin_cache[j]; + + float res_a = val_a * cos_val - val_b * sin_val; + float res_b = val_b * cos_val + val_a * sin_val; + + dst_vec[j] = llaisys::utils::cast(res_a); + dst_vec[j + dim_half] = llaisys::utils::cast(res_b); + } + } + } +} + +} // namespace + +namespace llaisys::ops::cpu { + +void rope(std::byte *out, const std::byte *input, const std::byte *pos_ids, + llaisysDataType_t dtype, size_t seq_len, size_t n_heads, size_t head_dim, + float theta) { + + const int64_t* pos_ptr = reinterpret_cast(pos_ids); + + switch (dtype) { + case LLAISYS_DTYPE_F32: + return rope_(reinterpret_cast(out), + reinterpret_cast(input), + pos_ptr, seq_len, n_heads, head_dim, theta); + case LLAISYS_DTYPE_F16: + return rope_(reinterpret_cast(out), + reinterpret_cast(input), + pos_ptr, seq_len, n_heads, head_dim, theta); + case LLAISYS_DTYPE_BF16: + return rope_(reinterpret_cast(out), + reinterpret_cast(input), + pos_ptr, seq_len, n_heads, head_dim, theta); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(dtype); + } +} + +} // namespace llaisys::ops::cpu \ No newline at end of file diff --git a/src/ops/rope/cpu/rope_cpu.hpp b/src/ops/rope/cpu/rope_cpu.hpp new file mode 100644 index 00000000..4911f7c6 --- /dev/null +++ b/src/ops/rope/cpu/rope_cpu.hpp @@ -0,0 +1,14 @@ +#pragma once +#include "llaisys.h" +#include + +namespace llaisys::ops::cpu { +void rope(std::byte* out, + const std::byte* in, + const std::byte* pos_ids, + llaisysDataType_t dtype, + size_t num_tokens, + size_t n_heads, + size_t head_dim, + float theta); +} // namespace llaisys::ops::cpu \ No newline at end of file diff --git a/src/ops/rope/nvidia/rope_nv.cu b/src/ops/rope/nvidia/rope_nv.cu new file mode 100644 index 00000000..b8c595bf --- /dev/null +++ b/src/ops/rope/nvidia/rope_nv.cu @@ -0,0 +1,92 @@ +#include "../op.hpp" +#include "../../../tensor/tensor.hpp" +#include "../../../core/context/context.hpp" +#include "../../../device/nvidia/nvidia_resource.cuh" + +#include +#include +#include +#include + +namespace llaisys::ops::nvidia { + +template +__global__ void rope_kernel( + T* output, + const T* input, + const int64_t* pos_ids, + int num_heads, + int head_dim, + float theta) +{ + int token_idx = blockIdx.x / num_heads; + int half_dim = head_dim / 2; + int64_t m = pos_ids[token_idx]; + + const T* in_ptr = input + blockIdx.x * head_dim; + T* out_ptr = output + blockIdx.x * head_dim; + + + for (int tid = threadIdx.x; tid < half_dim; tid += blockDim.x) { + + double freq_expon = static_cast(2 * tid) / static_cast(head_dim); + double inv_freq = pow(theta, -freq_expon); + double freq = static_cast(m) * inv_freq; + + float cos_val = static_cast(cos(freq)); + float sin_val = static_cast(sinf(freq)); + + float val_a = static_cast(in_ptr[tid]); + float val_b = static_cast(in_ptr[tid + half_dim]); + + float out_a = val_a * cos_val - val_b * sin_val; + float out_b = val_b * cos_val + val_a * sin_val; + + out_ptr[tid] = static_cast(out_a); + out_ptr[tid + half_dim] = static_cast(out_b); + } +} + +template +void launch_rope_kernel( + tensor_t output, + tensor_t input, + tensor_t pos_ids, + float theta) +{ + int head_dim = input->shape().back(); + int num_heads = input->shape()[input->shape().size() - 2]; + int num_tokens = input->numel() / (num_heads * head_dim); + int half_dim = head_dim / 2; + + int block_size = half_dim < 1024 ? half_dim : 1024; + + dim3 grid(num_tokens * num_heads); + dim3 block(block_size); + + T* d_out = reinterpret_cast(output->data()); + const T* d_in = reinterpret_cast(input->data()); + const int64_t* d_pos_ids = reinterpret_cast(pos_ids->data()); + + rope_kernel<<>>(d_out, d_in, d_pos_ids, num_heads, head_dim, theta); +} + +void rope(tensor_t output, tensor_t input, tensor_t pos_ids, float theta) { + auto dtype = input->dtype(); + + switch (dtype) { + case LLAISYS_DTYPE_F32: + launch_rope_kernel(output, input, pos_ids, theta); + break; + case LLAISYS_DTYPE_F16: + launch_rope_kernel<__half>(output, input, pos_ids, theta); + break; + case LLAISYS_DTYPE_BF16: + launch_rope_kernel<__nv_bfloat16>(output, input, pos_ids, theta); + break; + default: + abort(); + } +} + +} // namespace llaisys::ops::nvidia \ No newline at end of file diff --git a/src/ops/rope/op.cpp b/src/ops/rope/op.cpp index d60dbe64..cf60ed35 100644 --- a/src/ops/rope/op.cpp +++ b/src/ops/rope/op.cpp @@ -1,7 +1,56 @@ #include "op.hpp" +#include "cpu/rope_cpu.hpp" +#include "../../core/context/context.hpp" +#include "../../utils.hpp" +#include + + +namespace llaisys::ops::nvidia { +#ifdef ENABLE_NVIDIA_API + void rope(tensor_t output, tensor_t input, tensor_t pos_ids, float theta); +#endif +} namespace llaisys::ops { + void rope(tensor_t out, tensor_t in, tensor_t pos_ids, float theta) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(out, in); + CHECK_SAME_DEVICE(out, pos_ids); + + if (in->ndim() < 2) { + throw std::invalid_argument("RoPE: Input tensor must be at least 2D [..., head, dim]"); + } + + size_t head_dim = in->shape().back(); + size_t n_heads = in->shape()[in->ndim() - 2]; + + if (head_dim % 2 != 0) { + throw std::invalid_argument("RoPE: Head dimension must be mathematically even for half-split logic"); + } + + if (pos_ids->dtype() != LLAISYS_DTYPE_I64 && pos_ids->dtype() != LLAISYS_DTYPE_I32) { + throw std::invalid_argument("RoPE: pos_ids must be an integer type (int32 or int64)"); + } + + size_t num_tokens = in->numel() / (n_heads * head_dim); + if (pos_ids->numel() != num_tokens) { + throw std::invalid_argument("RoPE: pos_ids total elements must match the number of tokens"); + } + + auto device = core::context().runtime().deviceType(); + + if (device == LLAISYS_DEVICE_CPU) { + cpu::rope(out->data(), in->data(), pos_ids->data(), + in->dtype(), num_tokens, n_heads, head_dim, theta); + } +#ifdef ENABLE_NVIDIA_API + else if (device == LLAISYS_DEVICE_NVIDIA) { + nvidia::rope(out, in, pos_ids, theta); + } +#endif + else { + throw std::runtime_error("RoPE: Unsupported device physical routing"); + } } -} // namespace llaisys::ops + +} // namespace llaisys::ops \ No newline at end of file diff --git a/src/ops/self_attention/cpu/self_attention_cpu.cpp b/src/ops/self_attention/cpu/self_attention_cpu.cpp new file mode 100644 index 00000000..35339257 --- /dev/null +++ b/src/ops/self_attention/cpu/self_attention_cpu.cpp @@ -0,0 +1,126 @@ +#include "self_attention_cpu.hpp" +#include "../../../utils.hpp" +#include +#include +#include +#include + +namespace { + +template +void self_attention_(T *out, const T *q, const T *k, const T *v, + size_t seqlen, size_t total_len, + size_t nhead, size_t nkvhead, + size_t head_dim, size_t head_dim_v, + float scale) { + + + size_t group_size = nhead / nkvhead; + + std::vector scores(total_len); + + size_t start_pos = total_len - seqlen; + + for (size_t i = 0; i < seqlen; ++i) { + size_t q_global_pos = start_pos + i; + + for (size_t h = 0; h < nhead; ++h) { + size_t kv_h = h / group_size; + + float max_score = -std::numeric_limits::infinity(); + + const T* q_vec = q + (i * nhead * head_dim) + (h * head_dim); + + for (size_t t = 0; t < total_len; ++t) { + if (t > q_global_pos) { + scores[t] = -std::numeric_limits::infinity(); + continue; + } + + const T* k_vec = k + (t * nkvhead * head_dim) + (kv_h * head_dim); + + float dot = 0.0f; + for (size_t d = 0; d < head_dim; ++d) { + dot += llaisys::utils::cast(q_vec[d]) * llaisys::utils::cast(k_vec[d]); + } + dot *= scale; + scores[t] = dot; + + if (dot > max_score) { + max_score = dot; + } + } + + float sum_exp = 0.0f; + for (size_t t = 0; t < total_len; ++t) { + if (scores[t] == -std::numeric_limits::infinity()) { + scores[t] = 0.0f; // exp(-inf) = 0 + } else { + float val = std::exp(scores[t] - max_score); + scores[t] = val; + sum_exp += val; + } + } + + float inv_sum = 1.0f / (sum_exp + 1e-6f); + for (size_t t = 0; t < total_len; ++t) { + scores[t] *= inv_sum; + } + + T* out_vec = out + (i * nhead * head_dim_v) + (h * head_dim_v); + + std::vector acc_out(head_dim_v, 0.0f); + + for (size_t t = 0; t < total_len; ++t) { + float weight = scores[t]; + if (weight == 0.0f) continue; + + const T* v_vec = v + (t * nkvhead * head_dim_v) + (kv_h * head_dim_v); + + for (size_t d = 0; d < head_dim_v; ++d) { + acc_out[d] += weight * llaisys::utils::cast(v_vec[d]); + } + } + + for (size_t d = 0; d < head_dim_v; ++d) { + out_vec[d] = llaisys::utils::cast(acc_out[d]); + } + } + } +} + +} // namespace + +namespace llaisys::ops::cpu { + +void self_attention(std::byte *out, const std::byte *q, const std::byte *k, const std::byte *v, + llaisysDataType_t dtype, + size_t seqlen, size_t total_len, + size_t nhead, size_t nkvhead, + size_t head_dim, size_t head_dim_v, + float scale) { + switch (dtype) { + case LLAISYS_DTYPE_F32: + return self_attention_(reinterpret_cast(out), + reinterpret_cast(q), + reinterpret_cast(k), + reinterpret_cast(v), + seqlen, total_len, nhead, nkvhead, head_dim, head_dim_v, scale); + case LLAISYS_DTYPE_F16: + return self_attention_(reinterpret_cast(out), + reinterpret_cast(q), + reinterpret_cast(k), + reinterpret_cast(v), + seqlen, total_len, nhead, nkvhead, head_dim, head_dim_v, scale); + case LLAISYS_DTYPE_BF16: + return self_attention_(reinterpret_cast(out), + reinterpret_cast(q), + reinterpret_cast(k), + reinterpret_cast(v), + seqlen, total_len, nhead, nkvhead, head_dim, head_dim_v, scale); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(dtype); + } +} + +} // namespace llaisys::ops::cpu \ No newline at end of file diff --git a/src/ops/self_attention/cpu/self_attention_cpu.hpp b/src/ops/self_attention/cpu/self_attention_cpu.hpp new file mode 100644 index 00000000..d49eb712 --- /dev/null +++ b/src/ops/self_attention/cpu/self_attention_cpu.hpp @@ -0,0 +1,15 @@ +#pragma once +#include "llaisys.h" +#include + +namespace llaisys::ops::cpu { + + +void self_attention(std::byte *out, const std::byte *q, const std::byte *k, const std::byte *v, + llaisysDataType_t dtype, + size_t seqlen, size_t total_len, + size_t nhead, size_t nkvhead, + size_t head_dim, size_t head_dim_v, + float scale); + +} // namespace llaisys::ops::cpu \ No newline at end of file diff --git a/src/ops/self_attention/nvidia/self_attention_nv.cu b/src/ops/self_attention/nvidia/self_attention_nv.cu new file mode 100644 index 00000000..85cac26a --- /dev/null +++ b/src/ops/self_attention/nvidia/self_attention_nv.cu @@ -0,0 +1,199 @@ +#include "../op.hpp" +#include "../../../tensor/tensor.hpp" +#include "../../../core/context/context.hpp" +#include "../../../device/nvidia/nvidia_resource.cuh" + +#include +#include +#include +#include +#include + +namespace llaisys::ops::nvidia { + +constexpr int Br = 16; +constexpr int Bc = 16; + +// shared memory layout (per block): +// s_Q[Br * head_dim] +// s_K[Bc * head_dim] +// s_V[Bc * head_dim] +// s_O[Br * head_dim] +// s_m[Br] (running max) +// s_l[Br] (running sum) + +template +__global__ void self_attention_kernel( + T* output, + const T* q_in, + const T* k_in, + const T* v_in, + int q_len, + int kv_len, + int num_heads, + int num_kv_heads, + int head_dim, + float sm_scale) +{ + int q_block_idx = blockIdx.x; + int head_idx = blockIdx.y; + int batch_idx = blockIdx.z; + int tid = threadIdx.x; // 0..Br-1 + + int kv_head_idx = head_idx / (num_heads / num_kv_heads); + int start_pos = kv_len - q_len; + + int q_global_row = q_block_idx * Br + tid; + + int q_batch_offset = batch_idx * q_len * num_heads * head_dim; + int kv_batch_offset = batch_idx * kv_len * num_kv_heads * head_dim; + + // shared memory layout + extern __shared__ float s_mem[]; + float* s_Q = s_mem; // [Br * head_dim] + float* s_K = s_Q + Br * head_dim; // [Bc * head_dim] + float* s_V = s_K + Bc * head_dim; // [Bc * head_dim] + float* s_O = s_V + Bc * head_dim; // [Br * head_dim] + float* s_m = s_O + Br * head_dim; // [Br] + float* s_l = s_m + Br; // [Br] + + // init running stats and output + if (tid < Br) { + s_m[tid] = -FLT_MAX; + s_l[tid] = 0.0f; + for (int i = 0; i < head_dim; ++i) + s_O[tid * head_dim + i] = 0.0f; + } + + // load Q tile + if (q_global_row < q_len) { + int base = q_batch_offset + q_global_row * num_heads * head_dim + head_idx * head_dim; + for (int i = 0; i < head_dim; ++i) + s_Q[tid * head_dim + i] = static_cast(q_in[base + i]); + } + __syncthreads(); + + int num_k_blocks = (kv_len + Bc - 1) / Bc; + for (int k_idx = 0; k_idx < num_k_blocks; ++k_idx) { + + int k_global_row_start = k_idx * Bc; + + // load K/V tile (tid indexes into Bc rows) + if (tid < Bc) { + int k_global_row = k_global_row_start + tid; + if (k_global_row < kv_len) { + int base = kv_batch_offset + k_global_row * num_kv_heads * head_dim + kv_head_idx * head_dim; + for (int i = 0; i < head_dim; ++i) { + s_K[tid * head_dim + i] = static_cast(k_in[base + i]); + s_V[tid * head_dim + i] = static_cast(v_in[base + i]); + } + } else { + for (int i = 0; i < head_dim; ++i) { + s_K[tid * head_dim + i] = 0.0f; + s_V[tid * head_dim + i] = 0.0f; + } + } + } + __syncthreads(); + + if (q_global_row < q_len) { + int causal_limit = start_pos + q_global_row; + + // compute S = Q * K^T for this tile + float r_S[Bc]; + float m_ij = -FLT_MAX; + + for (int j = 0; j < Bc; ++j) { + int kj = k_global_row_start + j; + if (kj >= kv_len || kj > causal_limit) { + r_S[j] = -FLT_MAX; + continue; + } + float sum = 0.0f; + for (int i = 0; i < head_dim; ++i) + sum += s_Q[tid * head_dim + i] * s_K[j * head_dim + i]; + r_S[j] = sum * sm_scale; + m_ij = fmaxf(m_ij, r_S[j]); + } + + if (m_ij == -FLT_MAX) { + __syncthreads(); + continue; + } + + float m_i_new = fmaxf(s_m[tid], m_ij); + float exp_diff = expf(s_m[tid] - m_i_new); + + float l_ij = 0.0f; + for (int j = 0; j < Bc; ++j) { + r_S[j] = (r_S[j] == -FLT_MAX) ? 0.0f : expf(r_S[j] - m_i_new); + l_ij += r_S[j]; + } + + for (int i = 0; i < head_dim; ++i) { + float pv = 0.0f; + for (int j = 0; j < Bc; ++j) + pv += r_S[j] * s_V[j * head_dim + i]; + s_O[tid * head_dim + i] = exp_diff * s_O[tid * head_dim + i] + pv; + } + + s_m[tid] = m_i_new; + s_l[tid] = exp_diff * s_l[tid] + l_ij; + } + __syncthreads(); + } + + // write output + if (q_global_row < q_len) { + float inv_l = 1.0f / s_l[tid]; + int base = q_batch_offset + q_global_row * num_heads * head_dim + head_idx * head_dim; + for (int i = 0; i < head_dim; ++i) + output[base + i] = static_cast(s_O[tid * head_dim + i] * inv_l); + } +} + +template +void launch_self_attention_kernel(tensor_t output, tensor_t q, tensor_t k, tensor_t v, float scale) +{ + int batch_size, q_len, kv_len, num_heads, num_kv_heads, head_dim; + + if (q->ndim() == 3) { + batch_size = 1; + q_len = q->shape()[0]; + num_heads = q->shape()[1]; + head_dim = q->shape()[2]; + kv_len = k->shape()[0]; + num_kv_heads = k->shape()[1]; + } else { + batch_size = q->shape()[0]; + q_len = q->shape()[1]; + num_heads = q->shape()[2]; + head_dim = q->shape()[3]; + kv_len = k->shape()[1]; + num_kv_heads = k->shape()[2]; + } + + dim3 grid((q_len + Br - 1) / Br, num_heads, batch_size); + dim3 block(Br); + // s_Q + s_K + s_V + s_O + s_m + s_l + size_t smem_bytes = (2 * Br + 2 * Bc) * head_dim * sizeof(float) + + 2 * Br * sizeof(float); + + self_attention_kernel<<>>( + reinterpret_cast(output->data()), + reinterpret_cast(q->data()), + reinterpret_cast(k->data()), + reinterpret_cast(v->data()), + q_len, kv_len, num_heads, num_kv_heads, head_dim, scale); +} + +void self_attention(tensor_t output, tensor_t q, tensor_t k, tensor_t v, float scale) { + switch (q->dtype()) { + case LLAISYS_DTYPE_F32: launch_self_attention_kernel(output, q, k, v, scale); break; + case LLAISYS_DTYPE_F16: launch_self_attention_kernel<__half>(output, q, k, v, scale); break; + case LLAISYS_DTYPE_BF16: launch_self_attention_kernel<__nv_bfloat16>(output, q, k, v, scale); break; + default: abort(); + } +} + +} // namespace llaisys::ops::nvidia diff --git a/src/ops/self_attention/op.cpp b/src/ops/self_attention/op.cpp index 43d62014..8adacd24 100644 --- a/src/ops/self_attention/op.cpp +++ b/src/ops/self_attention/op.cpp @@ -1,7 +1,56 @@ #include "op.hpp" +#include "cpu/self_attention_cpu.hpp" +#include "../../core/context/context.hpp" +#include "../../utils.hpp" +#include + +namespace llaisys::ops::nvidia { +#ifdef ENABLE_NVIDIA_API + void self_attention(tensor_t output, tensor_t q, tensor_t k, tensor_t v, float scale); +#endif +} namespace llaisys::ops { + void self_attention(tensor_t attn_val, tensor_t q, tensor_t k, tensor_t v, float scale) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(attn_val, q); + CHECK_SAME_DEVICE(attn_val, k); + CHECK_SAME_DEVICE(attn_val, v); + + auto device = core::context().runtime().deviceType(); + + size_t seqlen = q->shape()[0]; + size_t nhead = q->shape()[1]; + size_t d = q->shape()[2]; + + size_t total_len = k->shape()[0]; + size_t nkvhead = k->shape()[1]; + size_t dv = v->shape()[2]; + + if (nhead % nkvhead != 0) { + throw std::invalid_argument("SelfAttention: nhead must be divisible by nkvhead"); + } + if (k->shape()[2] != d) { + throw std::invalid_argument("SelfAttention: Key dimension must match Query dimension"); + } + if (attn_val->shape()[0] != seqlen || + attn_val->shape()[1] != nhead || + attn_val->shape()[2] != dv) { + throw std::invalid_argument("SelfAttention: Output shape mismatch"); + } + + if (device == LLAISYS_DEVICE_CPU) { + cpu::self_attention(attn_val->data(), q->data(), k->data(), v->data(), + q->dtype(), seqlen, total_len, nhead, nkvhead, d, dv, scale); + } +#ifdef ENABLE_NVIDIA_API + else if (device == LLAISYS_DEVICE_NVIDIA) { + nvidia::self_attention(attn_val, q, k, v, scale); + } +#endif + else { + throw std::runtime_error("SelfAttention: Unsupported device physical routing"); + } } -} // namespace llaisys::ops + +} // namespace llaisys::ops \ No newline at end of file diff --git a/src/ops/swiglu/cpu/swiglu_cpu.cpp b/src/ops/swiglu/cpu/swiglu_cpu.cpp new file mode 100644 index 00000000..83729197 --- /dev/null +++ b/src/ops/swiglu/cpu/swiglu_cpu.cpp @@ -0,0 +1,49 @@ +#include "swiglu_cpu.hpp" +#include "../../../utils.hpp" +#include // for std::exp + +namespace { // 匿名空间 + +template +void swiglu_(T *out, const T *gate, const T *up, size_t numel) { + for (size_t i = 0; i < numel; ++i) { + float g = llaisys::utils::cast(gate[i]); + float u = llaisys::utils::cast(up[i]); + + float silu_g = g / (1.0f + std::exp(-g)); + + float res = u * silu_g; + + + out[i] = llaisys::utils::cast(res); + } +} + +} // namespace + +namespace llaisys::ops::cpu { + +void swiglu(std::byte *out, const std::byte *gate, const std::byte *up, + llaisysDataType_t dtype, size_t numel) { + switch (dtype) { + case LLAISYS_DTYPE_F32: + return swiglu_(reinterpret_cast(out), + reinterpret_cast(gate), + reinterpret_cast(up), + numel); + case LLAISYS_DTYPE_F16: + return swiglu_(reinterpret_cast(out), + reinterpret_cast(gate), + reinterpret_cast(up), + numel); + case LLAISYS_DTYPE_BF16: + return swiglu_(reinterpret_cast(out), + reinterpret_cast(gate), + reinterpret_cast(up), + numel); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(dtype); + } +} + +} // namespace llaisys::ops::cpu \ No newline at end of file diff --git a/src/ops/swiglu/cpu/swiglu_cpu.hpp b/src/ops/swiglu/cpu/swiglu_cpu.hpp new file mode 100644 index 00000000..77ba2d57 --- /dev/null +++ b/src/ops/swiglu/cpu/swiglu_cpu.hpp @@ -0,0 +1,10 @@ +#pragma once +#include "llaisys.h" +#include + +namespace llaisys::ops::cpu { + +void swiglu(std::byte *out, const std::byte *gate, const std::byte *up, + llaisysDataType_t dtype, size_t numel); + +} // namespace llaisys::ops::cpu \ No newline at end of file diff --git a/src/ops/swiglu/nvidia/swiglu_nv.cu b/src/ops/swiglu/nvidia/swiglu_nv.cu new file mode 100644 index 00000000..cc54da54 --- /dev/null +++ b/src/ops/swiglu/nvidia/swiglu_nv.cu @@ -0,0 +1,53 @@ +#include "../op.hpp" +#include "../../../tensor/tensor.hpp" +#include "../../../device/nvidia/nvidia_resource.cuh" + +#include +#include +#include +#include + +namespace llaisys::ops::nvidia { + +template +__global__ void swiglu_kernel(T* out, const T* gate, const T* up, int numel) +{ + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i >= numel) return; + + float g = static_cast(gate[i]); + float u = static_cast(up[i]); + float silu_g = g / (1.0f + expf(-g)); + out[i] = static_cast(u * silu_g); +} + +void swiglu(tensor_t out, tensor_t gate, tensor_t up) { + int numel = static_cast(out->numel()); + int block = 256; + int grid = (numel + block - 1) / block; + + switch (out->dtype()) { + case LLAISYS_DTYPE_F32: + swiglu_kernel<<>>( + reinterpret_cast(out->data()), + reinterpret_cast(gate->data()), + reinterpret_cast(up->data()), numel); + break; + case LLAISYS_DTYPE_F16: + swiglu_kernel<__half><<>>( + reinterpret_cast<__half*>(out->data()), + reinterpret_cast(gate->data()), + reinterpret_cast(up->data()), numel); + break; + case LLAISYS_DTYPE_BF16: + swiglu_kernel<__nv_bfloat16><<>>( + reinterpret_cast<__nv_bfloat16*>(out->data()), + reinterpret_cast(gate->data()), + reinterpret_cast(up->data()), numel); + break; + default: + abort(); + } +} + +} // namespace llaisys::ops::nvidia diff --git a/src/ops/swiglu/op.cpp b/src/ops/swiglu/op.cpp index 47edbcc9..78a95249 100644 --- a/src/ops/swiglu/op.cpp +++ b/src/ops/swiglu/op.cpp @@ -1,7 +1,36 @@ #include "op.hpp" +#include "cpu/swiglu_cpu.hpp" +#include "../../core/context/context.hpp" +#include "../../utils.hpp" +#include + +namespace llaisys::ops::nvidia { +#ifdef ENABLE_NVIDIA_API + void swiglu(tensor_t out, tensor_t gate, tensor_t up); +#endif +} namespace llaisys::ops { + void swiglu(tensor_t out, tensor_t gate, tensor_t up) { - TO_BE_IMPLEMENTED(); + auto device = core::context().runtime().deviceType(); + + size_t numel = out->numel(); + if (gate->numel() != numel || up->numel() != numel) { + throw std::invalid_argument("SwiGLU: Input/Output tensor shapes mismatch"); + } + + if (device == LLAISYS_DEVICE_CPU) { + cpu::swiglu(out->data(), gate->data(), up->data(), out->dtype(), numel); + } +#ifdef ENABLE_NVIDIA_API + else if (device == LLAISYS_DEVICE_NVIDIA) { + nvidia::swiglu(out, gate, up); + } +#endif + else { + throw std::runtime_error("SwiGLU: Unsupported device type"); + } } + } // namespace llaisys::ops diff --git a/src/tensor/tensor.cpp b/src/tensor/tensor.cpp index 2f594bb6..67f731c5 100644 --- a/src/tensor/tensor.cpp +++ b/src/tensor/tensor.cpp @@ -1,5 +1,5 @@ #include "tensor.hpp" - +#include #include "../utils.hpp" #include @@ -164,42 +164,205 @@ void Tensor::debug() const { } bool Tensor::isContiguous() const { - TO_BE_IMPLEMENTED(); + size_t accumulated = 1; + for(size_t i=0;i(accumulated)) + { + return false; + } + accumulated*=_meta.shape[current_dim]; + } return true; } tensor_t Tensor::permute(const std::vector &order) const { - TO_BE_IMPLEMENTED(); - return std::shared_ptr(new Tensor(_meta, _storage)); + if(order.size()!=ndim()){ + std::cerr << "Error: permute order size mismatch. Expected " << ndim() + << " but got " << order.size() << std::endl; + return nullptr; + } + std::vector new_shape(this->ndim()); + std::vector new_strides(this->ndim()); + + for(size_t i=0;indim();++i) + { + size_t old_idx=order[i]; + if(old_idx>=this->ndim()){ + std::cerr << "Error: permute order index out of range. Got " << old_idx << std::endl; + return nullptr; + } + new_shape[i]=_meta.shape[old_idx]; + new_strides[i]=_meta.strides[old_idx]; + } + TensorMeta new_meta{_meta.dtype,new_shape,new_strides}; + return std::shared_ptr(new Tensor(new_meta, _storage, _offset)); } tensor_t Tensor::view(const std::vector &shape) const { - TO_BE_IMPLEMENTED(); - return std::shared_ptr(new Tensor(_meta, _storage)); + size_t new_numel=1; + for(auto s:shape){ + new_numel*=s; + } + if(new_numel!=this->numel()){ + std::cerr << "Error: view size mismatch. Expected " << this->numel() + << " but got " << new_numel << std::endl; + return nullptr; + } + + if(!this->isContiguous()){ + std::cerr << "Error: view requires contiguous tensor." << std::endl; + return nullptr; + } + std::vector new_strides(shape.size()); +size_t stride = 1; + +if (!shape.empty()) { + // 安全的 size_t 逆序循环写法 + for (size_t idx = shape.size(); idx-- > 0; ) { + new_strides[idx] = stride; + stride *= shape[idx]; +} +} + + TensorMeta new_meta{_meta.dtype,shape,new_strides}; + return std::shared_ptr(new Tensor(new_meta, _storage, _offset)); } tensor_t Tensor::slice(size_t dim, size_t start, size_t end) const { - TO_BE_IMPLEMENTED(); - return std::shared_ptr(new Tensor(_meta, _storage)); + if(dim>=this->ndim()){ + std::cerr << "Error: slice dimension out of range. Got " << dim << std::endl; + return nullptr; + } + if(start>=end||end>this->shape()[dim]){ + std::cerr << "Error: slice indices out of range. Got [" << start << ", " << end << ")" + << " for dimension size " << this->shape()[dim] << std::endl; + return nullptr; + } + + std::vector new_shape=this->shape(); + new_shape[dim]=end-start; + + std::vector new_strides=this->strides(); + + size_t skipped_elements=start*new_strides[dim]; + size_t offset_bytes=skipped_elements*this->elementSize(); + size_t new_offset=this->_offset+offset_bytes; + + TensorMeta new_meta{this->dtype(),new_shape,new_strides}; + return std::shared_ptr(new Tensor(new_meta, _storage, new_offset)); } void Tensor::load(const void *src_) { - TO_BE_IMPLEMENTED(); + size_t size=this->numel()*this->elementSize(); + void *dst=this->data(); + llaisysMemcpyKind_t kind=LLAISYS_MEMCPY_H2H; + if(this->deviceType()==LLAISYS_DEVICE_NVIDIA){ + kind=LLAISYS_MEMCPY_H2D; + } + core::context().runtime().api()->memcpy_sync(dst,src_,size,kind); } tensor_t Tensor::contiguous() const { - TO_BE_IMPLEMENTED(); - return std::shared_ptr(new Tensor(_meta, _storage)); + if (this->isContiguous()) { + return std::shared_ptr(new Tensor(_meta, _storage, _offset)); + } + auto res = Tensor::create( + this->shape(), this->dtype(), this->deviceType(), this->deviceId() + ); + + char* dst_ptr = reinterpret_cast(res->data()); + const char* src_base = reinterpret_cast(this->data()); + + size_t elem_size = this->elementSize(); + const auto& shape = this->shape(); + const auto& strides = this->strides(); + size_t ndim = this->ndim(); + + std::function recursive_copy = + [&](size_t dim, size_t src_offset) { + + if (ndim == 0) { + std::memcpy(dst_ptr, src_base, elem_size); + dst_ptr += elem_size; + return; + } + + if (dim == ndim - 1) { + size_t stride_bytes = strides[dim] * elem_size; + for (size_t i = 0; i < shape[dim]; ++i) { + std::memcpy(dst_ptr, src_base + src_offset + i * stride_bytes, elem_size); + dst_ptr += elem_size; + } + } else { + size_t stride_bytes = strides[dim] * elem_size; + for (size_t i = 0; i < shape[dim]; ++i) { + recursive_copy(dim + 1, src_offset + i * stride_bytes); + } + } + }; + + if (this->deviceType() == LLAISYS_DEVICE_CPU) { + recursive_copy(0, 0); + } else { + std::cerr << "Error: contiguous for GPU is not implemented." << std::endl; + } + + return res; } tensor_t Tensor::reshape(const std::vector &shape) const { - TO_BE_IMPLEMENTED(); - return std::shared_ptr(new Tensor(_meta, _storage)); + size_t new_numel = 1; + for (auto s : shape) new_numel *= s; + + if (new_numel != this->numel()) { + std::cerr << "Error: reshape size mismatch." << std::endl; + return nullptr; + } + + + if (this->isContiguous()) { + return this->view(shape); + } else { + return this->contiguous()->view(shape); + } + } tensor_t Tensor::to(llaisysDeviceType_t device_type, int device) const { - TO_BE_IMPLEMENTED(); - return std::shared_ptr(new Tensor(_meta, _storage)); + if (device_type == this->deviceType() && device == this->deviceId()) { + return std::shared_ptr(new Tensor(_meta, _storage, _offset)); + } + + tensor_t src_contiguous = this->contiguous(); + + + auto res = Tensor::create(src_contiguous->shape(), src_contiguous->dtype(), device_type, device); + + void *dst_ptr = res->data(); + const void *src_ptr = src_contiguous->data(); + size_t size = src_contiguous->numel() * src_contiguous->elementSize(); + + llaisysMemcpyKind_t kind; + + bool is_src_gpu = (this->deviceType() == LLAISYS_DEVICE_NVIDIA); + bool is_dst_gpu = (device_type == LLAISYS_DEVICE_NVIDIA); + + if (!is_src_gpu && is_dst_gpu) { + kind = LLAISYS_MEMCPY_H2D; // Host -> Device (CPU to GPU) + } else if (is_src_gpu && !is_dst_gpu) { + kind = LLAISYS_MEMCPY_D2H; // Device -> Host (GPU to CPU) + } else if (is_src_gpu && is_dst_gpu) { + kind = LLAISYS_MEMCPY_D2D; // Device -> Device (GPU to GPU) + } else { + kind = LLAISYS_MEMCPY_H2H; // Host -> Host (CPU to CPU) + } + + core::context().runtime().api()->memcpy_sync(dst_ptr, src_ptr, size, kind); + + return res; } } // namespace llaisys +// Force recompile diff --git a/test/ops/self_attention.py b/test/ops/self_attention.py index a042b51b..abf3927a 100644 --- a/test/ops/self_attention.py +++ b/test/ops/self_attention.py @@ -15,7 +15,7 @@ def torch_self_attention(attn_val, query, key, value, scale): L, S = query.size(-2), key.size(-2) attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) - temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=S-L) + temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=S-L) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) attn_bias.to(query.dtype) diff --git a/test/test_infer.py b/test/test_infer.py index 59d06b87..7decbc8c 100644 --- a/test/test_infer.py +++ b/test/test_infer.py @@ -145,5 +145,12 @@ def llaisys_infer( print(f"Time elapsed: {(end_time - start_time):.2f}s\n") if args.test: - assert llaisys_tokens == tokens - print("\033[92mTest passed!\033[0m\n") + if llaisys_tokens == tokens: + print("\033[92mTest passed!\033[0m\n") + sys.stdout.flush() + os._exit(0) + else: + print("\033[91mTest failed!\033[0m") + print(f"Expected: {tokens}") + print(f"Got: {llaisys_tokens}") + sys.exit(1) diff --git a/xmake.lua b/xmake.lua index 1f65f7a9..a3ae42b7 100644 --- a/xmake.lua +++ b/xmake.lua @@ -16,6 +16,11 @@ option_end() if has_config("nv-gpu") then add_defines("ENABLE_NVIDIA_API") includes("xmake/nvidia.lua") + + if is_plat("linux") then + add_sysincludedirs("/usr/local/cuda/include") + add_linkdirs("/usr/local/cuda/lib64") + end end target("llaisys-utils") @@ -37,6 +42,9 @@ target("llaisys-device") set_kind("static") add_deps("llaisys-utils") add_deps("llaisys-device-cpu") + if(has_config("nv-gpu")) then + add_deps("llaisys-device-nvidia") + end set_languages("cxx17") set_warnings("all", "error") @@ -53,6 +61,9 @@ target("llaisys-core") set_kind("static") add_deps("llaisys-utils") add_deps("llaisys-device") + if(has_config("nv-gpu"))then + add_deps("llaisys-device-nvidia") + end set_languages("cxx17") set_warnings("all", "error") @@ -83,6 +94,9 @@ target_end() target("llaisys-ops") set_kind("static") add_deps("llaisys-ops-cpu") + if has_config("nv-gpu") then + add_deps("llaisys-ops-nvidia") + end set_languages("cxx17") set_warnings("all", "error") @@ -95,6 +109,22 @@ target("llaisys-ops") on_install(function (target) end) target_end() +target("llaisys-models") + set_kind("static") + add_deps("llaisys-tensor") + add_deps("llaisys-ops") + + set_languages("cxx17") + set_warnings("all", "error") + if not is_plat("windows") then + add_cxflags("-fPIC", "-Wno-unknown-pragmas") + end + + add_files("src/models/*/**.cpp") + + on_install(function (target) end) +target_end() + target("llaisys") set_kind("shared") add_deps("llaisys-utils") @@ -102,10 +132,32 @@ target("llaisys") add_deps("llaisys-core") add_deps("llaisys-tensor") add_deps("llaisys-ops") + add_deps("llaisys-models") + add_deps("llaisys-device-cpu") + add_deps("llaisys-ops-cpu") set_languages("cxx17") set_warnings("all", "error") - add_files("src/llaisys/*.cc") + if has_config("nv-gpu") then + add_deps("llaisys-device-nvidia") + add_deps("llaisys-ops-nvidia") + if is_plat("linux") then + add_syslinks("cudart", "cublas") + add_shflags("-Xcompiler -fPIC") + end + + set_toolset("sh", "nvcc") + + add_syslinks("cudart", "cublas") + + if not is_plat("windows") then + add_cxflags("-fPIC") + add_shflags("-Xcompiler -fPIC") + add_shflags("-shared") + end + end + + add_files("src/llaisys/**.cc") set_installdir(".") diff --git a/xmake/nvidia.lua b/xmake/nvidia.lua new file mode 100644 index 00000000..0e9616ba --- /dev/null +++ b/xmake/nvidia.lua @@ -0,0 +1,33 @@ +target("llaisys-device-nvidia") + set_kind("static") + set_languages("cxx17") + add_cugencodes("native") + set_warnings("all", "error") + if not is_plat("windows") then + add_cxflags("-fPIC", "-Wno-unknown-pragmas") + add_cuflags("-Xcompiler -fPIC") + end + + add_syslinks("cudart") + add_files("../src/device/nvidia/*.cu") + + on_install(function (target) end) +target_end() + +target("llaisys-ops-nvidia") + set_kind("static") + add_deps("llaisys-tensor") + set_languages("cxx17") + add_cugencodes("native") + set_warnings("all", "error") + if not is_plat("windows") then + add_cxflags("-fPIC", "-Wno-unknown-pragmas") + add_cuflags("-Xcompiler -fPIC") + end + + add_syslinks("cudart","cublas") + add_files("../src/ops/*/nvidia/*.cu") + + on_install(function (target) end) +target_end() + diff --git "a/\351\241\271\347\233\256\346\212\245\345\221\212.md" "b/\351\241\271\347\233\256\346\212\245\345\221\212.md" new file mode 100644 index 00000000..3eaf75ae --- /dev/null +++ "b/\351\241\271\347\233\256\346\212\245\345\221\212.md" @@ -0,0 +1,229 @@ +# 项目报告:在 LLAISYS 中集成 CUDA + +## 项目概述 + +本项目在 LLAISYS 推理框架中实现了完整的 NVIDIA GPU 后端,使 Qwen2/DeepSeek-R1 模型能够在 CUDA 设备上运行推理。 + +------ + +## 一、框架架构理解 + +LLAISYS 是一个支持同构硬件的推理框架,核心设计如下: + + + +``` +Context(线程唯一) + └── Runtime(每设备唯一,延迟初始化) + ├── LlaisysRuntimeAPI(设备 API 函数表) + ├── NaiveAllocator(内存分配器) + └── DeviceResource(设备专属资源,如 cuBLAS handle) +``` + +每个线程同一时间只激活一个设备,通过 `context().setDevice()` 切换。 + +------ + +## 二、实现步骤 + +### 步骤 1:配置 xmake 编译系统 + +新建 `xmake/nvidia.lua`,定义两个编译目标: + + + +```lua +-- 设备层:编译 src/device/nvidia/*.cu +target("llaisys-device-nvidia") + set_kind("static") + add_cugencodes("native") + add_syslinks("cudart") + add_files("../src/device/nvidia/*.cu") + +-- 算子层:编译 src/ops/*/nvidia/*.cu +target("llaisys-ops-nvidia") + set_kind("static") + add_cugencodes("native") + add_syslinks("cudart", "cublas") + add_files("../src/ops/*/nvidia/*.cu") +``` + +在 `xmake.lua` 中通过 `--nv-gpu=y` 选项控制是否启用 CUDA 编译,并定义 `ENABLE_NVIDIA_API` 宏。 + +------ + +### 步骤 2:实现 CUDA Runtime API + +**文件:`src/device/nvidia/nvidia_runtime_api.cu`** + +实现 `LlaisysRuntimeAPI` 函数表中的所有接口,对应 CUDA Runtime API: + +| LLAISYS 接口 | CUDA API | +| -------------------- | --------------------------------------------- | +| `malloc_device` | `cudaMalloc` | +| `free_device` | `cudaFree` | +| `malloc_host` | `cudaMallocHost`(页锁定内存,加速 H2D 传输) | +| `free_host` | `cudaFreeHost` | +| `memcpy_sync` | `cudaMemcpy` | +| `memcpy_async` | `cudaMemcpyAsync` | +| `create_stream` | `cudaStreamCreate` | +| `device_synchronize` | `cudaDeviceSynchronize` | + +**文件:`src/device/nvidia/nvidia_resource.cu`** + +管理设备专属资源,在构造时初始化 cuBLAS handle,供 `linear` 算子使用: + + + +```cpp +Resource::Resource(int device_id) { + cudaSetDevice(device_id); + cublasCreate(&_cublas_handle); +} +``` + +------ + +### 步骤 3:实现 7 个 CUDA 算子 + +每个算子在 `src/ops//nvidia/` 下实现,通过 dtype 分发支持 F32/F16/BF16。 + +#### add(向量加法) + +简单的 element-wise 加法 kernel,每线程处理一个元素。 + +#### embedding(词嵌入查表) + +每线程负责输出的一个元素,根据 token index 从权重矩阵中读取对应行。 + +#### argmax(取最大值索引) + +使用 warp reduce + block reduce 两级归约,每个 block 处理一行,输出最大值及其索引。 + +#### rms_norm(RMS 归一化) + +每个 block 处理一行(一个 token),使用 warp/block reduce 计算平方和,再归一化并乘以权重。 + +#### rope(旋转位置编码) + +每个 block 处理一个 `(token, head)` 对,对 head_dim 的前后两半分别做旋转变换: + + + +``` +out[i] = x[i] * cos(m * θ_i) - x[i + d/2] * sin(m * θ_i) +out[i + d/2] = x[i + d/2] * cos(m * θ_i) + x[i] * sin(m * θ_i) +``` + +#### linear(矩阵乘法) + +调用 cuBLAS `cublasGemmEx`,使用 F32 累加精度(`CUBLAS_COMPUTE_32F`)保证数值稳定性,支持可选 bias 的 element-wise 加法。 + +#### swiglu(SwiGLU 激活) + +Element-wise 计算:`out = up * silu(gate)`,其中 `silu(x) = x / (1 + e^{-x})`。 + +#### self_attention(Flash Attention) + +实现 Flash Attention 算法,避免 O(N²) 的显存占用: + +- 使用 shared memory 存储 Q/K/V/O tile,避免寄存器溢出 +- 支持 GQA(Grouped Query Attention),`kv_head_idx = head_idx / (nh / nkvh)` +- 支持 causal mask(因果掩码) +- Tile 大小 `Br=Bc=16`,shared memory 占用约 32KB(安全低于 48KB 限制) + +------ + +### 步骤 4:修改模型支持 CUDA 推理 + +**文件:`src/models/qwen2/qwen2.cpp`** + +`forward()` 函数的关键修改: + +1. **输入数据上传**:`token_ids` 和 `pos_ids` 从 CPU 通过 `H2D memcpy` 上传到 GPU +2. **所有中间张量**在 GPU 上分配(`Tensor::create(..., device, device_id)`) +3. **KV Cache 更新**:使用 `D2D memcpy` 将当前步的 K/V 写入 cache slice +4. **结果读回**:最终 argmax 结果通过 `D2H memcpy` 读回 CPU + +------ + +## 三、关键 Bug 修复 + +### Bug:self_attention kernel 非法内存访问 + +**现象**:`CUDA Error: an illegal memory access was encountered (700)` + +**根本原因**:原始实现将 Q 和 O 存储在寄存器数组 `float r_Q[128]` 中,每线程 256 个 float 寄存器,远超 GPU 寄存器文件容量,导致溢出到 local memory 并越界。同时 `Bc=64` 时 shared memory 需要 64KB,超过硬件限制(48KB)。 + +**修复**:将 `r_Q`、`r_O` 移入 shared memory(`s_Q`、`s_O`),并将 `Br=Bc` 从 64 降至 16,shared memory 占用降至 ~32KB。 + +------ + +## 四、复现流程 + +### 环境要求 + +- NVIDIA GPU(Compute Capability ≥ 8.0 推荐,RTX 3050 可用) +- CUDA Toolkit(`/usr/local/cuda`) +- xmake 构建工具 +- Python 3.x + PyTorch(用于对比测试) + +### 模型下载 + + + +```bash +huggingface-cli download deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B +``` + +### 编译 + + + +```bash +# 启用 CUDA 支持并编译 +xmake f --nv-gpu=y -cv +xmake +xmake install +``` + +### 测试 + + + +```bash +# Runtime API 测试(内存分配、拷贝) +python test/test_runtime.py --device nvidia + +# 推理正确性测试(与 HuggingFace 输出对比) +python test/test_infer.py \ + --model ~/.cache/huggingface/hub/models--deepseek-ai--DeepSeek-R1-Distill-Qwen-1.5B/snapshots/ \ + --test \ + --device nvidia +``` + +------ + +## 五、项目文件结构 + + + +``` +src/ +├── device/nvidia/ +│ ├── nvidia_runtime_api.cu # CUDA Runtime API 实现 +│ ├── nvidia_resource.cu # cuBLAS handle 管理 +│ └── nvidia_resource.cuh +├── ops/ +│ ├── add/nvidia/add_nv.cu +│ ├── argmax/nvidia/argmax_nv.cu +│ ├── embedding/nvidia/embedding_nv.cu +│ ├── linear/nvidia/linear_nv.cu # cuBLAS GemmEx +│ ├── rms_norm/nvidia/rms_norm_nv.cu +│ ├── rope/nvidia/rope_nv.cu +│ ├── self_attention/nvidia/self_attention_nv.cu # Flash Attention +│ └── swiglu/nvidia/swiglu_nv.cu +├── models/qwen2/qwen2.cpp # 支持 GPU 推理的 forward() +xmake/ +└── nvidia.lua # CUDA 编译配置 +``` \ No newline at end of file