diff --git a/python/llaisys/libllaisys/models.py b/python/llaisys/libllaisys/models.py new file mode 100644 index 000000000..6901bc7c5 --- /dev/null +++ b/python/llaisys/libllaisys/models.py @@ -0,0 +1,49 @@ +from ctypes import c_void_p, c_int64, c_size_t, c_float, c_int, POINTER, Structure +from . import LIB_LLAISYS + +class Qwen2Meta(Structure): + _fields_ = [ + ("dtype", c_int), + ("nlayer", c_size_t), + ("hs", c_size_t), + ("nh", c_size_t), + ("nkvh", c_size_t), + ("dh", c_size_t), + ("di", c_size_t), + ("maxseq", c_size_t), + ("voc", c_size_t), + ("epsilon", c_float), + ("theta", c_float), + ("end_token", c_int64), + ] + +class Qwen2Weights(Structure): + _fields_ = [ + ("in_embed", c_void_p), + ("out_embed", c_void_p), + ("out_norm_w", c_void_p), + ("attn_norm_w", POINTER(c_void_p)), + ("attn_q_w", POINTER(c_void_p)), + ("attn_q_b", POINTER(c_void_p)), + ("attn_k_w", POINTER(c_void_p)), + ("attn_k_b", POINTER(c_void_p)), + ("attn_v_w", POINTER(c_void_p)), + ("attn_v_b", POINTER(c_void_p)), + ("attn_o_w", POINTER(c_void_p)), + ("mlp_norm_w", POINTER(c_void_p)), + ("mlp_gate_w", POINTER(c_void_p)), + ("mlp_up_w", POINTER(c_void_p)), + ("mlp_down_w", POINTER(c_void_p)), + ] + +LIB_LLAISYS.llaisysQwen2ModelCreate.argtypes = [POINTER(Qwen2Meta), c_int, POINTER(c_int), c_int] +LIB_LLAISYS.llaisysQwen2ModelCreate.restype = c_void_p + +LIB_LLAISYS.llaisysQwen2ModelDestroy.argtypes = [c_void_p] +LIB_LLAISYS.llaisysQwen2ModelDestroy.restype = None + +LIB_LLAISYS.llaisysQwen2ModelWeights.argtypes = [c_void_p] +LIB_LLAISYS.llaisysQwen2ModelWeights.restype = POINTER(Qwen2Weights) + +LIB_LLAISYS.llaisysQwen2ModelInfer.argtypes = [c_void_p, POINTER(c_int64), c_size_t] +LIB_LLAISYS.llaisysQwen2ModelInfer.restype = c_int64 diff --git a/python/llaisys/models/qwen2.py b/python/llaisys/models/qwen2.py index 0d07b0b21..c2d44887e 100644 --- a/python/llaisys/models/qwen2.py +++ b/python/llaisys/models/qwen2.py @@ -1,23 +1,85 @@ from typing import Sequence from ..libllaisys import LIB_LLAISYS -from ..libllaisys import DeviceType - +from ..libllaisys import DeviceType, DataType +from ..libllaisys.models import Qwen2Meta, Qwen2Weights from pathlib import Path import safetensors - +import json +from ctypes import c_int64, c_int class Qwen2: def __init__(self, model_path, device: DeviceType = DeviceType.CPU): # TODO: Implement model constructor - model_path = Path(model_path) - + + with open(model_path / "config.json") as f: + config = json.load(f) + + meta = Qwen2Meta() + meta.dtype = DataType.F32 + meta.nlayer = config["num_hidden_layers"] + meta.hs = config["hidden_size"] + meta.nh = config["num_attention_heads"] + meta.nkvh = config["num_key_value_heads"] + meta.dh = config["hidden_size"] // config["num_attention_heads"] + meta.di = config["intermediate_size"] + meta.maxseq = config.get("max_position_embeddings", 32768) + meta.voc = config["vocab_size"] + meta.epsilon = config["rms_norm_eps"] + meta.theta = config.get("rope_theta", 10000.0) + meta.end_token = config.get("eos_token_id", 151643) + + self.model = LIB_LLAISYS.llaisysQwen2ModelCreate(meta, device, None, 0) + self.weights_ptr = LIB_LLAISYS.llaisysQwen2ModelWeights(self.model) + self.weights = self.weights_ptr.contents + self.nlayer = meta.nlayer + self.end_token = meta.end_token + for file in sorted(model_path.glob("*.safetensors")): - data_ = safetensors.safe_open(file, framework="numpy", device="cpu") + data_ = safetensors.safe_open(file, framework="pt", device="cpu") for name_ in data_.keys(): ## TODO: load the model weights - pass + tensor_data = data_.get_tensor(name_) + if tensor_data.dtype.is_floating_point and tensor_data.dtype != tensor_data.float().dtype: + tensor_data = tensor_data.float() + tensor_data = tensor_data.numpy() + self._load_weight(name_, tensor_data) + + def _load_weight(self, name, data): + if name == "model.embed_tokens.weight": + LIB_LLAISYS.tensorLoad(self.weights.in_embed, data.ctypes.data) + elif name == "lm_head.weight": + LIB_LLAISYS.tensorLoad(self.weights.out_embed, data.ctypes.data) + elif name == "model.norm.weight": + LIB_LLAISYS.tensorLoad(self.weights.out_norm_w, data.ctypes.data) + elif "layers" in name: + parts = name.split(".") + layer_idx = int(parts[2]) + if "input_layernorm.weight" in name: + LIB_LLAISYS.tensorLoad(self.weights.attn_norm_w[layer_idx], data.ctypes.data) + elif "self_attn.q_proj.weight" in name: + LIB_LLAISYS.tensorLoad(self.weights.attn_q_w[layer_idx], data.ctypes.data) + elif "self_attn.q_proj.bias" in name: + LIB_LLAISYS.tensorLoad(self.weights.attn_q_b[layer_idx], data.ctypes.data) + elif "self_attn.k_proj.weight" in name: + LIB_LLAISYS.tensorLoad(self.weights.attn_k_w[layer_idx], data.ctypes.data) + elif "self_attn.k_proj.bias" in name: + LIB_LLAISYS.tensorLoad(self.weights.attn_k_b[layer_idx], data.ctypes.data) + elif "self_attn.v_proj.weight" in name: + LIB_LLAISYS.tensorLoad(self.weights.attn_v_w[layer_idx], data.ctypes.data) + elif "self_attn.v_proj.bias" in name: + LIB_LLAISYS.tensorLoad(self.weights.attn_v_b[layer_idx], data.ctypes.data) + elif "self_attn.o_proj.weight" in name: + LIB_LLAISYS.tensorLoad(self.weights.attn_o_w[layer_idx], data.ctypes.data) + elif "post_attention_layernorm.weight" in name: + LIB_LLAISYS.tensorLoad(self.weights.mlp_norm_w[layer_idx], data.ctypes.data) + elif "mlp.gate_proj.weight" in name: + LIB_LLAISYS.tensorLoad(self.weights.mlp_gate_w[layer_idx], data.ctypes.data) + elif "mlp.up_proj.weight" in name: + LIB_LLAISYS.tensorLoad(self.weights.mlp_up_w[layer_idx], data.ctypes.data) + elif "mlp.down_proj.weight" in name: + LIB_LLAISYS.tensorLoad(self.weights.mlp_down_w[layer_idx], data.ctypes.data) def generate( self, @@ -27,7 +89,16 @@ def generate( top_p: float = 0.8, temperature: float = 0.8, ): - # TODO: Implement generate function - - return [] + tokens = list(inputs) + for _ in range(max_new_tokens or 128): + token_array = (c_int64 * len(tokens))(*tokens) + next_token = LIB_LLAISYS.llaisysQwen2ModelInfer(self.model, token_array, len(tokens)) + tokens.append(next_token) + if next_token == self.end_token: + break + return tokens + + def __del__(self): + if hasattr(self, 'model'): + LIB_LLAISYS.llaisysQwen2ModelDestroy(self.model) diff --git a/src/llaisys/qwen2.cc b/src/llaisys/qwen2.cc new file mode 100644 index 000000000..7effc4235 --- /dev/null +++ b/src/llaisys/qwen2.cc @@ -0,0 +1,31 @@ +#include "llaisys/models/qwen2.h" +#include "../models/qwen2/qwen2_model.hpp" + +__C { + +struct LlaisysQwen2Model { + llaisys::models::Qwen2Model *model; +}; + +__export struct LlaisysQwen2Model *llaisysQwen2ModelCreate(const LlaisysQwen2Meta *meta, llaisysDeviceType_t device, int *device_ids, int ndevice) { + auto model = new LlaisysQwen2Model; + model->model = new llaisys::models::Qwen2Model(meta, device, device_ids ? device_ids[0] : 0); + return model; +} + +__export void llaisysQwen2ModelDestroy(struct LlaisysQwen2Model *model) { + delete model->model; + delete model; +} + +__export struct LlaisysQwen2Weights *llaisysQwen2ModelWeights(struct LlaisysQwen2Model *model) { + auto weights = new LlaisysQwen2Weights; + *weights = model->model->getWeights(); + return weights; +} + +__export int64_t llaisysQwen2ModelInfer(struct LlaisysQwen2Model *model, int64_t *token_ids, size_t ntoken) { + return model->model->infer(token_ids, ntoken); +} + +} diff --git a/src/models/qwen2/qwen2_model.cpp b/src/models/qwen2/qwen2_model.cpp new file mode 100644 index 000000000..90c7f2fdc --- /dev/null +++ b/src/models/qwen2/qwen2_model.cpp @@ -0,0 +1,184 @@ +#include "qwen2_model.hpp" +#include "../../llaisys/llaisys_tensor.hpp" +#include "../../ops/embedding/op.hpp" +#include "../../ops/linear/op.hpp" +#include "../../ops/rms_norm/op.hpp" +#include "../../ops/rope/op.hpp" +#include "../../ops/self_attention/op.hpp" +#include "../../ops/swiglu/op.hpp" +#include "../../ops/argmax/op.hpp" +#include "../../ops/add/op.hpp" +#include "../../ops/rearrange/op.hpp" +#include "../../utils.hpp" +#include +#include + +namespace llaisys::models { + +Qwen2Model::Qwen2Model(const LlaisysQwen2Meta *meta_, llaisysDeviceType_t device, int dev_id) + : device_type(device), device_id(dev_id), cur_seq_len(0) { + meta = *meta_; + + in_embed = Tensor::create({meta.voc, meta.hs}, meta.dtype, device, dev_id); + out_embed = Tensor::create({meta.voc, meta.hs}, meta.dtype, device, dev_id); + out_norm_w = Tensor::create({meta.hs}, meta.dtype, device, dev_id); + + for (size_t i = 0; i < meta.nlayer; i++) { + attn_norm_w.push_back(Tensor::create({meta.hs}, meta.dtype, device, dev_id)); + attn_q_w.push_back(Tensor::create({meta.nh * meta.dh, meta.hs}, meta.dtype, device, dev_id)); + attn_q_b.push_back(Tensor::create({meta.nh * meta.dh}, meta.dtype, device, dev_id)); + attn_k_w.push_back(Tensor::create({meta.nkvh * meta.dh, meta.hs}, meta.dtype, device, dev_id)); + attn_k_b.push_back(Tensor::create({meta.nkvh * meta.dh}, meta.dtype, device, dev_id)); + attn_v_w.push_back(Tensor::create({meta.nkvh * meta.dh, meta.hs}, meta.dtype, device, dev_id)); + attn_v_b.push_back(Tensor::create({meta.nkvh * meta.dh}, meta.dtype, device, dev_id)); + attn_o_w.push_back(Tensor::create({meta.hs, meta.nh * meta.dh}, meta.dtype, device, dev_id)); + + mlp_norm_w.push_back(Tensor::create({meta.hs}, meta.dtype, device, dev_id)); + mlp_gate_w.push_back(Tensor::create({meta.di, meta.hs}, meta.dtype, device, dev_id)); + mlp_up_w.push_back(Tensor::create({meta.di, meta.hs}, meta.dtype, device, dev_id)); + mlp_down_w.push_back(Tensor::create({meta.hs, meta.di}, meta.dtype, device, dev_id)); + + k_cache.push_back(Tensor::create({meta.maxseq, meta.nkvh, meta.dh}, meta.dtype, device, dev_id)); + v_cache.push_back(Tensor::create({meta.maxseq, meta.nkvh, meta.dh}, meta.dtype, device, dev_id)); + } +} + +LlaisysQwen2Weights Qwen2Model::getWeights() { + LlaisysQwen2Weights weights; + weights.in_embed = new LlaisysTensor{in_embed}; + weights.out_embed = new LlaisysTensor{out_embed}; + weights.out_norm_w = new LlaisysTensor{out_norm_w}; + + weights.attn_norm_w = new llaisysTensor_t[meta.nlayer]; + weights.attn_q_w = new llaisysTensor_t[meta.nlayer]; + weights.attn_q_b = new llaisysTensor_t[meta.nlayer]; + weights.attn_k_w = new llaisysTensor_t[meta.nlayer]; + weights.attn_k_b = new llaisysTensor_t[meta.nlayer]; + weights.attn_v_w = new llaisysTensor_t[meta.nlayer]; + weights.attn_v_b = new llaisysTensor_t[meta.nlayer]; + weights.attn_o_w = new llaisysTensor_t[meta.nlayer]; + weights.mlp_norm_w = new llaisysTensor_t[meta.nlayer]; + weights.mlp_gate_w = new llaisysTensor_t[meta.nlayer]; + weights.mlp_up_w = new llaisysTensor_t[meta.nlayer]; + weights.mlp_down_w = new llaisysTensor_t[meta.nlayer]; + + for (size_t i = 0; i < meta.nlayer; i++) { + weights.attn_norm_w[i] = new LlaisysTensor{attn_norm_w[i]}; + weights.attn_q_w[i] = new LlaisysTensor{attn_q_w[i]}; + weights.attn_q_b[i] = new LlaisysTensor{attn_q_b[i]}; + weights.attn_k_w[i] = new LlaisysTensor{attn_k_w[i]}; + weights.attn_k_b[i] = new LlaisysTensor{attn_k_b[i]}; + weights.attn_v_w[i] = new LlaisysTensor{attn_v_w[i]}; + weights.attn_v_b[i] = new LlaisysTensor{attn_v_b[i]}; + weights.attn_o_w[i] = new LlaisysTensor{attn_o_w[i]}; + weights.mlp_norm_w[i] = new LlaisysTensor{mlp_norm_w[i]}; + weights.mlp_gate_w[i] = new LlaisysTensor{mlp_gate_w[i]}; + weights.mlp_up_w[i] = new LlaisysTensor{mlp_up_w[i]}; + weights.mlp_down_w[i] = new LlaisysTensor{mlp_down_w[i]}; + } + + return weights; +} + +int64_t Qwen2Model::infer(int64_t *token_ids, size_t ntoken) { + //向前传播 + size_t seqlen = ntoken - cur_seq_len; + + //embedding + auto idx_tensor = Tensor::create({seqlen}, LLAISYS_DTYPE_I64, device_type, device_id); + idx_tensor->load(token_ids + cur_seq_len); + auto x = Tensor::create({seqlen, meta.hs}, meta.dtype, device_type, device_id); + ops::embedding(x, idx_tensor, in_embed); + + //位置编码ID + std::vector pos_ids_vec(seqlen); + for (size_t i = 0; i < seqlen; i++) pos_ids_vec[i] = cur_seq_len + i; + auto pos_ids = Tensor::create({seqlen}, LLAISYS_DTYPE_I64, device_type, device_id); + pos_ids->load(pos_ids_vec.data()); + + //Transformer层 + for (size_t layer = 0; layer < meta.nlayer; layer++) { + + auto x_norm = Tensor::create({seqlen, meta.hs}, meta.dtype, device_type, device_id); + ops::rms_norm(x_norm, x, attn_norm_w[layer], meta.epsilon); + + auto q = Tensor::create({seqlen, meta.nh * meta.dh}, meta.dtype, device_type, device_id); + auto k = Tensor::create({seqlen, meta.nkvh * meta.dh}, meta.dtype, device_type, device_id); + auto v = Tensor::create({seqlen, meta.nkvh * meta.dh}, meta.dtype, device_type, device_id); + ops::linear(q, x_norm, attn_q_w[layer], attn_q_b[layer]); + ops::linear(k, x_norm, attn_k_w[layer], attn_k_b[layer]); + ops::linear(v, x_norm, attn_v_w[layer], attn_v_b[layer]); + + //重塑 + q = q->view({seqlen, meta.nh, meta.dh}); + k = k->view({seqlen, meta.nkvh, meta.dh}); + v = v->view({seqlen, meta.nkvh, meta.dh}); + + //rope + auto q_rope = Tensor::create({seqlen, meta.nh, meta.dh}, meta.dtype, device_type, device_id); + auto k_rope = Tensor::create({seqlen, meta.nkvh, meta.dh}, meta.dtype, device_type, device_id); + ops::rope(q_rope, q, pos_ids, meta.theta); + ops::rope(k_rope, k, pos_ids, meta.theta); + + //更新KV cache + auto k_cache_slice = k_cache[layer]->slice(0, cur_seq_len, cur_seq_len + seqlen); + auto v_cache_slice = v_cache[layer]->slice(0, cur_seq_len, cur_seq_len + seqlen); + ops::rearrange(k_cache_slice, k_rope); + ops::rearrange(v_cache_slice, v); + + auto k_full = k_cache[layer]->slice(0, 0, cur_seq_len + seqlen); + auto v_full = v_cache[layer]->slice(0, 0, cur_seq_len + seqlen); + + //self attention + auto attn_out = Tensor::create({seqlen, meta.nh, meta.dh}, meta.dtype, device_type, device_id); + float scale = 1.0f / std::sqrt(static_cast(meta.dh)); + ops::self_attention(attn_out, q_rope, k_full, v_full, scale); + + attn_out = attn_out->view({seqlen, meta.nh * meta.dh}); + auto attn_proj = Tensor::create({seqlen, meta.hs}, meta.dtype, device_type, device_id); + ops::linear(attn_proj, attn_out, attn_o_w[layer], nullptr); + + ops::add(x, x, attn_proj); + + //MLP + auto x_mlp = Tensor::create({seqlen, meta.hs}, meta.dtype, device_type, device_id); + ops::rms_norm(x_mlp, x, mlp_norm_w[layer], meta.epsilon); + + auto gate = Tensor::create({seqlen, meta.di}, meta.dtype, device_type, device_id); + auto up = Tensor::create({seqlen, meta.di}, meta.dtype, device_type, device_id); + ops::linear(gate, x_mlp, mlp_gate_w[layer], nullptr); + ops::linear(up, x_mlp, mlp_up_w[layer], nullptr); + + auto mlp_out = Tensor::create({seqlen, meta.di}, meta.dtype, device_type, device_id); + ops::swiglu(mlp_out, gate, up); + + auto mlp_proj = Tensor::create({seqlen, meta.hs}, meta.dtype, device_type, device_id); + ops::linear(mlp_proj, mlp_out, mlp_down_w[layer], nullptr); + + // residual + ops::add(x, x, mlp_proj); + } + + //归一化 + auto x_final = Tensor::create({seqlen, meta.hs}, meta.dtype, device_type, device_id); + ops::rms_norm(x_final, x, out_norm_w, meta.epsilon); + + //用最后一个预测 + auto last_hidden = x_final->slice(0, seqlen - 1, seqlen); + auto logits = Tensor::create({1, meta.voc}, meta.dtype, device_type, device_id); + ops::linear(logits, last_hidden, out_embed, nullptr); + + //argmax + auto max_idx = Tensor::create({1}, LLAISYS_DTYPE_I64, device_type, device_id); + auto max_val = Tensor::create({1}, meta.dtype, device_type, device_id); + ops::argmax(max_idx, max_val, logits->view({meta.voc})); + + int64_t result; + std::byte *data = max_idx->data(); + std::memcpy(&result, data, sizeof(int64_t)); + + cur_seq_len += seqlen; + return result; +} + +} diff --git a/src/models/qwen2/qwen2_model.hpp b/src/models/qwen2/qwen2_model.hpp new file mode 100644 index 000000000..a82450e55 --- /dev/null +++ b/src/models/qwen2/qwen2_model.hpp @@ -0,0 +1,26 @@ +#pragma once +#include "../../tensor/tensor.hpp" +#include "llaisys/models/qwen2.h" +#include + +namespace llaisys::models { + +class Qwen2Model { +private: + LlaisysQwen2Meta meta; + llaisysDeviceType_t device_type; + int device_id; + + tensor_t in_embed, out_embed, out_norm_w; + std::vector attn_norm_w, attn_q_w, attn_q_b, attn_k_w, attn_k_b, attn_v_w, attn_v_b, attn_o_w; + std::vector mlp_norm_w, mlp_gate_w, mlp_up_w, mlp_down_w; + std::vector k_cache, v_cache; + size_t cur_seq_len; + +public: + Qwen2Model(const LlaisysQwen2Meta *meta, llaisysDeviceType_t device, int device_id); + LlaisysQwen2Weights getWeights(); + int64_t infer(int64_t *token_ids, size_t ntoken); +}; + +} diff --git a/src/ops/argmax/cpu/argmax_cpu.cpp b/src/ops/argmax/cpu/argmax_cpu.cpp new file mode 100644 index 000000000..a3911e5cb --- /dev/null +++ b/src/ops/argmax/cpu/argmax_cpu.cpp @@ -0,0 +1,31 @@ +#include "argmax_cpu.hpp" +#include "../../../utils.hpp" +#include + +template +void argmax_(int64_t *idx, T *val, const T *vals, size_t n) { + float max = -std::numeric_limits::infinity(); + size_t pos = 0; + for (size_t i = 0; i < n; i++) { + float v = llaisys::utils::cast(vals[i]); + if (v > max) { max = v; pos = i; } + } + idx[0] = pos; + val[0] = llaisys::utils::cast(max); +} + +namespace llaisys::ops::cpu { +void argmax(std::byte *max_idx, std::byte *max_val, const std::byte *vals, llaisysDataType_t type, size_t n) { + auto idx = reinterpret_cast(max_idx); + switch (type) { + case LLAISYS_DTYPE_F32: + return argmax_(idx, reinterpret_cast(max_val), reinterpret_cast(vals), n); + case LLAISYS_DTYPE_BF16: + return argmax_(idx, reinterpret_cast(max_val), reinterpret_cast(vals), n); + case LLAISYS_DTYPE_F16: + return argmax_(idx, reinterpret_cast(max_val), reinterpret_cast(vals), n); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} diff --git a/src/ops/argmax/cpu/argmax_cpu.hpp b/src/ops/argmax/cpu/argmax_cpu.hpp new file mode 100644 index 000000000..32ea738d4 --- /dev/null +++ b/src/ops/argmax/cpu/argmax_cpu.hpp @@ -0,0 +1,7 @@ +#pragma once +#include "llaisys.h" +#include + +namespace llaisys::ops::cpu { +void argmax(std::byte *max_idx, std::byte *max_val, const std::byte *vals, llaisysDataType_t type, size_t size); +} diff --git a/src/ops/argmax/op.cpp b/src/ops/argmax/op.cpp index 6dc37d426..9bd964500 100644 --- a/src/ops/argmax/op.cpp +++ b/src/ops/argmax/op.cpp @@ -1,7 +1,28 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" +#include "cpu/argmax_cpu.hpp" namespace llaisys::ops { void argmax(tensor_t max_idx, tensor_t max_val, tensor_t vals) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(max_idx, max_val, vals); + + if (vals->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::argmax(max_idx->data(), max_val->data(), vals->data(), vals->dtype(), vals->numel()); + } + + llaisys::core::context().setDevice(vals->deviceType(), vals->deviceId()); + + switch (vals->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::argmax(max_idx->data(), max_val->data(), vals->data(), vals->dtype(), vals->numel()); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + TO_BE_IMPLEMENTED(); + return; +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/embedding/cpu/embedding_cpu.cpp b/src/ops/embedding/cpu/embedding_cpu.cpp new file mode 100644 index 000000000..b7455c3e6 --- /dev/null +++ b/src/ops/embedding/cpu/embedding_cpu.cpp @@ -0,0 +1,30 @@ +#include "embedding_cpu.hpp" +#include "../../../utils.hpp" + +template +void embedding_(T *out, const int64_t *index, const T *weight, size_t idx_size, size_t embd_dim) { + for (size_t i = 0; i < idx_size; i++) { + int64_t idx = index[i]; + const T *src = weight + idx * embd_dim; + T *dst = out + i * embd_dim; + for (size_t j = 0; j < embd_dim; j++) { + dst[j] = src[j]; + } + } +} + +namespace llaisys::ops::cpu { +void embedding(std::byte *out, const std::byte *index, const std::byte *weight, llaisysDataType_t type, size_t idx_size, size_t embd_dim) { + auto idx = reinterpret_cast(index); + switch (type) { + case LLAISYS_DTYPE_F32: + return embedding_(reinterpret_cast(out), idx, reinterpret_cast(weight), idx_size, embd_dim); + case LLAISYS_DTYPE_BF16: + return embedding_(reinterpret_cast(out), idx, reinterpret_cast(weight), idx_size, embd_dim); + case LLAISYS_DTYPE_F16: + return embedding_(reinterpret_cast(out), idx, reinterpret_cast(weight), idx_size, embd_dim); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} diff --git a/src/ops/embedding/cpu/embedding_cpu.hpp b/src/ops/embedding/cpu/embedding_cpu.hpp new file mode 100644 index 000000000..669f3c121 --- /dev/null +++ b/src/ops/embedding/cpu/embedding_cpu.hpp @@ -0,0 +1,7 @@ +#pragma once +#include "llaisys.h" +#include + +namespace llaisys::ops::cpu { +void embedding(std::byte *out, const std::byte *index, const std::byte *weight, llaisysDataType_t type, size_t idx_size, size_t embd_dim); +} diff --git a/src/ops/embedding/op.cpp b/src/ops/embedding/op.cpp index 84b9a5d06..c4cbd1f36 100644 --- a/src/ops/embedding/op.cpp +++ b/src/ops/embedding/op.cpp @@ -1,7 +1,28 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" +#include "cpu/embedding_cpu.hpp" namespace llaisys::ops { void embedding(tensor_t out, tensor_t index, tensor_t weight) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(out, index, weight); + + if (out->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::embedding(out->data(), index->data(), weight->data(), out->dtype(), index->numel(), weight->shape()[1]); + } + + llaisys::core::context().setDevice(out->deviceType(), out->deviceId()); + + switch (out->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::embedding(out->data(), index->data(), weight->data(), out->dtype(), index->numel(), weight->shape()[1]); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + TO_BE_IMPLEMENTED(); + return; +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/linear/cpu/linear_cpu.cpp b/src/ops/linear/cpu/linear_cpu.cpp new file mode 100644 index 000000000..b2c883034 --- /dev/null +++ b/src/ops/linear/cpu/linear_cpu.cpp @@ -0,0 +1,44 @@ +#include "linear_cpu.hpp" +#include "../../../utils.hpp" +#include + +template +void linear_(T *out, const T *in, const T *weight, const T *bias, size_t batch, size_t in_dim, size_t out_dim) { + #pragma omp parallel for + for (int b = 0; b < static_cast(batch); b++) { + for (size_t o = 0; o < out_dim; o++) { + float sum = 0.0f; + for (size_t i = 0; i < in_dim; i++) { + float x = llaisys::utils::cast(in[b * in_dim + i]); + float w = llaisys::utils::cast(weight[o * in_dim + i]); + sum += x * w; + } + if (bias) { + sum += llaisys::utils::cast(bias[o]); + } + out[b * out_dim + o] = llaisys::utils::cast(sum); + } + } +} + +namespace llaisys::ops::cpu { +void linear(std::byte *out, const std::byte *in, const std::byte *weight, const std::byte *bias, + llaisysDataType_t type, size_t batch, size_t in_dim, size_t out_dim) { + switch (type) { + case LLAISYS_DTYPE_F32: + return linear_(reinterpret_cast(out), reinterpret_cast(in), + reinterpret_cast(weight), reinterpret_cast(bias), + batch, in_dim, out_dim); + case LLAISYS_DTYPE_BF16: + return linear_(reinterpret_cast(out), reinterpret_cast(in), + reinterpret_cast(weight), reinterpret_cast(bias), + batch, in_dim, out_dim); + case LLAISYS_DTYPE_F16: + return linear_(reinterpret_cast(out), reinterpret_cast(in), + reinterpret_cast(weight), reinterpret_cast(bias), + batch, in_dim, out_dim); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} diff --git a/src/ops/linear/cpu/linear_cpu.hpp b/src/ops/linear/cpu/linear_cpu.hpp new file mode 100644 index 000000000..407eb3842 --- /dev/null +++ b/src/ops/linear/cpu/linear_cpu.hpp @@ -0,0 +1,8 @@ +#pragma once +#include "llaisys.h" +#include + +namespace llaisys::ops::cpu { +void linear(std::byte *out, const std::byte *in, const std::byte *weight, const std::byte *bias, + llaisysDataType_t type, size_t batch, size_t in_dim, size_t out_dim); +} diff --git a/src/ops/linear/op.cpp b/src/ops/linear/op.cpp index 97d1f8655..9206b612e 100644 --- a/src/ops/linear/op.cpp +++ b/src/ops/linear/op.cpp @@ -1,7 +1,38 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" +#include "cpu/linear_cpu.hpp" namespace llaisys::ops { void linear(tensor_t out, tensor_t in, tensor_t weight, tensor_t bias) { - TO_BE_IMPLEMENTED(); + if (bias) { + CHECK_SAME_DEVICE(out, in, weight, bias); + } else { + CHECK_SAME_DEVICE(out, in, weight); + } + + size_t batch = in->shape()[0]; + size_t in_dim = in->shape()[1]; + size_t out_dim = weight->shape()[0]; + + if (out->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::linear(out->data(), in->data(), weight->data(), bias ? bias->data() : nullptr, + out->dtype(), batch, in_dim, out_dim); + } + + llaisys::core::context().setDevice(out->deviceType(), out->deviceId()); + + switch (out->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::linear(out->data(), in->data(), weight->data(), bias ? bias->data() : nullptr, + out->dtype(), batch, in_dim, out_dim); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + TO_BE_IMPLEMENTED(); + return; +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/rearrange/cpu/rearrange_cpu.cpp b/src/ops/rearrange/cpu/rearrange_cpu.cpp new file mode 100644 index 000000000..f7f333df7 --- /dev/null +++ b/src/ops/rearrange/cpu/rearrange_cpu.cpp @@ -0,0 +1,27 @@ +#include "rearrange_cpu.hpp" +#include + +namespace llaisys::ops::cpu { + +void rearrange_recursive(std::byte *out, const std::byte *in, const std::vector &shape, + const std::vector &in_strides, size_t elem_size, + size_t dim, size_t &out_offset, size_t in_offset) { + if (dim == shape.size()) { + std::memcpy(out + out_offset, in + in_offset, elem_size); + out_offset += elem_size; + return; + } + + for (size_t i = 0; i < shape[dim]; i++) { + rearrange_recursive(out, in, shape, in_strides, elem_size, dim + 1, out_offset, + in_offset + i * in_strides[dim] * elem_size); + } +} + +void rearrange(std::byte *out, const std::byte *in, const std::vector &shape, + const std::vector &in_strides, size_t elem_size) { + size_t out_offset = 0; + rearrange_recursive(out, in, shape, in_strides, elem_size, 0, out_offset, 0); +} + +} diff --git a/src/ops/rearrange/cpu/rearrange_cpu.hpp b/src/ops/rearrange/cpu/rearrange_cpu.hpp new file mode 100644 index 000000000..97f5d14e4 --- /dev/null +++ b/src/ops/rearrange/cpu/rearrange_cpu.hpp @@ -0,0 +1,9 @@ +#pragma once +#include "llaisys.h" +#include +#include + +namespace llaisys::ops::cpu { +void rearrange(std::byte *out, const std::byte *in, const std::vector &shape, + const std::vector &in_strides, size_t elem_size); +} diff --git a/src/ops/rearrange/op.cpp b/src/ops/rearrange/op.cpp index 017a6ae59..ccb194c9e 100644 --- a/src/ops/rearrange/op.cpp +++ b/src/ops/rearrange/op.cpp @@ -1,7 +1,29 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" +#include "cpu/rearrange_cpu.hpp" namespace llaisys::ops { void rearrange(tensor_t out, tensor_t in) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(out, in); + CHECK_SAME_SHAPE(out->shape(), in->shape()); + + if (out->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::rearrange(out->data(), in->data(), in->shape(), in->strides(), in->elementSize()); + } + + llaisys::core::context().setDevice(out->deviceType(), out->deviceId()); + + switch (out->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::rearrange(out->data(), in->data(), in->shape(), in->strides(), in->elementSize()); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + TO_BE_IMPLEMENTED(); + return; +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/rms_norm/cpu/rms_norm_cpu.cpp b/src/ops/rms_norm/cpu/rms_norm_cpu.cpp new file mode 100644 index 000000000..bc7c28429 --- /dev/null +++ b/src/ops/rms_norm/cpu/rms_norm_cpu.cpp @@ -0,0 +1,39 @@ +#include "rms_norm_cpu.hpp" +#include "../../../utils.hpp" +#include + +template +void rms_norm_(T *out, const T *in, const T *weight, size_t batch, size_t dim, float eps) { + for (size_t b = 0; b < batch; b++) { + float sum_sq = 0.0f; + for (size_t i = 0; i < dim; i++) { + float val = llaisys::utils::cast(in[b * dim + i]); + sum_sq += val * val; + } + float rms = std::sqrt(sum_sq / dim + eps); + for (size_t i = 0; i < dim; i++) { + float val = llaisys::utils::cast(in[b * dim + i]); + float w = llaisys::utils::cast(weight[i]); + out[b * dim + i] = llaisys::utils::cast(w * val / rms); + } + } +} + +namespace llaisys::ops::cpu { +void rms_norm(std::byte *out, const std::byte *in, const std::byte *weight, llaisysDataType_t type, + size_t batch, size_t dim, float eps) { + switch (type) { + case LLAISYS_DTYPE_F32: + return rms_norm_(reinterpret_cast(out), reinterpret_cast(in), + reinterpret_cast(weight), batch, dim, eps); + case LLAISYS_DTYPE_BF16: + return rms_norm_(reinterpret_cast(out), reinterpret_cast(in), + reinterpret_cast(weight), batch, dim, eps); + case LLAISYS_DTYPE_F16: + return rms_norm_(reinterpret_cast(out), reinterpret_cast(in), + reinterpret_cast(weight), batch, dim, eps); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} diff --git a/src/ops/rms_norm/cpu/rms_norm_cpu.hpp b/src/ops/rms_norm/cpu/rms_norm_cpu.hpp new file mode 100644 index 000000000..321e05d5f --- /dev/null +++ b/src/ops/rms_norm/cpu/rms_norm_cpu.hpp @@ -0,0 +1,8 @@ +#pragma once +#include "llaisys.h" +#include + +namespace llaisys::ops::cpu { +void rms_norm(std::byte *out, const std::byte *in, const std::byte *weight, llaisysDataType_t type, + size_t batch, size_t dim, float eps); +} diff --git a/src/ops/rms_norm/op.cpp b/src/ops/rms_norm/op.cpp index 529553d9d..c3beb364c 100644 --- a/src/ops/rms_norm/op.cpp +++ b/src/ops/rms_norm/op.cpp @@ -1,7 +1,31 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" +#include "cpu/rms_norm_cpu.hpp" namespace llaisys::ops { void rms_norm(tensor_t out, tensor_t in, tensor_t weight, float eps) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(out, in, weight); + + size_t batch = in->shape()[0]; + size_t dim = in->shape()[1]; + + if (out->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::rms_norm(out->data(), in->data(), weight->data(), out->dtype(), batch, dim, eps); + } + + llaisys::core::context().setDevice(out->deviceType(), out->deviceId()); + + switch (out->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::rms_norm(out->data(), in->data(), weight->data(), out->dtype(), batch, dim, eps); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + TO_BE_IMPLEMENTED(); + return; +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/rope/cpu/rope_cpu.cpp b/src/ops/rope/cpu/rope_cpu.cpp new file mode 100644 index 000000000..c142aadcf --- /dev/null +++ b/src/ops/rope/cpu/rope_cpu.cpp @@ -0,0 +1,42 @@ +#include "rope_cpu.hpp" +#include "../../../utils.hpp" +#include + +template +void rope_(T *out, const T *in, const int64_t *pos_ids, size_t seqlen, size_t nhead, size_t d, float theta) { + size_t half_d = d / 2; + for (size_t s = 0; s < seqlen; s++) { + float pos = static_cast(pos_ids[s]); + for (size_t h = 0; h < nhead; h++) { + for (size_t j = 0; j < half_d; j++) { + float angle = pos / std::pow(theta, 2.0f * j / d); + float cos_val = std::cos(angle); + float sin_val = std::sin(angle); + + size_t idx = s * nhead * d + h * d; + float a = llaisys::utils::cast(in[idx + j]); + float b = llaisys::utils::cast(in[idx + j + half_d]); + + out[idx + j] = llaisys::utils::cast(a * cos_val - b * sin_val); + out[idx + j + half_d] = llaisys::utils::cast(b * cos_val + a * sin_val); + } + } + } +} + +namespace llaisys::ops::cpu { +void rope(std::byte *out, const std::byte *in, const std::byte *pos_ids, llaisysDataType_t type, + size_t seqlen, size_t nhead, size_t d, float theta) { + auto ids = reinterpret_cast(pos_ids); + switch (type) { + case LLAISYS_DTYPE_F32: + return rope_(reinterpret_cast(out), reinterpret_cast(in), ids, seqlen, nhead, d, theta); + case LLAISYS_DTYPE_BF16: + return rope_(reinterpret_cast(out), reinterpret_cast(in), ids, seqlen, nhead, d, theta); + case LLAISYS_DTYPE_F16: + return rope_(reinterpret_cast(out), reinterpret_cast(in), ids, seqlen, nhead, d, theta); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} diff --git a/src/ops/rope/cpu/rope_cpu.hpp b/src/ops/rope/cpu/rope_cpu.hpp new file mode 100644 index 000000000..0ff668bd1 --- /dev/null +++ b/src/ops/rope/cpu/rope_cpu.hpp @@ -0,0 +1,8 @@ +#pragma once +#include "llaisys.h" +#include + +namespace llaisys::ops::cpu { +void rope(std::byte *out, const std::byte *in, const std::byte *pos_ids, llaisysDataType_t type, + size_t seqlen, size_t nhead, size_t d, float theta); +} diff --git a/src/ops/rope/op.cpp b/src/ops/rope/op.cpp index d60dbe64e..cd6f22dce 100644 --- a/src/ops/rope/op.cpp +++ b/src/ops/rope/op.cpp @@ -1,7 +1,32 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" +#include "cpu/rope_cpu.hpp" namespace llaisys::ops { void rope(tensor_t out, tensor_t in, tensor_t pos_ids, float theta) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(out, in, pos_ids); + + size_t seqlen = in->shape()[0]; + size_t nhead = in->shape()[1]; + size_t d = in->shape()[2]; + + if (out->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::rope(out->data(), in->data(), pos_ids->data(), out->dtype(), seqlen, nhead, d, theta); + } + + llaisys::core::context().setDevice(out->deviceType(), out->deviceId()); + + switch (out->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::rope(out->data(), in->data(), pos_ids->data(), out->dtype(), seqlen, nhead, d, theta); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + TO_BE_IMPLEMENTED(); + return; +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/self_attention/cpu/self_attention_cpu.cpp b/src/ops/self_attention/cpu/self_attention_cpu.cpp new file mode 100644 index 000000000..ff4024b0b --- /dev/null +++ b/src/ops/self_attention/cpu/self_attention_cpu.cpp @@ -0,0 +1,83 @@ +#include "self_attention_cpu.hpp" +#include "../../../utils.hpp" +#include +#include +#include + +template +void self_attention_(T *attn_val, const T *q, const T *k, const T *v, + size_t qlen, size_t kvlen, size_t nhead, size_t nkvhead, size_t d, float scale) { + size_t head_repeat = nhead / nkvhead; + + for (size_t qi = 0; qi < qlen; qi++) { + for (size_t h = 0; h < nhead; h++) { + size_t kv_h = h / head_repeat; + std::vector scores(kvlen); + + // Compute Q * K^T * scale + for (size_t ki = 0; ki < kvlen; ki++) { + float sum = 0.0f; + for (size_t j = 0; j < d; j++) { + float q_val = llaisys::utils::cast(q[qi * nhead * d + h * d + j]); + float k_val = llaisys::utils::cast(k[ki * nkvhead * d + kv_h * d + j]); + sum += q_val * k_val; + } + scores[ki] = sum * scale; + } + + // Apply causal mask and softmax + float max_score = -std::numeric_limits::infinity(); + for (size_t ki = 0; ki < kvlen; ki++) { + if (ki <= kvlen - qlen + qi) { // causal mask + max_score = std::max(max_score, scores[ki]); + } + } + + float sum_exp = 0.0f; + for (size_t ki = 0; ki < kvlen; ki++) { + if (ki <= kvlen - qlen + qi) { + scores[ki] = std::exp(scores[ki] - max_score); + sum_exp += scores[ki]; + } else { + scores[ki] = 0.0f; + } + } + + for (size_t ki = 0; ki < kvlen; ki++) { + scores[ki] /= sum_exp; + } + + // Compute attention * V + for (size_t j = 0; j < d; j++) { + float sum = 0.0f; + for (size_t ki = 0; ki < kvlen; ki++) { + float v_val = llaisys::utils::cast(v[ki * nkvhead * d + kv_h * d + j]); + sum += scores[ki] * v_val; + } + attn_val[qi * nhead * d + h * d + j] = llaisys::utils::cast(sum); + } + } + } +} + +namespace llaisys::ops::cpu { +void self_attention(std::byte *attn_val, const std::byte *q, const std::byte *k, const std::byte *v, + llaisysDataType_t type, size_t qlen, size_t kvlen, size_t nhead, size_t nkvhead, size_t d, float scale) { + switch (type) { + case LLAISYS_DTYPE_F32: + return self_attention_(reinterpret_cast(attn_val), reinterpret_cast(q), + reinterpret_cast(k), reinterpret_cast(v), + qlen, kvlen, nhead, nkvhead, d, scale); + case LLAISYS_DTYPE_BF16: + return self_attention_(reinterpret_cast(attn_val), reinterpret_cast(q), + reinterpret_cast(k), reinterpret_cast(v), + qlen, kvlen, nhead, nkvhead, d, scale); + case LLAISYS_DTYPE_F16: + return self_attention_(reinterpret_cast(attn_val), reinterpret_cast(q), + reinterpret_cast(k), reinterpret_cast(v), + qlen, kvlen, nhead, nkvhead, d, scale); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} diff --git a/src/ops/self_attention/cpu/self_attention_cpu.hpp b/src/ops/self_attention/cpu/self_attention_cpu.hpp new file mode 100644 index 000000000..3a2784823 --- /dev/null +++ b/src/ops/self_attention/cpu/self_attention_cpu.hpp @@ -0,0 +1,8 @@ +#pragma once +#include "llaisys.h" +#include + +namespace llaisys::ops::cpu { +void self_attention(std::byte *attn_val, const std::byte *q, const std::byte *k, const std::byte *v, + llaisysDataType_t type, size_t qlen, size_t kvlen, size_t nhead, size_t nkvhead, size_t d, float scale); +} diff --git a/src/ops/self_attention/op.cpp b/src/ops/self_attention/op.cpp index 43d620142..ca9408eb7 100644 --- a/src/ops/self_attention/op.cpp +++ b/src/ops/self_attention/op.cpp @@ -1,7 +1,36 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" +#include "cpu/self_attention_cpu.hpp" namespace llaisys::ops { void self_attention(tensor_t attn_val, tensor_t q, tensor_t k, tensor_t v, float scale) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(attn_val, q, k, v); + + size_t qlen = q->shape()[0]; + size_t nhead = q->shape()[1]; + size_t d = q->shape()[2]; + size_t kvlen = k->shape()[0]; + size_t nkvhead = k->shape()[1]; + + if (attn_val->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::self_attention(attn_val->data(), q->data(), k->data(), v->data(), + attn_val->dtype(), qlen, kvlen, nhead, nkvhead, d, scale); + } + + llaisys::core::context().setDevice(attn_val->deviceType(), attn_val->deviceId()); + + switch (attn_val->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::self_attention(attn_val->data(), q->data(), k->data(), v->data(), + attn_val->dtype(), qlen, kvlen, nhead, nkvhead, d, scale); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + TO_BE_IMPLEMENTED(); + return; +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/swiglu/cpu/swiglu_cpu.cpp b/src/ops/swiglu/cpu/swiglu_cpu.cpp new file mode 100644 index 000000000..fa5eac331 --- /dev/null +++ b/src/ops/swiglu/cpu/swiglu_cpu.cpp @@ -0,0 +1,31 @@ +#include "swiglu_cpu.hpp" +#include "../../../utils.hpp" +#include + +template +void swiglu_(T *out, const T *gate, const T *up, size_t size) { + for (size_t i = 0; i < size; i++) { + float g = llaisys::utils::cast(gate[i]); + float u = llaisys::utils::cast(up[i]); + float silu = g / (1.0f + std::exp(-g)); + out[i] = llaisys::utils::cast(u * silu); + } +} + +namespace llaisys::ops::cpu { +void swiglu(std::byte *out, const std::byte *gate, const std::byte *up, llaisysDataType_t type, size_t size) { + switch (type) { + case LLAISYS_DTYPE_F32: + return swiglu_(reinterpret_cast(out), reinterpret_cast(gate), + reinterpret_cast(up), size); + case LLAISYS_DTYPE_BF16: + return swiglu_(reinterpret_cast(out), reinterpret_cast(gate), + reinterpret_cast(up), size); + case LLAISYS_DTYPE_F16: + return swiglu_(reinterpret_cast(out), reinterpret_cast(gate), + reinterpret_cast(up), size); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} diff --git a/src/ops/swiglu/cpu/swiglu_cpu.hpp b/src/ops/swiglu/cpu/swiglu_cpu.hpp new file mode 100644 index 000000000..b6369aee5 --- /dev/null +++ b/src/ops/swiglu/cpu/swiglu_cpu.hpp @@ -0,0 +1,7 @@ +#pragma once +#include "llaisys.h" +#include + +namespace llaisys::ops::cpu { +void swiglu(std::byte *out, const std::byte *gate, const std::byte *up, llaisysDataType_t type, size_t size); +} diff --git a/src/ops/swiglu/op.cpp b/src/ops/swiglu/op.cpp index 47edbcc97..9cf378831 100644 --- a/src/ops/swiglu/op.cpp +++ b/src/ops/swiglu/op.cpp @@ -1,7 +1,28 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" +#include "cpu/swiglu_cpu.hpp" namespace llaisys::ops { void swiglu(tensor_t out, tensor_t gate, tensor_t up) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(out, gate, up); + + if (out->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::swiglu(out->data(), gate->data(), up->data(), out->dtype(), out->numel()); + } + + llaisys::core::context().setDevice(out->deviceType(), out->deviceId()); + + switch (out->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::swiglu(out->data(), gate->data(), up->data(), out->dtype(), out->numel()); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + TO_BE_IMPLEMENTED(); + return; +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/tensor/tensor.cpp b/src/tensor/tensor.cpp index 2f594bb65..70aa971ff 100644 --- a/src/tensor/tensor.cpp +++ b/src/tensor/tensor.cpp @@ -164,27 +164,103 @@ void Tensor::debug() const { } bool Tensor::isContiguous() const { - TO_BE_IMPLEMENTED(); + const auto &shape_ref = this->shape(); + const auto &stride_ref = this->strides(); + + if (shape_ref.empty()) { + return true; + } + + size_t expected_stride = 1; + for (size_t dim = shape_ref.size(); dim-- > 0;) { + if (shape_ref[dim] == 0) { + return true; + } + if (stride_ref[dim] != static_cast(expected_stride)) { + return false; + } + expected_stride *= shape_ref[dim]; + } return true; } tensor_t Tensor::permute(const std::vector &order) const { - TO_BE_IMPLEMENTED(); - return std::shared_ptr(new Tensor(_meta, _storage)); + const size_t ndim_ = this->ndim(); + CHECK_ARGUMENT(order.size() == ndim_, "permute order size mismatch"); + + std::vector seen(ndim_, false); + std::vector new_shape(ndim_); + std::vector new_strides(ndim_); + + for (size_t i = 0; i < ndim_; ++i) { + CHECK_ARGUMENT(order[i] < ndim_, "permute index out of range"); + CHECK_ARGUMENT(!seen[order[i]], "permute order must be a permutation"); + seen[order[i]] = true; + + new_shape[i] = this->shape()[order[i]]; + new_strides[i] = this->strides()[order[i]]; + } + + TensorMeta meta{this->dtype(), std::move(new_shape), std::move(new_strides)}; + return std::shared_ptr(new Tensor(meta, _storage, _offset)); } tensor_t Tensor::view(const std::vector &shape) const { - TO_BE_IMPLEMENTED(); - return std::shared_ptr(new Tensor(_meta, _storage)); + const size_t new_elems = + std::accumulate(shape.begin(), shape.end(), size_t{1}, std::multiplies()); + CHECK_ARGUMENT(new_elems == this->numel(), "view shape mismatch with tensor size"); + + CHECK_ARGUMENT(this->isContiguous(), "view requires contiguous tensor"); + + std::vector new_strides(shape.size()); + size_t stride = 1; + for (size_t i = shape.size(); i-- > 0;) { + new_strides[i] = static_cast(stride); + stride *= shape[i]; + } + + TensorMeta meta{this->dtype(), shape, new_strides}; + return std::shared_ptr(new Tensor(meta, _storage)); } tensor_t Tensor::slice(size_t dim, size_t start, size_t end) const { - TO_BE_IMPLEMENTED(); - return std::shared_ptr(new Tensor(_meta, _storage)); + CHECK_ARGUMENT(dim < this->ndim(), "slice dim out of range"); + CHECK_ARGUMENT(start <= end, "slice start must be <= end"); + CHECK_ARGUMENT(end <= this->shape()[dim], "slice end out of range"); + + auto new_shape = this->shape(); + new_shape[dim] = end - start; + + auto new_strides = this->strides(); + CHECK_ARGUMENT(new_strides[dim] >= 0, "slice requires non-negative stride along dim"); + + const size_t elem_offset = static_cast(new_strides[dim]) * start; + const size_t byte_offset = _offset + elem_offset * this->elementSize(); + + TensorMeta meta{this->dtype(), std::move(new_shape), std::move(new_strides)}; + return std::shared_ptr(new Tensor(meta, _storage, byte_offset)); } void Tensor::load(const void *src_) { - TO_BE_IMPLEMENTED(); + if (!src_) { + throw std::invalid_argument("Tensor::load received null src pointer"); + } + const size_t bytes = this->numel() * this->elementSize(); + if (bytes == 0) { + return; + } + auto dst = this->data(); + auto dtype = this->deviceType(); + if (dtype == LLAISYS_DEVICE_CPU) { + std::memcpy(dst, src_, bytes); + } else { + core::context().setDevice(dtype, this->deviceId()); + core::context().runtime().api()->memcpy_sync( + dst, + src_, + bytes, + LLAISYS_MEMCPY_H2D); + } } tensor_t Tensor::contiguous() const { diff --git a/xmake.lua b/xmake.lua index 1f65f7a95..82eec90f7 100644 --- a/xmake.lua +++ b/xmake.lua @@ -3,6 +3,9 @@ set_encodings("utf-8") add_includedirs("include") +-- 添加 OpenMP 包 +add_requires("openmp") + -- CPU -- includes("xmake/cpu.lua") @@ -95,6 +98,22 @@ target("llaisys-ops") on_install(function (target) end) target_end() +target("llaisys-models") + set_kind("static") + add_deps("llaisys-tensor") + add_deps("llaisys-ops") + + set_languages("cxx17") + set_warnings("all", "error") + if not is_plat("windows") then + add_cxflags("-fPIC", "-Wno-unknown-pragmas") + end + + add_files("src/models/*/*.cpp") + + on_install(function (target) end) +target_end() + target("llaisys") set_kind("shared") add_deps("llaisys-utils") @@ -102,6 +121,7 @@ target("llaisys") add_deps("llaisys-core") add_deps("llaisys-tensor") add_deps("llaisys-ops") + add_deps("llaisys-models") set_languages("cxx17") set_warnings("all", "error") @@ -119,4 +139,4 @@ target("llaisys") os.cp("lib/*.so", "python/llaisys/libllaisys/") end end) -target_end() \ No newline at end of file +target_end() diff --git a/xmake/cpu.lua b/xmake/cpu.lua index 101d894e6..38ec55436 100644 --- a/xmake/cpu.lua +++ b/xmake/cpu.lua @@ -16,6 +16,10 @@ target("llaisys-ops-cpu") add_deps("llaisys-tensor") set_languages("cxx17") set_warnings("all", "error") + + -- 添加 OpenMP 支持 + add_packages("openmp") + if not is_plat("windows") then add_cxflags("-fPIC", "-Wno-unknown-pragmas") end @@ -24,4 +28,3 @@ target("llaisys-ops-cpu") on_install(function (target) end) target_end() -