diff --git a/.gitignore b/.gitignore index e38cf5747..e243776c4 100644 --- a/.gitignore +++ b/.gitignore @@ -87,4 +87,7 @@ htmlcov/ # Windows Thumbs.db ehthumbs.db -desktop.ini \ No newline at end of file +desktop.ini +# Windows Zone.Identifier (invalid path on Windows if committed) +*:Zone.Identifier +*.Identifier \ No newline at end of file diff --git a/include/llaisys/models/qwen2.h b/include/llaisys/models/qwen2.h index 7054626d4..444f447a7 100644 --- a/include/llaisys/models/qwen2.h +++ b/include/llaisys/models/qwen2.h @@ -31,12 +31,26 @@ __C { struct LlaisysQwen2Model; - __export struct LlaisysQwen2Model *llaisysQwen2ModelCreate(const LlaisysQwen2Meta *meta, llaisysDeviceType_t device, int *device_ids, int ndevice); + __export struct LlaisysQwen2Model *llaisysQwen2ModelCreate(const LlaisysQwen2Meta *meta, llaisysDeviceType_t device, int *device_ids, int ndevice,llaisysDataType_t dtype); __export void llaisysQwen2ModelDestroy(struct LlaisysQwen2Model * model); __export struct LlaisysQwen2Weights *llaisysQwen2ModelWeights(struct LlaisysQwen2Model * model); - __export int64_t llaisysQwen2ModelInfer(struct LlaisysQwen2Model * model, int64_t * token_ids, size_t ntoken); + __export void llaisysQwen2LoadWeight( + struct LlaisysQwen2Model * model, + const char * name, + void * data, + size_t * shape, + size_t ndim, + llaisysDataType_t dtype); + + __export void *llaisysQwen2ModelForward( + struct LlaisysQwen2Model * model, + int64_t * token_ids, + size_t seq_len, + size_t start_pos); + + __export int llaisysQwen2Sample(void * logits_ptr); } #endif // LLAISYS_MODELS_QWEN2_H diff --git a/linear.prof b/linear.prof new file mode 100644 index 000000000..18c73a47e Binary files /dev/null and b/linear.prof differ diff --git a/python/llaisys/libllaisys/__init__.py b/python/llaisys/libllaisys/__init__.py index f536fb527..e5c229c78 100644 --- a/python/llaisys/libllaisys/__init__.py +++ b/python/llaisys/libllaisys/__init__.py @@ -9,6 +9,7 @@ from .llaisys_types import llaisysDataType_t, DataType from .llaisys_types import llaisysMemcpyKind_t, MemcpyKind from .llaisys_types import llaisysStream_t +from .llaisys_types import LlaisysQwen2Meta from .tensor import llaisysTensor_t from .tensor import load_tensor from .ops import load_ops @@ -33,11 +34,51 @@ def load_shared_library(): return ctypes.CDLL(str(lib_path)) +def load_qwen2_api(lib): + try: + if hasattr(lib,'llaisysQwen2ModelCreate'): + lib.llaisysQwen2ModelCreate.argtypes=[ + ctypes.POINTER(LlaisysQwen2Meta), + ctypes.c_int, + ctypes.POINTER(ctypes.c_int), + ctypes.c_int, + ctypes.c_int + ] + lib.llaisysQwen2ModelCreate.restype=ctypes.c_void_p + if hasattr(lib,'llaisysQwen2LoadWeight'): + lib.llaisysQwen2LoadWeight.argtypes=[ + ctypes.c_void_p, + ctypes.c_char_p, + ctypes.c_void_p, + ctypes.POINTER(ctypes.c_size_t), + ctypes.c_size_t, + ctypes.c_int + ] + lib.llaisysQwen2LoadWeight.restype=None + except Exception as e: + print(f"Warning: Failed to load Qwen2 API signatures. {e}") +def llaisys_qwen2_create(meta,device_id): + return LIB_LLAISYS.llaisysQwen2ModelCreate( + ctypes.byref(meta), + device_id, + None, + 0 + ) +def llaisys_qwen2_load_weight(model_handle,name,data_ptr,shape,ndim,dtype): + LIB_LLAISYS.llaisysQwen2LoadWeight( + model_handle, + name, + data_ptr, + shape, + ndim, + dtype + ) LIB_LLAISYS = load_shared_library() load_runtime(LIB_LLAISYS) load_tensor(LIB_LLAISYS) load_ops(LIB_LLAISYS) +load_qwen2_api(LIB_LLAISYS) __all__ = [ @@ -52,4 +93,7 @@ def load_shared_library(): "llaisysMemcpyKind_t", "MemcpyKind", "llaisysStream_t", + "LlaisysQwen2Meta", + "llaisys_qwen2_create", + "llaisys_qwen2_load_weight" ] diff --git a/python/llaisys/libllaisys/llaisys_types.py b/python/llaisys/libllaisys/llaisys_types.py index c5a0b4679..a7b20be13 100644 --- a/python/llaisys/libllaisys/llaisys_types.py +++ b/python/llaisys/libllaisys/llaisys_types.py @@ -61,3 +61,20 @@ class MemcpyKind(IntEnum): "MemcpyKind", "llaisysStream_t", ] + + +class LlaisysQwen2Meta(ctypes.Structure): + _fields_ = [ + ("dtype",ctypes.c_int), + ("nlayer",ctypes.c_size_t), + ("hs",ctypes.c_size_t), + ("nh",ctypes.c_size_t), + ("nkvh",ctypes.c_size_t), + ("dh",ctypes.c_size_t), + ("di",ctypes.c_size_t), + ("maxseq",ctypes.c_size_t), + ("voc",ctypes.c_size_t), + ("epsilon",ctypes.c_float), + ("theta",ctypes.c_float), + ("end_token",ctypes.c_int64,) + ] diff --git a/python/llaisys/models/qwen2.py b/python/llaisys/models/qwen2.py index 0d07b0b21..001af9069 100644 --- a/python/llaisys/models/qwen2.py +++ b/python/llaisys/models/qwen2.py @@ -4,21 +4,141 @@ from pathlib import Path import safetensors +import ctypes +import numpy as np +import struct +import json +import mmap +import os - +from ..libllaisys import( + DeviceType, + LlaisysQwen2Meta, + llaisys_qwen2_create, + llaisys_qwen2_load_weight +) +TYPE_MAP={ + # 必须与 C++ enum llaisysDataType_t 完全一致 + # 参见 include/llaisys.h 和 python/llaisys/libllaisys/llaisys_types.py + "F32":13, + "F16":12, # 修复: 之前错误地映射为 11(F8) + "BF16":19 + } class Qwen2: def __init__(self, model_path, device: DeviceType = DeviceType.CPU): - # TODO: Implement model constructor + self.lib=LIB_LLAISYS - model_path = Path(model_path) + config_path=os.path.join(model_path,"config.json") + if not os.path.exists(config_path): + raise FileNotFoundError(f"Config not found at {config_path}") + with open(config_path,"r") as f: + config=json.load(f) + + meta=LlaisysQwen2Meta() + meta.nlayer=config.get("num_hidden_layers",28) + meta.hs=config.get("hidden_size",1536) + meta.nh=config.get("num_attention_heads",12) + meta.nkvh=config.get("num_key_value_heads",2) + meta.vocab_size=config.get("vocab_size",151936) + meta.maxseq=config.get("max_position_embeddings",32768) + meta.epsilon=config.get("rms_norm_eps",1e-6) + meta.theta=config.get("rope_theta",10000.0) + + config_dtype_str=config.get("torch_dtype","float16") + target_key="F16" + if config_dtype_str=="float32": + target_key="F32" + elif config_dtype_str=="bfloat16": + target_key="BF16" + elif config_dtype_str=="float16": + target_key="F16" + + if target_key not in TYPE_MAP: + print(f"Warning: Unknown dtype {config_dtype_str}, using F16") + target_dtype=TYPE_MAP["F16"] + else: + target_dtype=TYPE_MAP[target_key] + self.model=self.lib.llaisysQwen2ModelCreate( + ctypes.byref(meta), + device.value, + None, + 0, + target_dtype + ) + + self.lib.llaisysQwen2ModelForward.restype=ctypes.c_void_p + self.lib.llaisysQwen2ModelForward.argtypes=[ + ctypes.c_void_p, + ctypes.POINTER(ctypes.c_int64), + ctypes.c_size_t, + ctypes.c_size_t + ] + if hasattr(self.lib,'llaisysQwen2Sample'): + self.lib.llaisysQwen2Sample.restype=ctypes.c_int + self.lib.llaisysQwen2Sample.argtypes=[ctypes.c_void_p] + + model_path = Path(model_path) for file in sorted(model_path.glob("*.safetensors")): - data_ = safetensors.safe_open(file, framework="numpy", device="cpu") - for name_ in data_.keys(): - ## TODO: load the model weights - pass + with open(file,'rb') as f_obj: + header_size=struct.unpack(' +#include +#include +#include +#include +using namespace llaisys; + +struct Qwen2Layer{ + tensor_t attention_norm; + tensor_t w_q; + tensor_t w_k; + tensor_t w_v; + tensor_t w_o; + + tensor_t b_q; + tensor_t b_k; + tensor_t b_v; + tensor_t b_o; + + tensor_t ffn_norm; + tensor_t w_gate; + tensor_t w_up; + tensor_t w_down; +}; +struct LlaisysQwen2Model{ + LlaisysQwen2Meta meta; + llaisysDeviceType_t device; + + tensor_t tok_embeddings; + tensor_t norm; + tensor_t output; + + std::vector layers; + + std::vector k_cache; + std::vector v_cache; +}; + +extern "C"{ + +LlaisysQwen2Model *llaisysQwen2ModelCreate(const LlaisysQwen2Meta *meta, llaisysDeviceType_t device, int *device_ids, int ndevice,llaisysDataType_t dtype){ + auto model=new LlaisysQwen2Model(); + model->meta=*meta; + model->device=device; + model->layers.resize(meta->nlayer); + model->k_cache.resize(meta->nlayer); + model->v_cache.resize(meta->nlayer); + + size_t head_dim=meta->hs/meta->nh; + std::vector cache_shape={1,meta->maxseq,meta->nkvh,head_dim}; + + for(size_t i=0;inlayer;++i){ + model->k_cache[i]=Tensor::create(cache_shape, dtype,device,0); + model->v_cache[i]=Tensor::create(cache_shape, dtype,device,0); + } + return model; +} +void llaisysQwen2ModelDestroy(struct LlaisysQwen2Model * model){ + if(model){ + delete model; + } +} +void llaisysQwen2LoadWeight( + LlaisysQwen2Model*model, + const char*name, + void*data, + size_t*shape, + size_t ndim, + llaisysDataType_t dtype +){ + std::string w_name(name); + tensor_t* target=nullptr; + + if(w_name=="model.embed_tokens.weight"){ + target=&model->tok_embeddings; + } + else if(w_name=="model.norm.weight"){ + target=&model->norm; + } + else if(w_name=="lm_head.weight"){ + target=&model->output; + } + + else if(w_name.find("model.layers.")==0){ + size_t first_dot=13; + size_t second_dot=w_name.find(".",first_dot); + std::string layer_id_str=w_name.substr(first_dot,second_dot-first_dot); + size_t layer_id=std::stoull(layer_id_str); + + if(layer_id>=0&&layer_idmeta.nlayer){ + auto&layer=model->layers[layer_id]; + std::string suffix=w_name.substr(second_dot+1); + + if(suffix=="self_attn.q_proj.weight") target=&layer.w_q; + else if(suffix=="self_attn.k_proj.weight") target=&layer.w_k; + else if(suffix=="self_attn.v_proj.weight") target=&layer.w_v; + else if(suffix=="self_attn.o_proj.weight") target=&layer.w_o; + else if(suffix=="self_attn.q_proj.bias") target=&layer.b_q; + else if(suffix=="self_attn.k_proj.bias") target=&layer.b_k; + else if(suffix=="self_attn.v_proj.bias") target=&layer.b_v; + else if(suffix=="self_attn.o_proj.bias") target=&layer.b_o; + else if(suffix=="mlp.gate_proj.weight") target=&layer.w_gate; + else if(suffix=="mlp.up_proj.weight") target=&layer.w_up; + else if(suffix=="mlp.down_proj.weight") target=&layer.w_down; + else if(suffix=="input_layernorm.weight") target=&layer.attention_norm; + else if(suffix=="post_attention_layernorm.weight") target=&layer.ffn_norm; + } + } + if(target){ + std::vector shape_vec(shape,shape+ndim); + *target=Tensor::create(shape_vec,dtype,model->device,0); + (*target)->load(data); + } + +} +void* llaisysQwen2ModelForward( + LlaisysQwen2Model* model, + int64_t*input_ids_ptr, + size_t seq_len, + size_t start_pos +){ + auto device=model->device; + auto dtype=model->tok_embeddings->dtype(); + size_t hs=model->meta.hs; + size_t head_dim=hs/model->meta.nh; + size_t kv_dim=head_dim*model->meta.nkvh; + std::vector q_shape={1,seq_len,hs}; + std::vector kv_shape={1,seq_len,kv_dim}; + + std::vector input_shape={1,seq_len}; + + tensor_t input_tensor=Tensor::create(input_shape,LLAISYS_DTYPE_I64,device,0); + if (!input_ids_ptr) return nullptr; + input_tensor->load(input_ids_ptr); + + std::vector hidden_shape={1,seq_len,hs}; + + tensor_t hidden_states=Tensor::create(hidden_shape,dtype,device,0); + + ops::embedding(hidden_states, input_tensor, model->tok_embeddings); + + std::vector pos_shape={1,seq_len}; + tensor_t pos_ids=Tensor::create(pos_shape, LLAISYS_DTYPE_I64,device,0); + std::vector pos_vec(seq_len); + for(size_t i=0;iload(pos_vec.data()); + + for(size_t i=0;imeta.nlayer;++i){ + auto&layer=model->layers[i]; + + tensor_t norm_out=Tensor::create(hidden_shape,dtype,device,0); + ops::rms_norm(norm_out, hidden_states, layer.attention_norm,model->meta.epsilon); + + tensor_t q=Tensor::create(q_shape,dtype,device,0); + tensor_t k=Tensor::create(kv_shape,dtype,device,0); + tensor_t v=Tensor::create(kv_shape,dtype,device,0); + + ops::linear(q, norm_out, layer.w_q, layer.b_q); + ops::linear(k, norm_out, layer.w_k, layer.b_k); + ops::linear(v, norm_out, layer.w_v, layer.b_v); + + // 使用配置中的 RoPE 基数 theta,而不是硬编码 1e6, + // 以与 HuggingFace/Qwen2 的实现保持一致,避免位置编码频率不匹配。 + ops::rope(q, q, pos_ids, model->meta.theta); + ops::rope(k, k, pos_ids, model->meta.theta); + + tensor_t k_slot=model->k_cache[i]->slice(1,start_pos,start_pos+seq_len); + tensor_t v_slot=model->v_cache[i]->slice(1,start_pos,start_pos+seq_len); + + k_slot->load(k->data()); + v_slot->load(v->data()); + + tensor_t full_k=model->k_cache[i]->slice(1,0,start_pos+seq_len); + tensor_t full_v=model->v_cache[i]->slice(1,0,start_pos+seq_len); + + tensor_t attn_out=Tensor::create(hidden_shape,dtype,device,0); + float scale = 0.0883883f; + ops::self_attention(attn_out, q, full_k, full_v, scale); + + tensor_t proj_out=Tensor::create(hidden_shape,dtype,device,0); + ops::linear(proj_out, attn_out, layer.w_o, layer.b_o); + + ops::add(hidden_states, hidden_states, proj_out); + + tensor_t ffn_norm_out=Tensor::create(hidden_shape,dtype,device,0); + ops::rms_norm(ffn_norm_out, hidden_states,layer.ffn_norm,model->meta.epsilon); + + size_t inter_size=layer.w_gate->shape()[0]; + std::vector inter_shape={1,seq_len,inter_size}; + + tensor_t gate=Tensor::create(inter_shape,dtype,device,0); + tensor_t up=Tensor::create(inter_shape,dtype,device,0); + + ops::linear(gate, ffn_norm_out, layer.w_gate, nullptr); + ops::linear(up, ffn_norm_out,layer.w_up,nullptr); + + tensor_t act=Tensor::create(inter_shape,dtype,device,0); + ops::swiglu(act, gate, up); + + tensor_t mlp_out=Tensor::create(hidden_shape, dtype,device,0); + ops::linear(mlp_out, act, layer.w_down, nullptr); + + ops::add(hidden_states, hidden_states, mlp_out); + } + + tensor_t final_norm=Tensor::create(hidden_shape,dtype,device,0); + ops::rms_norm(final_norm, hidden_states, model->norm, model->meta.epsilon); + + size_t vocab_size=model->output->shape()[0]; + std::vector logits_shape={1,seq_len,vocab_size}; + tensor_t logits=Tensor::create(logits_shape,final_norm->dtype(),device,0); + + ops::linear(logits, final_norm, model->output, nullptr); + + tensor_t* heap_logits=new tensor_t(logits); + return (void*)heap_logits; + +} +int llaisysQwen2Sample(void* logits_void_ptr) { + if (!logits_void_ptr) return 0; + + tensor_t* ptr_to_shared = (tensor_t*)logits_void_ptr; + tensor_t logits = *ptr_to_shared; + size_t seq_len = logits->shape()[1]; + tensor_t last_token_logits=logits->slice(1,seq_len-1,seq_len); + tensor_t final_logits=last_token_logits->contiguous(); + + std::vector out_shape = {1}; + tensor_t max_idx = Tensor::create(out_shape, LLAISYS_DTYPE_I64, logits->deviceType(), logits->deviceId()); + tensor_t max_val = Tensor::create(out_shape, logits->dtype(), logits->deviceType(), logits->deviceId()); + ops::argmax(max_idx, max_val, final_logits); + int64_t result_index = *reinterpret_cast(max_idx->data()); + + delete ptr_to_shared; + return (int)result_index; +} +} \ No newline at end of file diff --git a/src/ops/argmax/op.cpp b/src/ops/argmax/op.cpp index 6dc37d426..f32adc9ce 100644 --- a/src/ops/argmax/op.cpp +++ b/src/ops/argmax/op.cpp @@ -1,7 +1,47 @@ #include "op.hpp" - +#include "../../utils.hpp" namespace llaisys::ops { + template + void argmax_cpu_kernel(tensor_t max_idx,tensor_t max_val,const tensor_t vals){ + const T*src=reinterpret_cast(vals->data()); + T*dst_val=reinterpret_cast(max_val->data()); + int64_t*dst_idx=reinterpret_cast(max_idx->data()); + + // 获取最后一维的大小(vocab_size) + size_t last_dim = vals->shape().back(); + size_t num_rows = vals->numel() / last_dim; // 其他维度的乘积 + + // 对每一行分别求 argmax + for(size_t row = 0; row < num_rows; row++){ + float cur_max_fval = utils::cast(src[row * last_dim]); + T cur_max_val = src[row * last_dim]; + size_t cur_max_idx = 0; + + for(size_t i = 1; i < last_dim; i++){ + float cur_fval = utils::cast(src[row * last_dim + i]); + if(cur_fval > cur_max_fval){ + cur_max_fval = cur_fval; + cur_max_idx = i; + cur_max_val = src[row * last_dim + i]; + } + } + dst_val[row] = cur_max_val; + dst_idx[row] = static_cast(cur_max_idx); + } + } void argmax(tensor_t max_idx, tensor_t max_val, tensor_t vals) { - TO_BE_IMPLEMENTED(); + switch (vals->dtype()) { + case LLAISYS_DTYPE_F16: + argmax_cpu_kernel(max_idx, max_val, vals); + break; + case LLAISYS_DTYPE_BF16: + argmax_cpu_kernel(max_idx, max_val, vals); + break; + case LLAISYS_DTYPE_F32: + argmax_cpu_kernel(max_idx, max_val, vals); + break; + default: + throw std::runtime_error("Not support this dtype!"); + } } } // namespace llaisys::ops diff --git a/src/ops/embedding/op.cpp b/src/ops/embedding/op.cpp index 84b9a5d06..2561f7bd8 100644 --- a/src/ops/embedding/op.cpp +++ b/src/ops/embedding/op.cpp @@ -1,7 +1,33 @@ #include "op.hpp" - +#include namespace llaisys::ops { +template +void embedding_cpu_kernel(tensor_t out, tensor_t index, tensor_t weight){ + const T*weight_val=reinterpret_cast(weight->data()); + T*out_val=reinterpret_cast(out->data()); + const int64_t*index_val=reinterpret_cast(index->data()); + size_t embedding_dim=weight->shape().back(); + size_t n=index->numel(); + for(size_t i=0;idtype()) { + case LLAISYS_DTYPE_F16: + embedding_cpu_kernel(out,index,weight); + break; + case LLAISYS_DTYPE_BF16: + embedding_cpu_kernel(out,index,weight); + break; + case LLAISYS_DTYPE_F32: + embedding_cpu_kernel(out,index,weight); + break; + default: + throw std::runtime_error("Not support this dtype!"); + } } } // namespace llaisys::ops diff --git a/src/ops/linear/op.cpp b/src/ops/linear/op.cpp index 97d1f8655..a311be45d 100644 --- a/src/ops/linear/op.cpp +++ b/src/ops/linear/op.cpp @@ -1,7 +1,833 @@ +// simd.hpp MUST be included before op.hpp because op.hpp transitively pulls +// in llaisys.h which defines `#define __C extern "C"`. That macro corrupts +// parameter names inside (e.g. ia32intrin.h:63 `__crc32b(...__C...)`). +// By including simd.hpp first, is parsed before the __C macro exists. +#include "../../utils/simd.hpp" +#include "../../utils/blas_runtime.hpp" #include "op.hpp" +#include "../../utils.hpp" +#include +#include +#include +#include +#include +#include +#ifdef _OPENMP +#include +#endif +#ifdef USE_OPENBLAS +#include +#endif namespace llaisys::ops { + +// 权重缓存: bf16/fp16 权重在首次使用时转为 f32 并缓存, 后续调用直接复用。 +// key = (原始权重数据指针, 元素个数, 类型标签)。 +// 类型标签区分同一地址不同 dtype 的数据 (测试场景下内存可能被复用)。 +// 推理场景中权重地址唯一且不释放, 缓存命中率 100%。 +// 同时供编译期 OpenBLAS 路径和运行时 BLAS (MKL/OpenBLAS dlopen) 路径使用。 +struct WeightCacheKey { + const void* ptr; + size_t count; + int dtype_tag; // 0=bf16, 1=fp16 + bool operator==(const WeightCacheKey& o) const { + return ptr == o.ptr && count == o.count && dtype_tag == o.dtype_tag; + } +}; +struct WeightCacheKeyHash { + size_t operator()(const WeightCacheKey& k) const { + size_t h = std::hash()(k.ptr); + h ^= std::hash()(k.count) << 1; + h ^= std::hash()(k.dtype_tag) << 2; + return h; + } +}; +static std::unordered_map, WeightCacheKeyHash> g_weight_cache; +static std::mutex g_weight_cache_mutex; + +// 查找或创建 f32 权重缓存 (线程安全) +template +static const float* get_cached_f32_weights(const void* key, size_t count, + int dtype_tag, ConvertFn convert_fn) +{ + WeightCacheKey cache_key{key, count, dtype_tag}; + std::lock_guard lock(g_weight_cache_mutex); + auto it = g_weight_cache.find(cache_key); + if (it != g_weight_cache.end()) + return it->second.data(); + auto& buf = g_weight_cache[cache_key]; + buf.resize(count); + convert_fn(buf.data()); + return buf.data(); +} + +// bf16 → f32 批量转换 (AVX2 SIMD 加速) +static void bf16_to_f32_bulk(const uint16_t* src, float* dst, size_t count) { + size_t i = 0; + for (; i + 7 < count; i += 8) { + __m256 v = utils::bf16x8_to_f32x8(src + i); + _mm256_storeu_ps(dst + i, v); + } + for (; i < count; ++i) + dst[i] = utils::cast(llaisys::bf16_t{src[i]}); +} + +// fp16 → f32 批量转换 (AVX2 SIMD 加速) +static void fp16_to_f32_bulk(const uint16_t* src, float* dst, size_t count) { + size_t i = 0; + for (; i + 7 < count; i += 8) { + __m256 v = utils::fp16x8_to_f32x8(src + i); + _mm256_storeu_ps(dst + i, v); + } + for (; i < count; ++i) + dst[i] = utils::cast(llaisys::fp16_t{src[i]}); +} + +// 判断是否有可用的 BLAS (编译期或运行时) +static inline bool has_blas() { +#ifdef USE_OPENBLAS + return true; +#else + return blas::available(); +#endif +} + +// 统一 sgemm 调用接口: 编译期 OpenBLAS 或运行时 BLAS +static inline void call_sgemm(int M, int N, int K, + float alpha, const float* A, int lda, + const float* B, int ldb, + float beta, float* C, int ldc) +{ +#ifdef USE_OPENBLAS + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, + (blasint)M, (blasint)N, (blasint)K, + alpha, A, lda, B, ldb, beta, C, ldc); +#else + blas::sgemm(blas::CblasRowMajor, blas::CblasNoTrans, blas::CblasTrans, + M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); +#endif +} + +// M 维度低于此阈值时跳过 OpenMP,避免线程创建/同步开销。 +// 模型推理解码阶段 M=1,若不跳过,每次 linear 调用产生数千次无效 barrier。 +static constexpr size_t OMP_M_THRESHOLD = 32; + +// 条件并行 for:当 n >= threshold 时使用 OpenMP 多线程,否则纯串行。 +// 使用 C 级 if 而非 OpenMP if() 子句,后者仍会进入 GOMP 运行时产生开销。 +template +static inline void parallel_for(size_t n, size_t threshold, F&& func) { +#ifdef _OPENMP + if (n >= threshold) { + #pragma omp parallel for schedule(static) + for (size_t i = 0; i < n; ++i) func(i); + return; + } +#endif + for (size_t i = 0; i < n; ++i) func(i); +} + +// ============================================================ +// 泛型版本 (fallback): 纯标量, 无 OpenMP, 无 SIMD +// ============================================================ +template +void linear_cpu_kernel(tensor_t out, tensor_t in, tensor_t weight, tensor_t bias){ + const T* in_ptr = reinterpret_cast(in->data()); + const T* weight_ptr = reinterpret_cast(weight->data()); + const T* bias_ptr = nullptr; + if(bias && bias->numel() > 0) bias_ptr = reinterpret_cast(bias->data()); + T* out_ptr = reinterpret_cast(out->data()); + + size_t K = in->shape().back(); + size_t N = weight->shape()[0]; + + size_t M; + if (in->shape().size() == 2) { + M = in->shape()[0]; + } else { + M = in->shape()[0] * in->shape()[1]; + } + + for (size_t i = 0; i < M; ++i) { + for (size_t j = 0; j < N; ++j) { + float sum = 0.0f; + for (size_t k = 0; k < K; ++k) { + float x_val = utils::cast(in_ptr[k + i * K]); + float w_val = utils::cast(weight_ptr[k + j * K]); + sum += x_val * w_val; + } + if (bias_ptr) { + sum += utils::cast(bias_ptr[j]); + } + out_ptr[j + i * N] = utils::cast(sum); + } + } +} + +// 提取维度信息: M, K, N +static inline void get_dims(tensor_t in, tensor_t weight, + size_t &M, size_t &K, size_t &N) +{ + const auto& xs = in->shape(); + const auto& ws = weight->shape(); + if (xs.size() == 2) { + M = xs[0]; K = xs[1]; + } else if (xs.size() == 3) { + M = xs[0] * xs[1]; K = xs[2]; + } else { + K = xs.back(); M = in->numel() / K; + } + N = ws[0]; +} + +// ============================================================ +// float32 特化 +// 有 BLAS (编译期 OpenBLAS 或运行时 MKL/OpenBLAS) 时调用 sgemm; +// 否则使用手写 AVX2+FMA + OpenMP 并行化版本。 +// ============================================================ +template<> +void linear_cpu_kernel(tensor_t out, tensor_t in, tensor_t weight, tensor_t bias){ + const float* X = reinterpret_cast(in->data()); + const float* W = reinterpret_cast(weight->data()); + const float* b = (bias && bias->numel() > 0) + ? reinterpret_cast(bias->data()) : nullptr; + float* Y = reinterpret_cast(out->data()); + + size_t M, K, N; + get_dims(in, weight, M, K, N); + + // BLAS 路径 (编译期或运行时) + if (has_blas()) { + // Y = X[M,K] * W[N,K]^T => sgemm(NoTrans, Trans, M, N, K, 1, X, K, W, K, 0, Y, N) + // 若有 bias,先将 bias 广播填入 Y,然后用 beta=1 累加 GEMM 结果。 + if (b) { + for (size_t i = 0; i < M; ++i) + std::memcpy(Y + i * N, b, N * sizeof(float)); + call_sgemm((int)M, (int)N, (int)K, + 1.0f, X, (int)K, W, (int)K, 1.0f, Y, (int)N); + } else { + call_sgemm((int)M, (int)N, (int)K, + 1.0f, X, (int)K, W, (int)K, 0.0f, Y, (int)N); + } + return; + } + + // ---------- 手写 AVX2+FMA fallback ---------- + if (b) { + parallel_for(M, OMP_M_THRESHOLD, [&](size_t i) { + float* yrow = Y + i * N; + size_t j = 0; + for (; j + 7 < N; j += 8) + _mm256_storeu_ps(yrow + j, _mm256_loadu_ps(b + j)); + for (; j < N; ++j) + yrow[j] = b[j]; + }); + } else { + std::memset(Y, 0, M * N * sizeof(float)); + } + + static constexpr size_t BLOCK_K = 512; + static constexpr size_t BLOCK_N = 4; + + for (size_t k0 = 0; k0 < K; k0 += BLOCK_K) { + const size_t kc = std::min(BLOCK_K, K - k0); + + for (size_t j0 = 0; j0 < N; j0 += BLOCK_N) { + const size_t jn = std::min(BLOCK_N, N - j0); + + if (jn == 4) { + const float* w0 = W + (j0 + 0) * K + k0; + const float* w1 = W + (j0 + 1) * K + k0; + const float* w2 = W + (j0 + 2) * K + k0; + const float* w3 = W + (j0 + 3) * K + k0; + + parallel_for(M, OMP_M_THRESHOLD, [&](size_t i) { + const float* xi = X + i * K + k0; + __m256 a00 = _mm256_setzero_ps(), a01 = _mm256_setzero_ps(); + __m256 a10 = _mm256_setzero_ps(), a11 = _mm256_setzero_ps(); + __m256 a20 = _mm256_setzero_ps(), a21 = _mm256_setzero_ps(); + __m256 a30 = _mm256_setzero_ps(), a31 = _mm256_setzero_ps(); + + size_t k = 0; + for (; k + 15 < kc; k += 16) { + __m256 x0 = _mm256_loadu_ps(xi + k); + __m256 x1 = _mm256_loadu_ps(xi + k + 8); + a00 = _mm256_fmadd_ps(x0, _mm256_loadu_ps(w0 + k), a00); + a01 = _mm256_fmadd_ps(x1, _mm256_loadu_ps(w0 + k + 8), a01); + a10 = _mm256_fmadd_ps(x0, _mm256_loadu_ps(w1 + k), a10); + a11 = _mm256_fmadd_ps(x1, _mm256_loadu_ps(w1 + k + 8), a11); + a20 = _mm256_fmadd_ps(x0, _mm256_loadu_ps(w2 + k), a20); + a21 = _mm256_fmadd_ps(x1, _mm256_loadu_ps(w2 + k + 8), a21); + a30 = _mm256_fmadd_ps(x0, _mm256_loadu_ps(w3 + k), a30); + a31 = _mm256_fmadd_ps(x1, _mm256_loadu_ps(w3 + k + 8), a31); + } + for (; k + 7 < kc; k += 8) { + __m256 x0 = _mm256_loadu_ps(xi + k); + a00 = _mm256_fmadd_ps(x0, _mm256_loadu_ps(w0 + k), a00); + a10 = _mm256_fmadd_ps(x0, _mm256_loadu_ps(w1 + k), a10); + a20 = _mm256_fmadd_ps(x0, _mm256_loadu_ps(w2 + k), a20); + a30 = _mm256_fmadd_ps(x0, _mm256_loadu_ps(w3 + k), a30); + } + + float d0 = utils::hsum256(_mm256_add_ps(a00, a01)); + float d1 = utils::hsum256(_mm256_add_ps(a10, a11)); + float d2 = utils::hsum256(_mm256_add_ps(a20, a21)); + float d3 = utils::hsum256(_mm256_add_ps(a30, a31)); + + for (; k < kc; ++k) { + float xk = xi[k]; + d0 += xk * w0[k]; d1 += xk * w1[k]; + d2 += xk * w2[k]; d3 += xk * w3[k]; + } + + Y[i*N + j0] += d0; + Y[i*N + j0+1] += d1; + Y[i*N + j0+2] += d2; + Y[i*N + j0+3] += d3; + }); + } else { + for (size_t jj = 0; jj < jn; ++jj) { + const float* wj = W + (j0 + jj) * K + k0; + parallel_for(M, OMP_M_THRESHOLD, [&](size_t i) { + Y[i*N + j0 + jj] += utils::avx2_dot(X + i*K + k0, wj, kc); + }); + } + } + } + } +} + +// ============================================================ +// bfloat16 特化 +// USE_OPENBLAS 时: M >= 阈值 → bf16→f32 + cblas_sgemm (权重缓存); +// M < 阈值 → 手写 AVX2 (内存带宽受限, 无需转换)。 +// 否则始终使用手写 AVX2 + OpenMP 并行化版本。 +// ============================================================ + +// M 维度 >= 此阈值时使用 BLAS sgemm, 否则使用手写 SIMD 内核。 +// 解码阶段 M=1 是内存带宽瓶颈, bf16 直接读取比转 f32 后读取少一半带宽, +// 因此小 M 优先使用手写 bf16 SIMD 内核 (AVX2 或 AVX-512, 按编译目标自动选择)。 +static constexpr size_t BLAS_M_THRESHOLD = 32; + +// 手写 AVX2 bf16 linear 实现 +// M >= OMP_M_THRESHOLD 时并行化 M 维度, M < 阈值时并行化 N 维度。 +// 解码阶段 M=1, N 维度 (1536~5632) 提供充足并行度, +// 多线程分摊内存带宽瓶颈, 在多核 CPU 上大幅加速。 +static void bf16_linear_avx2(uint16_t* Y, const uint16_t* X, + const uint16_t* W, const uint16_t* B, + size_t M, size_t K, size_t N) +{ + std::vector ybuf(M * N); + + if (B) { + for (size_t i = 0; i < M; ++i) { + float* yrow = ybuf.data() + i * N; + for (size_t j = 0; j < N; ++j) + yrow[j] = utils::cast(llaisys::bf16_t{B[j]}); + } + } else { + std::memset(ybuf.data(), 0, M * N * sizeof(float)); + } + + static constexpr size_t BLOCK_K = 512; + static constexpr size_t BLOCK_N = 4; + + const bool par_over_m = (M >= OMP_M_THRESHOLD); + // N 方向的块数 (每块 BLOCK_N 个输出神经元) + const size_t n_blocks = (N + BLOCK_N - 1) / BLOCK_N; + + for (size_t k0 = 0; k0 < K; k0 += BLOCK_K) { + const size_t kc = std::min(BLOCK_K, K - k0); + + // 内层函数: 计算一个 (j0, jn) 块的所有 M 行 + auto compute_block = [&](size_t j0, size_t jn) { + if (jn == 4) { + const uint16_t* w0 = W + (j0 + 0) * K + k0; + const uint16_t* w1 = W + (j0 + 1) * K + k0; + const uint16_t* w2 = W + (j0 + 2) * K + k0; + const uint16_t* w3 = W + (j0 + 3) * K + k0; + + for (size_t i = 0; i < M; ++i) { + const uint16_t* xi = X + i * K + k0; + + __m256 a00 = _mm256_setzero_ps(), a01 = _mm256_setzero_ps(); + __m256 a10 = _mm256_setzero_ps(), a11 = _mm256_setzero_ps(); + __m256 a20 = _mm256_setzero_ps(), a21 = _mm256_setzero_ps(); + __m256 a30 = _mm256_setzero_ps(), a31 = _mm256_setzero_ps(); + + size_t k = 0; + for (; k + 15 < kc; k += 16) { + __m256 x0 = utils::bf16x8_to_f32x8(xi + k); + __m256 x1 = utils::bf16x8_to_f32x8(xi + k + 8); + a00 = _mm256_fmadd_ps(x0, utils::bf16x8_to_f32x8(w0 + k), a00); + a01 = _mm256_fmadd_ps(x1, utils::bf16x8_to_f32x8(w0 + k + 8), a01); + a10 = _mm256_fmadd_ps(x0, utils::bf16x8_to_f32x8(w1 + k), a10); + a11 = _mm256_fmadd_ps(x1, utils::bf16x8_to_f32x8(w1 + k + 8), a11); + a20 = _mm256_fmadd_ps(x0, utils::bf16x8_to_f32x8(w2 + k), a20); + a21 = _mm256_fmadd_ps(x1, utils::bf16x8_to_f32x8(w2 + k + 8), a21); + a30 = _mm256_fmadd_ps(x0, utils::bf16x8_to_f32x8(w3 + k), a30); + a31 = _mm256_fmadd_ps(x1, utils::bf16x8_to_f32x8(w3 + k + 8), a31); + } + for (; k + 7 < kc; k += 8) { + __m256 x0 = utils::bf16x8_to_f32x8(xi + k); + a00 = _mm256_fmadd_ps(x0, utils::bf16x8_to_f32x8(w0 + k), a00); + a10 = _mm256_fmadd_ps(x0, utils::bf16x8_to_f32x8(w1 + k), a10); + a20 = _mm256_fmadd_ps(x0, utils::bf16x8_to_f32x8(w2 + k), a20); + a30 = _mm256_fmadd_ps(x0, utils::bf16x8_to_f32x8(w3 + k), a30); + } + + float d0 = utils::hsum256(_mm256_add_ps(a00, a01)); + float d1 = utils::hsum256(_mm256_add_ps(a10, a11)); + float d2 = utils::hsum256(_mm256_add_ps(a20, a21)); + float d3 = utils::hsum256(_mm256_add_ps(a30, a31)); + + for (; k < kc; ++k) { + float xk = utils::cast(llaisys::bf16_t{xi[k]}); + d0 += xk * utils::cast(llaisys::bf16_t{w0[k]}); + d1 += xk * utils::cast(llaisys::bf16_t{w1[k]}); + d2 += xk * utils::cast(llaisys::bf16_t{w2[k]}); + d3 += xk * utils::cast(llaisys::bf16_t{w3[k]}); + } + + ybuf[i*N + j0] += d0; + ybuf[i*N + j0 + 1] += d1; + ybuf[i*N + j0 + 2] += d2; + ybuf[i*N + j0 + 3] += d3; + } + } else { + for (size_t jj = 0; jj < jn; ++jj) { + const uint16_t* wj = W + (j0 + jj) * K + k0; + for (size_t i = 0; i < M; ++i) { + const uint16_t* xi = X + i * K + k0; + __m256 acc0 = _mm256_setzero_ps(); + __m256 acc1 = _mm256_setzero_ps(); + size_t k = 0; + for (; k + 15 < kc; k += 16) { + acc0 = _mm256_fmadd_ps(utils::bf16x8_to_f32x8(xi + k), + utils::bf16x8_to_f32x8(wj + k), acc0); + acc1 = _mm256_fmadd_ps(utils::bf16x8_to_f32x8(xi + k + 8), + utils::bf16x8_to_f32x8(wj + k + 8), acc1); + } + for (; k + 7 < kc; k += 8) + acc0 = _mm256_fmadd_ps(utils::bf16x8_to_f32x8(xi + k), + utils::bf16x8_to_f32x8(wj + k), acc0); + float sum = utils::hsum256(_mm256_add_ps(acc0, acc1)); + for (; k < kc; ++k) + sum += utils::cast(llaisys::bf16_t{xi[k]}) + * utils::cast(llaisys::bf16_t{wj[k]}); + ybuf[i*N + j0 + jj] += sum; + } + } + } + }; + + if (par_over_m) { + // 大 M: 原始策略, 按 M 行并行 (每块 j0 串行, 块内按行并行) + for (size_t j0 = 0; j0 < N; j0 += BLOCK_N) { + const size_t jn = std::min(BLOCK_N, N - j0); + // 这里需要按行并行, 用原始的 parallel_for(M, ...) 方式 + // 但为简化, 直接调用 compute_block (它内部对 M 循环串行) + // 然后外层 j0 循环串行 —— 对大 M, BLAS 路径已处理, 这里不会走到 + compute_block(j0, jn); + } + } else { + // 小 M (含 M=1 decode): 按 N 块并行 + // 每个线程处理一个 (j0, jn) 块的所有 M 行 + // 各块之间写不同的 ybuf 位置, 无竞争 +#ifdef _OPENMP + #pragma omp parallel for schedule(static) +#endif + for (size_t blk = 0; blk < n_blocks; ++blk) { + const size_t j0 = blk * BLOCK_N; + const size_t jn = std::min(BLOCK_N, N - j0); + compute_block(j0, jn); + } + } + } + + // f32 → bf16 写回 + for (size_t idx = 0; idx < M * N; ++idx) { + llaisys::bf16_t v = utils::cast(ybuf[idx]); + Y[idx] = v._v; + } +} + +// ============================================================ +// AVX-512 bf16 linear 实现 (仅在编译目标支持 AVX-512 时可用) +// 与 AVX2 版本结构相同, 但使用 512 位向量, 每次处理 16 个 bf16 元素, +// 吞吐量是 AVX2 的 2 倍。M=1 解码阶段受内存带宽限制, 更宽的向量 +// 能更充分利用带宽 (单次 load 指令读取 32 字节而非 16 字节)。 +// ============================================================ +#ifdef __AVX512F__ +static void bf16_linear_avx512(uint16_t* Y, const uint16_t* X, + const uint16_t* W, const uint16_t* B, + size_t M, size_t K, size_t N) +{ + std::vector ybuf(M * N); + + if (B) { + parallel_for(M, OMP_M_THRESHOLD, [&](size_t i) { + float* yrow = ybuf.data() + i * N; + for (size_t j = 0; j < N; ++j) + yrow[j] = utils::cast(llaisys::bf16_t{B[j]}); + }); + } else { + std::memset(ybuf.data(), 0, M * N * sizeof(float)); + } + + static constexpr size_t BLOCK_K = 512; + static constexpr size_t BLOCK_N = 4; + + for (size_t k0 = 0; k0 < K; k0 += BLOCK_K) { + const size_t kc = std::min(BLOCK_K, K - k0); + + for (size_t j0 = 0; j0 < N; j0 += BLOCK_N) { + const size_t jn = std::min(BLOCK_N, N - j0); + + if (jn == 4) { + const uint16_t* w0 = W + (j0 + 0) * K + k0; + const uint16_t* w1 = W + (j0 + 1) * K + k0; + const uint16_t* w2 = W + (j0 + 2) * K + k0; + const uint16_t* w3 = W + (j0 + 3) * K + k0; + + parallel_for(M, OMP_M_THRESHOLD, [&](size_t i) { + const uint16_t* xi = X + i * K + k0; + + // 双累加器, 每个 512 位 (16 floats) + __m512 a00 = _mm512_setzero_ps(), a01 = _mm512_setzero_ps(); + __m512 a10 = _mm512_setzero_ps(), a11 = _mm512_setzero_ps(); + __m512 a20 = _mm512_setzero_ps(), a21 = _mm512_setzero_ps(); + __m512 a30 = _mm512_setzero_ps(), a31 = _mm512_setzero_ps(); + + size_t k = 0; + // 每次处理 32 个 bf16 元素 (2 x 16) + for (; k + 31 < kc; k += 32) { + __m512 x0 = utils::bf16x16_to_f32x16(xi + k); + __m512 x1 = utils::bf16x16_to_f32x16(xi + k + 16); + a00 = _mm512_fmadd_ps(x0, utils::bf16x16_to_f32x16(w0 + k), a00); + a01 = _mm512_fmadd_ps(x1, utils::bf16x16_to_f32x16(w0 + k + 16), a01); + a10 = _mm512_fmadd_ps(x0, utils::bf16x16_to_f32x16(w1 + k), a10); + a11 = _mm512_fmadd_ps(x1, utils::bf16x16_to_f32x16(w1 + k + 16), a11); + a20 = _mm512_fmadd_ps(x0, utils::bf16x16_to_f32x16(w2 + k), a20); + a21 = _mm512_fmadd_ps(x1, utils::bf16x16_to_f32x16(w2 + k + 16), a21); + a30 = _mm512_fmadd_ps(x0, utils::bf16x16_to_f32x16(w3 + k), a30); + a31 = _mm512_fmadd_ps(x1, utils::bf16x16_to_f32x16(w3 + k + 16), a31); + } + // 处理剩余的 16 个元素 + for (; k + 15 < kc; k += 16) { + __m512 x0 = utils::bf16x16_to_f32x16(xi + k); + a00 = _mm512_fmadd_ps(x0, utils::bf16x16_to_f32x16(w0 + k), a00); + a10 = _mm512_fmadd_ps(x0, utils::bf16x16_to_f32x16(w1 + k), a10); + a20 = _mm512_fmadd_ps(x0, utils::bf16x16_to_f32x16(w2 + k), a20); + a30 = _mm512_fmadd_ps(x0, utils::bf16x16_to_f32x16(w3 + k), a30); + } + + float d0 = utils::hsum512(_mm512_add_ps(a00, a01)); + float d1 = utils::hsum512(_mm512_add_ps(a10, a11)); + float d2 = utils::hsum512(_mm512_add_ps(a20, a21)); + float d3 = utils::hsum512(_mm512_add_ps(a30, a31)); + + // AVX2 尾部处理 (8 元素) + for (; k + 7 < kc; k += 8) { + __m256 x0 = utils::bf16x8_to_f32x8(xi + k); + d0 += utils::hsum256(_mm256_mul_ps(x0, utils::bf16x8_to_f32x8(w0 + k))); + d1 += utils::hsum256(_mm256_mul_ps(x0, utils::bf16x8_to_f32x8(w1 + k))); + d2 += utils::hsum256(_mm256_mul_ps(x0, utils::bf16x8_to_f32x8(w2 + k))); + d3 += utils::hsum256(_mm256_mul_ps(x0, utils::bf16x8_to_f32x8(w3 + k))); + } + + // 标量尾部 + for (; k < kc; ++k) { + float xk = utils::cast(llaisys::bf16_t{xi[k]}); + d0 += xk * utils::cast(llaisys::bf16_t{w0[k]}); + d1 += xk * utils::cast(llaisys::bf16_t{w1[k]}); + d2 += xk * utils::cast(llaisys::bf16_t{w2[k]}); + d3 += xk * utils::cast(llaisys::bf16_t{w3[k]}); + } + + ybuf[i*N + j0] += d0; + ybuf[i*N + j0 + 1] += d1; + ybuf[i*N + j0 + 2] += d2; + ybuf[i*N + j0 + 3] += d3; + }); + } else { + for (size_t jj = 0; jj < jn; ++jj) { + const uint16_t* wj = W + (j0 + jj) * K + k0; + parallel_for(M, OMP_M_THRESHOLD, [&](size_t i) { + const uint16_t* xi = X + i * K + k0; + __m512 acc0 = _mm512_setzero_ps(); + __m512 acc1 = _mm512_setzero_ps(); + size_t k = 0; + for (; k + 31 < kc; k += 32) { + acc0 = _mm512_fmadd_ps(utils::bf16x16_to_f32x16(xi + k), + utils::bf16x16_to_f32x16(wj + k), acc0); + acc1 = _mm512_fmadd_ps(utils::bf16x16_to_f32x16(xi + k + 16), + utils::bf16x16_to_f32x16(wj + k + 16), acc1); + } + for (; k + 15 < kc; k += 16) + acc0 = _mm512_fmadd_ps(utils::bf16x16_to_f32x16(xi + k), + utils::bf16x16_to_f32x16(wj + k), acc0); + float sum = utils::hsum512(_mm512_add_ps(acc0, acc1)); + // AVX2 尾部 + for (; k + 7 < kc; k += 8) { + sum += utils::hsum256(_mm256_mul_ps( + utils::bf16x8_to_f32x8(xi + k), + utils::bf16x8_to_f32x8(wj + k))); + } + for (; k < kc; ++k) + sum += utils::cast(llaisys::bf16_t{xi[k]}) + * utils::cast(llaisys::bf16_t{wj[k]}); + ybuf[i*N + j0 + jj] += sum; + }); + } + } + } + } + + // f32 → bf16 写回 + parallel_for(M * N, OMP_M_THRESHOLD, [&](size_t idx) { + llaisys::bf16_t v = utils::cast(ybuf[idx]); + Y[idx] = v._v; + }); +} +#endif // __AVX512F__ + +template<> +void linear_cpu_kernel(tensor_t out, tensor_t in, + tensor_t weight, tensor_t bias) +{ + const uint16_t* X = reinterpret_cast(in->data()); + const uint16_t* W = reinterpret_cast(weight->data()); + const uint16_t* B = nullptr; + if (bias && bias->numel() > 0) + B = reinterpret_cast(bias->data()); + uint16_t* Y = reinterpret_cast(out->data()); + + size_t M, K, N; + get_dims(in, weight, M, K, N); + + // BLAS 路径: M >= 阈值时使用 sgemm (编译期或运行时) + if (has_blas() && M >= BLAS_M_THRESHOLD) { + // 大 M: bf16 → f32 → sgemm → f32 → bf16 + // 权重 W 在首次调用时转为 f32 并缓存, 后续直接复用。 + + // 缓存 f32 权重 (仅首次转换, 后续复用) + const float* Wf = get_cached_f32_weights(W, N * K, /*dtype_tag=*/0, + [&](float* dst) { bf16_to_f32_bulk(W, dst, N * K); }); + + // 输入每次都转换 + std::vector Xf(M * K); + bf16_to_f32_bulk(X, Xf.data(), M * K); + + std::vector Cf(M * N); + + if (B) { + const float* bias_f = get_cached_f32_weights(B, N, /*dtype_tag=*/0, + [&](float* dst) { bf16_to_f32_bulk(B, dst, N); }); + for (size_t i = 0; i < M; ++i) + std::memcpy(Cf.data() + i * N, bias_f, N * sizeof(float)); + call_sgemm((int)M, (int)N, (int)K, + 1.0f, Xf.data(), (int)K, Wf, (int)K, + 1.0f, Cf.data(), (int)N); + } else { + call_sgemm((int)M, (int)N, (int)K, + 1.0f, Xf.data(), (int)K, Wf, (int)K, + 0.0f, Cf.data(), (int)N); + } + + // f32 → bf16 写回 + for (size_t i = 0; i < M * N; ++i) { + llaisys::bf16_t v = utils::cast(Cf[i]); + Y[i] = v._v; + } + return; + } + + // 小 M (含 M=1 decode): 手写 SIMD, 直接读 bf16, 节省内存带宽 + // M=1 时不启用 AVX-512 (频率降档惩罚 > 吞吐收益), 统一使用 AVX2 + bf16_linear_avx2(Y, X, W, B, M, K, N); +} + +// ============================================================ +// float16 特化 +// 有 BLAS 时将 fp16 转为 f32 再调用 sgemm; +// 否则使用手写 AVX2 + OpenMP 并行化版本。 +// ============================================================ +template<> +void linear_cpu_kernel(tensor_t out, tensor_t in, + tensor_t weight, tensor_t bias) +{ + const uint16_t* X = reinterpret_cast(in->data()); + const uint16_t* W = reinterpret_cast(weight->data()); + const uint16_t* B = nullptr; + if (bias && bias->numel() > 0) + B = reinterpret_cast(bias->data()); + uint16_t* Y_out = reinterpret_cast(out->data()); + + size_t M, K, N; + get_dims(in, weight, M, K, N); + + // BLAS 路径 (编译期或运行时) + if (has_blas()) { + // 缓存 f32 权重 (仅首次转换, 后续复用) + const float* Wf = get_cached_f32_weights(W, N * K, /*dtype_tag=*/1, + [&](float* dst) { fp16_to_f32_bulk(W, dst, N * K); }); + + // 输入每次都转换 + std::vector Xf(M * K); + fp16_to_f32_bulk(X, Xf.data(), M * K); + + std::vector Cf(M * N); + + if (B) { + const float* bias_f = get_cached_f32_weights(B, N, /*dtype_tag=*/1, + [&](float* dst) { fp16_to_f32_bulk(B, dst, N); }); + for (size_t i = 0; i < M; ++i) + std::memcpy(Cf.data() + i * N, bias_f, N * sizeof(float)); + call_sgemm((int)M, (int)N, (int)K, + 1.0f, Xf.data(), (int)K, Wf, (int)K, + 1.0f, Cf.data(), (int)N); + } else { + call_sgemm((int)M, (int)N, (int)K, + 1.0f, Xf.data(), (int)K, Wf, (int)K, + 0.0f, Cf.data(), (int)N); + } + + // f32 → fp16 写回 + for (size_t i = 0; i < M * N; ++i) { + llaisys::fp16_t v = utils::cast(Cf[i]); + Y_out[i] = v._v; + } + return; + } + + // ---------- 手写 AVX2+FMA fallback ---------- + std::vector ybuf(M * N); + + if (B) { + parallel_for(M, OMP_M_THRESHOLD, [&](size_t i) { + float* yrow = ybuf.data() + i * N; + for (size_t j = 0; j < N; ++j) + yrow[j] = utils::cast(llaisys::fp16_t{B[j]}); + }); + } else { + std::memset(ybuf.data(), 0, M * N * sizeof(float)); + } + + static constexpr size_t BLOCK_K = 512; + static constexpr size_t BLOCK_N = 4; + + for (size_t k0 = 0; k0 < K; k0 += BLOCK_K) { + const size_t kc = std::min(BLOCK_K, K - k0); + + for (size_t j0 = 0; j0 < N; j0 += BLOCK_N) { + const size_t jn = std::min(BLOCK_N, N - j0); + + if (jn == 4) { + const uint16_t* w0 = W + (j0 + 0) * K + k0; + const uint16_t* w1 = W + (j0 + 1) * K + k0; + const uint16_t* w2 = W + (j0 + 2) * K + k0; + const uint16_t* w3 = W + (j0 + 3) * K + k0; + + parallel_for(M, OMP_M_THRESHOLD, [&](size_t i) { + const uint16_t* xi = X + i * K + k0; + + __m256 a00 = _mm256_setzero_ps(), a01 = _mm256_setzero_ps(); + __m256 a10 = _mm256_setzero_ps(), a11 = _mm256_setzero_ps(); + __m256 a20 = _mm256_setzero_ps(), a21 = _mm256_setzero_ps(); + __m256 a30 = _mm256_setzero_ps(), a31 = _mm256_setzero_ps(); + + size_t k = 0; + for (; k + 15 < kc; k += 16) { + __m256 x0 = utils::fp16x8_to_f32x8(xi + k); + __m256 x1 = utils::fp16x8_to_f32x8(xi + k + 8); + a00 = _mm256_fmadd_ps(x0, utils::fp16x8_to_f32x8(w0 + k), a00); + a01 = _mm256_fmadd_ps(x1, utils::fp16x8_to_f32x8(w0 + k + 8), a01); + a10 = _mm256_fmadd_ps(x0, utils::fp16x8_to_f32x8(w1 + k), a10); + a11 = _mm256_fmadd_ps(x1, utils::fp16x8_to_f32x8(w1 + k + 8), a11); + a20 = _mm256_fmadd_ps(x0, utils::fp16x8_to_f32x8(w2 + k), a20); + a21 = _mm256_fmadd_ps(x1, utils::fp16x8_to_f32x8(w2 + k + 8), a21); + a30 = _mm256_fmadd_ps(x0, utils::fp16x8_to_f32x8(w3 + k), a30); + a31 = _mm256_fmadd_ps(x1, utils::fp16x8_to_f32x8(w3 + k + 8), a31); + } + for (; k + 7 < kc; k += 8) { + __m256 x0 = utils::fp16x8_to_f32x8(xi + k); + a00 = _mm256_fmadd_ps(x0, utils::fp16x8_to_f32x8(w0 + k), a00); + a10 = _mm256_fmadd_ps(x0, utils::fp16x8_to_f32x8(w1 + k), a10); + a20 = _mm256_fmadd_ps(x0, utils::fp16x8_to_f32x8(w2 + k), a20); + a30 = _mm256_fmadd_ps(x0, utils::fp16x8_to_f32x8(w3 + k), a30); + } + + float d0 = utils::hsum256(_mm256_add_ps(a00, a01)); + float d1 = utils::hsum256(_mm256_add_ps(a10, a11)); + float d2 = utils::hsum256(_mm256_add_ps(a20, a21)); + float d3 = utils::hsum256(_mm256_add_ps(a30, a31)); + + for (; k < kc; ++k) { + float xk = utils::cast(llaisys::fp16_t{xi[k]}); + d0 += xk * utils::cast(llaisys::fp16_t{w0[k]}); + d1 += xk * utils::cast(llaisys::fp16_t{w1[k]}); + d2 += xk * utils::cast(llaisys::fp16_t{w2[k]}); + d3 += xk * utils::cast(llaisys::fp16_t{w3[k]}); + } + + ybuf[i*N + j0] += d0; + ybuf[i*N + j0 + 1] += d1; + ybuf[i*N + j0 + 2] += d2; + ybuf[i*N + j0 + 3] += d3; + }); + } else { + for (size_t jj = 0; jj < jn; ++jj) { + const uint16_t* wj = W + (j0 + jj) * K + k0; + parallel_for(M, OMP_M_THRESHOLD, [&](size_t i) { + const uint16_t* xi = X + i * K + k0; + __m256 acc0 = _mm256_setzero_ps(); + __m256 acc1 = _mm256_setzero_ps(); + size_t k = 0; + for (; k + 15 < kc; k += 16) { + acc0 = _mm256_fmadd_ps(utils::fp16x8_to_f32x8(xi + k), + utils::fp16x8_to_f32x8(wj + k), acc0); + acc1 = _mm256_fmadd_ps(utils::fp16x8_to_f32x8(xi + k + 8), + utils::fp16x8_to_f32x8(wj + k + 8), acc1); + } + for (; k + 7 < kc; k += 8) + acc0 = _mm256_fmadd_ps(utils::fp16x8_to_f32x8(xi + k), + utils::fp16x8_to_f32x8(wj + k), acc0); + float sum = utils::hsum256(_mm256_add_ps(acc0, acc1)); + for (; k < kc; ++k) + sum += utils::cast(llaisys::fp16_t{xi[k]}) + * utils::cast(llaisys::fp16_t{wj[k]}); + ybuf[i*N + j0 + jj] += sum; + }); + } + } + } + } + + // f32 → fp16 写回 + parallel_for(M * N, OMP_M_THRESHOLD, [&](size_t idx) { + llaisys::fp16_t v = utils::cast(ybuf[idx]); + Y_out[idx] = v._v; + }); +} + +// ============================================================ +// 入口函数 +// ============================================================ void linear(tensor_t out, tensor_t in, tensor_t weight, tensor_t bias) { - TO_BE_IMPLEMENTED(); + in = in->isContiguous() ? in : in->contiguous(); + weight = weight->isContiguous() ? weight : weight->contiguous(); + switch (in->dtype()) { + case LLAISYS_DTYPE_F16: + linear_cpu_kernel(out, in, weight, bias); + break; + case LLAISYS_DTYPE_BF16: + linear_cpu_kernel(out, in, weight, bias); + break; + case LLAISYS_DTYPE_F32: + linear_cpu_kernel(out, in, weight, bias); + break; + default: + throw std::runtime_error("Not support this dtype!"); + } } } // namespace llaisys::ops diff --git a/src/ops/rearrange/op.cpp b/src/ops/rearrange/op.cpp index 017a6ae59..9f4d6c9e4 100644 --- a/src/ops/rearrange/op.cpp +++ b/src/ops/rearrange/op.cpp @@ -1,7 +1,39 @@ #include "op.hpp" namespace llaisys::ops { +template +void rearrange_kernel(tensor_t out, tensor_t in){ + const T*in_ptr=reinterpret_cast(in->data()); + T*out_ptr=reinterpret_cast(out->data()); + size_t total_num=in->numel(); + size_t dim=in->ndim(); + for(size_t i=0;i 0;) { + size_t cur_shape=in->shape()[j]; + size_t cur_index=index_acc%cur_shape; + index_acc/=cur_shape; + in_offset+=cur_index*in->strides()[j]; + out_offset+=cur_index*out->strides()[j]; + } + out_ptr[out_offset]=in_ptr[in_offset]; + } +} void rearrange(tensor_t out, tensor_t in) { - TO_BE_IMPLEMENTED(); + switch (in->dtype()) { + case LLAISYS_DTYPE_F16: + rearrange_kernel(out,in); + break; + case LLAISYS_DTYPE_BF16: + rearrange_kernel(out,in); + break; + case LLAISYS_DTYPE_F32: + rearrange_kernel(out,in); + break; + default: + throw std::runtime_error("Not support this dtype!"); + } } } // namespace llaisys::ops diff --git a/src/ops/rms_norm/op.cpp b/src/ops/rms_norm/op.cpp index 529553d9d..a00924322 100644 --- a/src/ops/rms_norm/op.cpp +++ b/src/ops/rms_norm/op.cpp @@ -1,7 +1,43 @@ #include "op.hpp" - +#include namespace llaisys::ops { +template +void rms_norm_cpu_kernel(tensor_t out, tensor_t in, tensor_t weight, float eps){ + T*out_ptr=reinterpret_cast(out->data()); + const T*in_ptr=reinterpret_cast(in->data()); + const T*weight_ptr=reinterpret_cast(weight->data()); + size_t d=in->shape().back(); + size_t n = in->numel() / d; + for(size_t i=0;i(in_ptr[i*d+j]); + sum+=cur_num*cur_num; + } + sum/=(float)d; + float std_sum=sqrtf(sum+eps); + for(size_t j=0;j(weight_ptr[j]); + float xi=utils::cast(in_ptr[i*d+j]); + float out_val=wi*xi/std_sum; + out_ptr[i*d+j]=utils::cast(out_val); + } + } +} void rms_norm(tensor_t out, tensor_t in, tensor_t weight, float eps) { - TO_BE_IMPLEMENTED(); + tensor_t contiguous_in = in->isContiguous() ? in : in->contiguous(); + switch (in->dtype()) { + case LLAISYS_DTYPE_F16: + rms_norm_cpu_kernel(out,contiguous_in,weight,eps); + break; + case LLAISYS_DTYPE_BF16: + rms_norm_cpu_kernel(out,contiguous_in,weight,eps); + break; + case LLAISYS_DTYPE_F32: + rms_norm_cpu_kernel(out,contiguous_in,weight,eps); + break; + default: + throw std::runtime_error("Not support this dtype!"); + } } } // namespace llaisys::ops diff --git a/src/ops/rope/op.cpp b/src/ops/rope/op.cpp index d60dbe64e..b505488a3 100644 --- a/src/ops/rope/op.cpp +++ b/src/ops/rope/op.cpp @@ -1,7 +1,154 @@ #include "op.hpp" +#include namespace llaisys::ops { +template +void rope_cpu_kernel(tensor_t out, tensor_t in, tensor_t pos_ids, float theta){ + T* out_ptr = reinterpret_cast(out->data()); + const T* in_ptr = reinterpret_cast(in->data()); + const int64_t* pos_ptr = reinterpret_cast(pos_ids->data()); + + const auto& in_shape = in->shape(); + const auto& pos_shape = pos_ids->shape(); + + // ----------------------------- + // 模式 1:单元测试 / 通用 3D 形式 + // in : [seq_len, n_heads, head_dim] + // pos : [seq_len] + // 与 test/ops/rope.py 中 torch_rope 完全对齐 + // 这里严格按 Tensor 的 strides 访问,避免对内存布局做任何假设 + // ----------------------------- + if (in_shape.size() == 3 && pos_shape.size() == 1 && pos_ids->numel() == in_shape[0]) { + size_t seq_len = in_shape[0]; + size_t n_heads = in_shape[1]; + size_t head_dim = in_shape[2]; + + const auto& in_strides = in->strides(); + const auto& out_strides = out->strides(); + ptrdiff_t s_in_0 = in_strides[0]; + ptrdiff_t s_in_1 = in_strides[1]; + ptrdiff_t s_in_2 = in_strides[2]; + ptrdiff_t s_out_0 = out_strides[0]; + ptrdiff_t s_out_1 = out_strides[1]; + ptrdiff_t s_out_2 = out_strides[2]; + + for (size_t i = 0; i < seq_len; ++i) { + float p_i = utils::cast(pos_ptr[i]); + for (size_t h = 0; h < n_heads; ++h) { + for (size_t k = 0; k < head_dim / 2; ++k) { + // 计算输入/输出索引,完全依赖 strides + size_t idx_a_in = i * s_in_0 + h * s_in_1 + k * s_in_2; + size_t idx_b_in = i * s_in_0 + h * s_in_1 + (k + head_dim / 2) * s_in_2; + size_t idx_a_out = i * s_out_0 + h * s_out_1 + k * s_out_2; + size_t idx_b_out = i * s_out_0 + h * s_out_1 + (k + head_dim / 2) * s_out_2; + + float theta_in = p_i / std::pow(theta, 2.0f * k / head_dim); + + float a_in = utils::cast(in_ptr[idx_a_in]); + float b_in = utils::cast(in_ptr[idx_b_in]); + + float cos_t = std::cos(theta_in); + float sin_t = std::sin(theta_in); + + T a_out = utils::cast(a_in * cos_t - b_in * sin_t); + T b_out = utils::cast(b_in * cos_t + a_in * sin_t); + + out_ptr[idx_a_out] = a_out; + out_ptr[idx_b_out] = b_out; + } + } + } + return; + } + + // ----------------------------- + // 模式 2:Qwen2 / DeepSeek 推理路径 + // in : [batch, seq_len, hidden] + // pos : [1, seq_len] 或 [seq_len] + // ----------------------------- + size_t N = in_shape[0]; // batch + size_t M = in_shape[1]; // seq_len + size_t D = in_shape[2]; // hidden + + size_t n_pos = pos_ids->numel(); + + for (size_t i = 0; i < N; i++) { + for (size_t j = 0; j < M; j++) { + size_t flat_idx = i * M + j; + float p_i; + if (n_pos == M) { + p_i = utils::cast(pos_ptr[j]); + } else if (n_pos == N * M) { + p_i = utils::cast(pos_ptr[flat_idx]); + } else { + p_i = utils::cast(pos_ptr[flat_idx % n_pos]); + } + + size_t base_offset = i * (M * D) + j * D; + + if (D == 1536 || D == 256) { + size_t head_dim = 128; + size_t n_heads = D / 128; + + for (size_t h = 0; h < n_heads; h++) { + for (size_t k = 0; k < head_dim / 2; k++) { + size_t offset = base_offset + h * head_dim; + size_t idx_a = offset + k; + size_t idx_b = offset + k + head_dim / 2; + + float theta_in = p_i / std::pow(theta, 2.0f * k / head_dim); + + float a_in = utils::cast(in_ptr[idx_a]); + float b_in = utils::cast(in_ptr[idx_b]); + + float cos_t = std::cos(theta_in); + float sin_t = std::sin(theta_in); + + T a_out = utils::cast(a_in * cos_t - b_in * sin_t); + T b_out = utils::cast(b_in * cos_t + a_in * sin_t); + + out_ptr[idx_a] = a_out; + out_ptr[idx_b] = b_out; + } + } + } else { + for (size_t k = 0; k < D / 2; k++) { + size_t idx_a = base_offset + k; + size_t idx_b = base_offset + k + D / 2; + + float theta_in = p_i / std::pow(theta, 2.0f * k / D); + + float a_in = utils::cast(in_ptr[idx_a]); + float b_in = utils::cast(in_ptr[idx_b]); + + float cos_t = std::cos(theta_in); + float sin_t = std::sin(theta_in); + + T a_out = utils::cast(a_in * cos_t - b_in * sin_t); + T b_out = utils::cast(b_in * cos_t + a_in * sin_t); + + out_ptr[idx_a] = a_out; + out_ptr[idx_b] = b_out; + } + } + } + } +} + void rope(tensor_t out, tensor_t in, tensor_t pos_ids, float theta) { - TO_BE_IMPLEMENTED(); + tensor_t cin = in->isContiguous() ? in : in->contiguous(); + switch (in->dtype()) { + case LLAISYS_DTYPE_F16: + rope_cpu_kernel(out,cin,pos_ids,theta); + break; + case LLAISYS_DTYPE_BF16: + rope_cpu_kernel(out,cin,pos_ids,theta); + break; + case LLAISYS_DTYPE_F32: + rope_cpu_kernel(out,cin,pos_ids,theta); + break; + default: + throw std::runtime_error("Not support this dtype!"); + } } -} // namespace llaisys::ops +} // namespace llaisys::ops \ No newline at end of file diff --git a/src/ops/self_attention/op.cpp b/src/ops/self_attention/op.cpp index 43d620142..246ba066f 100644 --- a/src/ops/self_attention/op.cpp +++ b/src/ops/self_attention/op.cpp @@ -1,7 +1,188 @@ #include "op.hpp" +#include +#include +#include + +const int INF = 0x3f3f3f3f; namespace llaisys::ops { + +template +void self_attention_kernel(tensor_t attn_val, tensor_t q, tensor_t k, tensor_t v, float scale){ + T* attn_val_ptr_base = reinterpret_cast(attn_val->data()); + const T* q_ptr_base = reinterpret_cast(q->data()); + const T* k_ptr_base = reinterpret_cast(k->data()); + const T* v_ptr_base = reinterpret_cast(v->data()); + + size_t batch, seqlen, nhead, d, total_len, nkvhead, dv; + + // 【关键】智能判断模式 + // DeepSeek 的 Hidden Dim 是 1536 或 256 + size_t last_dim = q->shape().back(); + bool is_deepseek = (last_dim == 1536 || last_dim == 256); + + if (is_deepseek) { + // -------------------------------------------------------- + // 模式 A: DeepSeek 推理模式 [Batch, Seq, Hidden] + // 说明: + // - q 始终是 [B, T, H],H=nh*head_dim (例如 1536 = 12 * 128) + // - k/v 在两种场景下: + // 1) 即时计算时为 [B, T, kv_dim],其中 kv_dim = nkvh * head_dim + // 2) 从 KV Cache 读出后为 [B, T_total, nkvh, head_dim] + // -------------------------------------------------------- + batch = q->shape()[0]; + seqlen = q->shape()[1]; + size_t hidden_q = q->shape()[2]; + + // 拆 head:Q 的 head 维度始终按 128 处理 + d = 128; + nhead = hidden_q / d; // 例如 1536 / 128 = 12 + + total_len = k->shape()[1]; + + if (k->shape().size() == 3) { + // [B, T_total, kv_dim] 视为 [B, T_total, nkvh, head_dim] 的拍扁形式 + size_t hidden_kv = k->shape()[2]; + dv = d; + nkvhead = hidden_kv / dv; // 例如 256 / 128 = 2 + } else if (k->shape().size() == 4) { + // [B, T_total, nkvh, head_dim] —— 来自 KV Cache 的 4D 形式 + nkvhead = k->shape()[2]; + dv = k->shape()[3]; // 一般为 128 + } else { + throw std::runtime_error("Unsupported K shape for DeepSeek mode"); + } + } + else if (k->shape().size() == 4) { + // -------------------------------------------------------- + // 模式 B: 4D 标准模式 [Batch, Seq, Head, Dim] + // -------------------------------------------------------- + batch = q->shape()[0]; + seqlen = q->shape()[1]; + size_t hidden_size = q->shape()[2]; + total_len = k->shape()[1]; + nkvhead = k->shape()[2]; + d = k->shape()[3]; + nhead = hidden_size / d; + dv = d; + } + else { + // -------------------------------------------------------- + // 模式 C: 单元测试/通用兼容模式 [Seq, Head, Dim] (你最初的逻辑) + // -------------------------------------------------------- + // 这里的 shape[0] 被视为 SeqLen,Batch 被视为 1 + batch = 1; + seqlen = q->shape()[0]; + nhead = q->shape()[1]; + d = q->shape()[2]; + + total_len = k->shape()[0]; + nkvhead = k->shape()[1]; + dv = v->shape()[2]; + } + + // 计算 Stride (内存跨度) + size_t stride_q = seqlen * nhead * d; + size_t stride_k = total_len * nkvhead * d; + size_t stride_v = total_len * nkvhead * dv; + size_t stride_out = seqlen * nhead * dv; + + // 针对 DeepSeek 的广播检查 + if (is_deepseek) { + size_t batch_k = k->shape()[0]; + size_t batch_v = v->shape()[0]; + if (batch_k < batch) stride_k = 0; + if (batch_v < batch) stride_v = 0; + } + + // 执行 Batch 循环 + // 注意:在模式 C (单元测试) 下,batch=1,只会跑一次,完美复现你最初的逻辑 + for (size_t b = 0; b < batch; b++) { + + T* attn_val_ptr = attn_val_ptr_base + b * stride_out; + const T* q_ptr = q_ptr_base + b * stride_q; + const T* k_ptr = k_ptr_base + b * stride_k; + const T* v_ptr = v_ptr_base + b * stride_v; + + std::vector A(total_len); + size_t group_size = (nkvhead == 0) ? 1 : (nhead / nkvhead); + if(group_size == 0) group_size = 1; + + auto get_k_index = [&](size_t n_index) -> size_t { + if(nkvhead == 1) return 0; + if(nkvhead == nhead) return n_index; + return n_index / group_size; + }; + + for(size_t n_index = 0; n_index < nhead; n_index++){ + size_t k_index = get_k_index(n_index); + for(size_t i = 0; i < seqlen; i++){ + float MAX_num = -1e30f; + // Q * K + for(size_t j = 0; j < total_len; j++){ + float sum = 0.0f; + for(size_t k = 0; k < d; k++){ + float q_val = utils::cast(q_ptr[i*(nhead*d)+n_index*d+k]); + float k_val = utils::cast(k_ptr[j*(nkvhead*d)+k_index*d+k]); + sum += q_val * k_val; + } + sum *= scale; + // Causal Mask logic + size_t global_i = total_len - seqlen + i; + if(j > global_i){ + A[j] = utils::cast(-INF); + } + else{ + A[j] = utils::cast(sum); + MAX_num = std::fmax(MAX_num, sum); + } + } + + // Softmax + float softmax_accu = 0; + for(size_t j = 0; j < total_len; j++){ + float a_val = utils::cast(A[j]); + softmax_accu += std::exp(a_val - MAX_num); + } + + float inv_accu = 1.0f / softmax_accu; + for(size_t j = 0; j < total_len; j++){ + float a_val = utils::cast(A[j]); + float a_sval = std::exp(a_val - MAX_num); + A[j] = utils::cast(a_sval * inv_accu); + } + + // Weighted Sum + for(size_t j = 0; j < dv; j++){ + float out_sum = 0.0f; + for(size_t k = 0; k < total_len; k++){ + float a_cv = utils::cast(A[k]); + float v_cv = utils::cast(v_ptr[k*nkvhead*dv+k_index*dv+j]); + out_sum += a_cv * v_cv; + } + attn_val_ptr[i*nhead*dv+n_index*dv+j] = utils::cast(out_sum); + } + } + } + } +} + void self_attention(tensor_t attn_val, tensor_t q, tensor_t k, tensor_t v, float scale) { - TO_BE_IMPLEMENTED(); + tensor_t cq = q->isContiguous() ? q : q->contiguous(); + tensor_t ck = k->isContiguous() ? k : k->contiguous(); + tensor_t cv = v->isContiguous() ? v : v->contiguous(); + switch (q->dtype()) { + case LLAISYS_DTYPE_F16: + self_attention_kernel(attn_val,cq,ck,cv,scale); + break; + case LLAISYS_DTYPE_BF16: + self_attention_kernel(attn_val,cq,ck,cv,scale); + break; + case LLAISYS_DTYPE_F32: + self_attention_kernel(attn_val,cq,ck,cv,scale); + break; + default: + throw std::runtime_error("Not support this dtype!"); + } } -} // namespace llaisys::ops +} // namespace llaisys::ops \ No newline at end of file diff --git a/src/ops/swiglu/op.cpp b/src/ops/swiglu/op.cpp index 47edbcc97..468ff2026 100644 --- a/src/ops/swiglu/op.cpp +++ b/src/ops/swiglu/op.cpp @@ -1,7 +1,40 @@ #include "op.hpp" +#include namespace llaisys::ops { + +template +void swiglu_kernel(tensor_t out, tensor_t gate, tensor_t up){ + T* out_ptr = reinterpret_cast(out->data()); + const T* gate_ptr = reinterpret_cast(gate->data()); + const T* up_ptr = reinterpret_cast(up->data()); + size_t n = out->numel(); + + for(size_t i = 0; i < n; i++){ + float up_val = utils::cast(up_ptr[i]); + float gate_val = utils::cast(gate_ptr[i]); + + // Swish / SiLU: x / (1 + exp(-x)) + float t_val = gate_val / (1.0f + std::exp(-gate_val)); + + float out_val = up_val * t_val; + out_ptr[i] = utils::cast(out_val); + } +} + void swiglu(tensor_t out, tensor_t gate, tensor_t up) { - TO_BE_IMPLEMENTED(); + switch (out->dtype()) { + case LLAISYS_DTYPE_F16: + swiglu_kernel(out,gate,up); + break; + case LLAISYS_DTYPE_BF16: + swiglu_kernel(out,gate,up); + break; + case LLAISYS_DTYPE_F32: + swiglu_kernel(out,gate,up); + break; + default: + throw std::runtime_error("Not support this dtype!"); + } } -} // namespace llaisys::ops +} // namespace llaisys::ops \ No newline at end of file diff --git a/src/tensor/tensor.cpp b/src/tensor/tensor.cpp index 2f594bb65..59ded9b7e 100644 --- a/src/tensor/tensor.cpp +++ b/src/tensor/tensor.cpp @@ -1,7 +1,6 @@ #include "tensor.hpp" - +#include "../ops/rearrange/op.hpp" #include "../utils.hpp" - #include #include #include @@ -164,32 +163,138 @@ void Tensor::debug() const { } bool Tensor::isContiguous() const { - TO_BE_IMPLEMENTED(); + size_t Rank=this->ndim(); + const auto&Cur_strides=this->strides(); + const auto&Shapes=this->shape(); + ptrdiff_t accumulate_stride=1; + if(Rank==0) return true; + for(size_t i=Rank;i>0;--i){ + size_t index=i-1; + if(accumulate_stride!=Cur_strides[index]) return false; + accumulate_stride*=Shapes[index]; + } return true; } tensor_t Tensor::permute(const std::vector &order) const { - TO_BE_IMPLEMENTED(); - return std::shared_ptr(new Tensor(_meta, _storage)); + if(order.size()!=this->ndim()){ + throw std::runtime_error("Order Error!"); + } + const auto&old_shape=this->shape(); + const auto&old_strides=this->strides(); + std::vector new_shape(old_shape.size()); + std::vector new_strides(old_strides.size()); + for(size_t i=0;i=this->ndim()){ + throw std::runtime_error("Index Error!"); + } + new_shape[i]=old_shape[order_index]; + new_strides[i]=old_strides[order_index]; + } + TensorMeta _meta{this->dtype(),std::move(new_shape),std::move(new_strides)}; + return std::shared_ptr(new Tensor(_meta, _storage,this->_offset)); } tensor_t Tensor::view(const std::vector &shape) const { - TO_BE_IMPLEMENTED(); - return std::shared_ptr(new Tensor(_meta, _storage)); + auto target_numel=std::accumulate(shape.begin(), shape.end(), size_t(1), std::multiplies()); + if(this->numel()!=target_numel) throw std::runtime_error("size error"); + if (this->isContiguous()) { + std::vector new_strides(shape.size()); + size_t stride = 1; + for (size_t i = shape.size(); i-- > 0;) { + new_strides[i] = static_cast(stride); + stride *= shape[i]; + } + TensorMeta meta{this->dtype(), shape, new_strides}; + return std::shared_ptr(new Tensor(meta, this->_storage, this->_offset)); + } + std::vector new_strides;new_strides.reserve(shape.size()); + const auto&old_strides=this->strides(); + const auto&old_shape=this->shape(); + size_t old_dim_index=0,split_divisor=1; + for(size_t new_dim_index=0;new_dim_index=old_shape.size()){ + throw std::runtime_error("Dim Error!"); + } + size_t available_dim_size=old_shape[old_dim_index]/split_divisor; + size_t original_stride=old_strides[old_dim_index]; + if(target_dim=old_shape.size()){ + throw std::runtime_error("Dim Error!"); + } + if(accumulated_size>1){ + if(old_strides[old_dim_index-1]!=static_cast(old_shape[old_dim_index]*old_strides[old_dim_index])){ + throw std::runtime_error("Transform Error!"); + } + } + accumulated_size*=old_shape[old_dim_index++]; + } + if(accumulated_size!=target_dim){ + throw std::runtime_error("Match Error!"); + } + new_strides.emplace_back(old_strides[old_dim_index-1]); + } + TensorMeta _meta{this->dtype(),shape,std::move(new_strides)}; + return std::shared_ptr(new Tensor(_meta, _storage,this->_offset)); } tensor_t Tensor::slice(size_t dim, size_t start, size_t end) const { - TO_BE_IMPLEMENTED(); - return std::shared_ptr(new Tensor(_meta, _storage)); + if(dim>=this->ndim()) throw std::runtime_error("Dim Error!"); + if(start>=end) throw std::runtime_error("Index Error!"); + if(end>this->shape()[dim]) throw std::runtime_error("End Error!"); + auto new_shape=this->shape(); + new_shape[dim]=end-start; + auto new_offset=this->_offset+start*this->strides()[dim]*this->elementSize(); + TensorMeta _meta{this->dtype(),std::move(new_shape),this->strides()}; + return std::shared_ptr(new Tensor(_meta, _storage,new_offset)); } void Tensor::load(const void *src_) { - TO_BE_IMPLEMENTED(); + std::byte* dis_ptr=this->data(); + size_t Elemsize_in_bytes=this->elementSize()*this->numel(); + llaisysMemcpyKind_t CurKind=this->deviceType()==LLAISYS_DEVICE_CPU?LLAISYS_MEMCPY_H2H:LLAISYS_MEMCPY_H2D; + core::context().runtime().api()->memcpy_sync(dis_ptr,src_,Elemsize_in_bytes,CurKind); } tensor_t Tensor::contiguous() const { - TO_BE_IMPLEMENTED(); - return std::shared_ptr(new Tensor(_meta, _storage)); + // 1. 如果已经是连续的,仿照 slice 的实现,返回一个指向相同存储的浅拷贝 + if (this->isContiguous()) { + // + return std::shared_ptr(new Tensor(this->_meta, this->_storage, this->_offset)); + } + + // 2. 如果不连续,创建一个形状相同、物理连续的新 Tensor + // + auto out = Tensor::create(this->shape(), this->dtype(), this->deviceType(), this->deviceId()); + + // 3. 【核心技巧】由于 rearrange 需要 tensor_t (shared_ptr) + // 我们将 const this 指针临时包装成一个不拥有所有权的 shared_ptr + // 加上 const_cast 是因为 rearrange 的输入参数类型要求 + tensor_t self_wrapper(const_cast(this), [](Tensor*){ + // 空删除器:防止这个临时 shared_ptr 析构时误删 this + }); + + // 4. 调用你定义的 rearrange 算子进行物理搬运 + // + llaisys::ops::rearrange(out, self_wrapper); + + return out; } tensor_t Tensor::reshape(const std::vector &shape) const { diff --git a/src/utils/blas_runtime.hpp b/src/utils/blas_runtime.hpp new file mode 100644 index 000000000..a00640b27 --- /dev/null +++ b/src/utils/blas_runtime.hpp @@ -0,0 +1,158 @@ +#pragma once +// ============================================================ +// 运行时 BLAS 检测与加载 (dlopen) +// +// 在程序首次调用 blas::sgemm() 时自动检测并加载系统上可用的 BLAS 库。 +// 优先级: MKL > OpenBLAS > 无 (返回 false, 调用方应 fallback 到手写内核) +// +// 使用方式: +// if (blas::available()) { +// blas::sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, +// M, N, K, 1.0f, A, K, B, K, 0.0f, C, N); +// } +// +// 此模块不引入编译期依赖, 同一份 .so 可在有/无 BLAS 的机器上运行。 +// ============================================================ + +#include +#include +#include + +#ifndef _WIN32 +#include +#endif + +namespace llaisys::blas { + +// CBLAS 枚举 (与标准 cblas.h 兼容) +enum CBLAS_ORDER { CblasRowMajor = 101, CblasColMajor = 102 }; +enum CBLAS_TRANSPOSE { CblasNoTrans = 111, CblasTrans = 112, CblasConjTrans = 113 }; + +// cblas_sgemm 函数指针类型 +using sgemm_fn_t = void (*)( + CBLAS_ORDER order, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, + int M, int N, int K, + float alpha, const float* A, int lda, + const float* B, int ldb, + float beta, float* C, int ldc); + +// 设置线程数的函数指针类型 +using set_num_threads_fn_t = void (*)(int); + +// 内部状态 (线程安全: 全局初始化一次) +namespace detail { + +enum class BlasBackend { NONE, MKL, OPENBLAS }; + +struct BlasState { + bool initialized = false; + bool is_available = false; + BlasBackend backend = BlasBackend::NONE; + void* handle = nullptr; + sgemm_fn_t sgemm_ptr = nullptr; + set_num_threads_fn_t set_threads_ptr = nullptr; +}; + +inline BlasState& state() { + static BlasState s; + return s; +} + +inline void try_load() { +#ifdef _WIN32 + state().initialized = true; + return; // Windows 不支持 dlopen +#else + auto& s = state(); + if (s.initialized) return; + s.initialized = true; + + // 尝试加载的库列表 (按优先级) + struct LibCandidate { + const char* path; + BlasBackend backend; + const char* sgemm_sym; + const char* threads_sym; + }; + + // MKL rt 是单一入口点, 会自动根据 CPU 选择最优内核 + LibCandidate candidates[] = { + // MKL (优先) + {"libmkl_rt.so.1", BlasBackend::MKL, "cblas_sgemm", "MKL_Set_Num_Threads"}, + {"libmkl_rt.so", BlasBackend::MKL, "cblas_sgemm", "MKL_Set_Num_Threads"}, + // OpenBLAS + {"libopenblas.so.0", BlasBackend::OPENBLAS, "cblas_sgemm", "openblas_set_num_threads"}, + {"libopenblas.so", BlasBackend::OPENBLAS, "cblas_sgemm", "openblas_set_num_threads"}, + }; + + // MKL 的默认 Intel 线程层 (libiomp5) 在某些服务器上存在 bug, + // 导致多线程 sgemm 产生错误结果 (累加被重复计算)。 + // 强制使用 GNU 线程层 (libgomp) 避免此问题。 + // 必须在 dlopen 之前设置, 因为 MKL 在加载时读取此环境变量。 + setenv("MKL_THREADING_LAYER", "GNU", 0); // 0 = 不覆盖用户已设置的值 + + for (auto& c : candidates) { + void* h = dlopen(c.path, RTLD_NOW | RTLD_LOCAL); + if (!h) continue; + + auto fn = (sgemm_fn_t)dlsym(h, c.sgemm_sym); + if (!fn) { + dlclose(h); + continue; + } + + s.handle = h; + s.sgemm_ptr = fn; + s.backend = c.backend; + s.is_available = true; + + // 尝试获取线程设置函数 (可选) + s.set_threads_ptr = (set_num_threads_fn_t)dlsym(h, c.threads_sym); + + const char* name = (c.backend == BlasBackend::MKL) ? "MKL" : "OpenBLAS"; + std::fprintf(stderr, "[llaisys] Runtime BLAS loaded: %s (%s)\n", name, c.path); + return; + } + + // 未找到任何 BLAS + std::fprintf(stderr, "[llaisys] No runtime BLAS found, using AVX2 fallback\n"); +#endif +} + +} // namespace detail + +// 检查是否有可用的 BLAS (首次调用时自动初始化) +inline bool available() { + if (!detail::state().initialized) + detail::try_load(); + return detail::state().is_available; +} + +// 获取后端名称 +inline const char* backend_name() { + if (!available()) return "none"; + switch (detail::state().backend) { + case detail::BlasBackend::MKL: return "MKL"; + case detail::BlasBackend::OPENBLAS: return "OpenBLAS"; + default: return "none"; + } +} + +// 设置 BLAS 内部线程数 +inline void set_num_threads(int n) { + if (available() && detail::state().set_threads_ptr) + detail::state().set_threads_ptr(n); +} + +// 调用 cblas_sgemm +inline void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, + int M, int N, int K, + float alpha, const float* A, int lda, + const float* B, int ldb, + float beta, float* C, int ldc) +{ + detail::state().sgemm_ptr(order, transA, transB, + M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); +} + +} // namespace llaisys::blas diff --git a/src/utils/simd.hpp b/src/utils/simd.hpp new file mode 100644 index 000000000..8b838d1d1 --- /dev/null +++ b/src/utils/simd.hpp @@ -0,0 +1,140 @@ +#pragma once +#include +#include + +namespace llaisys::utils { + +// ============================================================ +// AVX2 SIMD 辅助函数 +// 供各个 op 的 AVX2 特化共用 +// ============================================================ + +// AVX2 水平求和: 8 个 float → 1 个 float +inline float hsum256(__m256 v) { + __m128 hi = _mm256_extractf128_ps(v, 1); + __m128 lo = _mm256_castps256_ps128(v); + __m128 sum4 = _mm_add_ps(lo, hi); + sum4 = _mm_hadd_ps(sum4, sum4); + sum4 = _mm_hadd_ps(sum4, sum4); + float r; + _mm_store_ss(&r, sum4); + return r; +} + +// AVX2 float32 双累加器点积 +inline float avx2_dot(const float* __restrict__ a, + const float* __restrict__ b, + size_t len) +{ + __m256 acc0 = _mm256_setzero_ps(); + __m256 acc1 = _mm256_setzero_ps(); + + size_t k = 0; + for (; k + 15 < len; k += 16) { + acc0 = _mm256_fmadd_ps(_mm256_loadu_ps(a + k), _mm256_loadu_ps(b + k), acc0); + acc1 = _mm256_fmadd_ps(_mm256_loadu_ps(a + k + 8), _mm256_loadu_ps(b + k + 8), acc1); + } + for (; k + 7 < len; k += 8) { + acc0 = _mm256_fmadd_ps(_mm256_loadu_ps(a + k), _mm256_loadu_ps(b + k), acc0); + } + acc0 = _mm256_add_ps(acc0, acc1); + float dot = hsum256(acc0); + for (; k < len; ++k) { + dot += a[k] * b[k]; + } + return dot; +} + +// ============================================================ +// 批量类型转换 (8 元素 SIMD) +// ============================================================ + +// bf16 x8 → f32 x8 +// bf16 与 f32 共享指数格式, 仅需左移 16 位 +inline __m256 bf16x8_to_f32x8(const uint16_t* p) { + __m128i h8 = _mm_loadu_si128(reinterpret_cast(p)); + __m256i i32 = _mm256_cvtepu16_epi32(h8); + i32 = _mm256_slli_epi32(i32, 16); + return _mm256_castsi256_ps(i32); +} + +// fp16 x8 → f32 x8 +// 提取 sign/exp/mantissa, 重偏指数 (+112), 处理零值 +inline __m256 fp16x8_to_f32x8(const uint16_t* p) { + __m128i h8 = _mm_loadu_si128(reinterpret_cast(p)); + __m256i i32 = _mm256_cvtepu16_epi32(h8); + + __m256i sign = _mm256_slli_epi32( + _mm256_and_si256(i32, _mm256_set1_epi32(0x8000)), 16); + __m256i exp16 = _mm256_and_si256( + _mm256_srli_epi32(i32, 10), _mm256_set1_epi32(0x1F)); + __m256i mant = _mm256_and_si256(i32, _mm256_set1_epi32(0x3FF)); + + __m256i exp32 = _mm256_slli_epi32( + _mm256_add_epi32(exp16, _mm256_set1_epi32(112)), 23); + __m256i mant32 = _mm256_slli_epi32(mant, 13); + + __m256i result = _mm256_or_si256(sign, _mm256_or_si256(exp32, mant32)); + + // 处理零 (exp==0 且 mant==0) + __m256i is_zero = _mm256_cmpeq_epi32( + _mm256_and_si256(i32, _mm256_set1_epi32(0x7FFF)), + _mm256_setzero_si256()); + result = _mm256_andnot_si256(is_zero, result); + result = _mm256_or_si256(result, _mm256_and_si256(is_zero, sign)); + + return _mm256_castsi256_ps(result); +} + +// ============================================================ +// AVX-512 SIMD 辅助函数 (仅在支持 AVX-512 的编译器/目标上可用) +// ============================================================ +#ifdef __AVX512F__ + +// AVX-512 水平求和: 16 个 float → 1 个 float +inline float hsum512(__m512 v) { + // 256-bit 上下两半相加 + __m256 lo = _mm512_castps512_ps256(v); + __m256 hi = _mm512_extractf32x8_ps(v, 1); + __m256 sum8 = _mm256_add_ps(lo, hi); + return hsum256(sum8); +} + +// bf16 x16 → f32 x16 (AVX-512) +// 从 16 个 uint16_t 生成 16 个 f32 +inline __m512 bf16x16_to_f32x16(const uint16_t* p) { + __m256i h16 = _mm256_loadu_si256(reinterpret_cast(p)); + __m512i i32 = _mm512_cvtepu16_epi32(h16); + i32 = _mm512_slli_epi32(i32, 16); + return _mm512_castsi512_ps(i32); +} + +// fp16 x16 → f32 x16 (AVX-512) +inline __m512 fp16x16_to_f32x16(const uint16_t* p) { + __m256i h16 = _mm256_loadu_si256(reinterpret_cast(p)); + __m512i i32 = _mm512_cvtepu16_epi32(h16); + + __m512i sign = _mm512_slli_epi32( + _mm512_and_si512(i32, _mm512_set1_epi32(0x8000)), 16); + __m512i exp16 = _mm512_and_si512( + _mm512_srli_epi32(i32, 10), _mm512_set1_epi32(0x1F)); + __m512i mant = _mm512_and_si512(i32, _mm512_set1_epi32(0x3FF)); + + __m512i exp32 = _mm512_slli_epi32( + _mm512_add_epi32(exp16, _mm512_set1_epi32(112)), 23); + __m512i mant32 = _mm512_slli_epi32(mant, 13); + + __m512i result = _mm512_or_si512(sign, _mm512_or_si512(exp32, mant32)); + + // 处理零 (exp==0 且 mant==0): 使用 mask 操作 + __mmask16 is_zero = _mm512_cmpeq_epi32_mask( + _mm512_and_si512(i32, _mm512_set1_epi32(0x7FFF)), + _mm512_setzero_si512()); + result = _mm512_mask_mov_epi32(result, is_zero, sign); + + return _mm512_castsi512_ps(result); +} + +#endif // __AVX512F__ + +} // namespace llaisys::utils diff --git a/src/utils/types.cpp b/src/utils/types.cpp.old similarity index 99% rename from src/utils/types.cpp rename to src/utils/types.cpp.old index 4163c2148..adb52b198 100644 --- a/src/utils/types.cpp +++ b/src/utils/types.cpp.old @@ -1,5 +1,4 @@ #include "types.hpp" - #include namespace llaisys::utils { diff --git a/src/utils/types.hpp b/src/utils/types.hpp index e09619db8..3280a66ce 100644 --- a/src/utils/types.hpp +++ b/src/utils/types.hpp @@ -1,5 +1,6 @@ +#pragma once #include "llaisys.h" - +#include #include #include @@ -107,11 +108,85 @@ inline const char *dtype_to_str(llaisysDataType_t dtype) { } } -float _f16_to_f32(fp16_t val); -fp16_t _f32_to_f16(float val); +inline float _f16_to_f32(fp16_t val) { + uint16_t h = val._v; + uint32_t sign = (h & 0x8000) << 16; + int32_t exponent = (h >> 10) & 0x1F; + uint32_t mantissa = h & 0x3FF; + + uint32_t f32; + if (exponent == 31) { + if (mantissa != 0) { + f32 = sign | 0x7F800000 | (mantissa << 13); + } else { + f32 = sign | 0x7F800000; + } + } else if (exponent == 0) { + if (mantissa == 0) { + f32 = sign; + } else { + exponent = -14; + while ((mantissa & 0x400) == 0) { + mantissa <<= 1; + exponent--; + } + mantissa &= 0x3FF; + f32 = sign | ((exponent + 127) << 23) | (mantissa << 13); + } + } else { + f32 = sign | ((exponent + 127 - 15) << 23) | (mantissa << 13); + } + + float result; + memcpy(&result, &f32, sizeof(result)); + return result; +} + +inline fp16_t _f32_to_f16(float val) { + uint32_t f32; + memcpy(&f32, &val, sizeof(f32)); // Read the bits of the float32 + uint16_t sign = (f32 >> 16) & 0x8000; // Extract the sign bit + int32_t exponent = ((f32 >> 23) & 0xFF) - 127; // Extract and de-bias the exponent + uint32_t mantissa = f32 & 0x7FFFFF; // Extract the mantissa (fraction part) + + if (exponent >= 16) { // Special cases for Inf and NaN + // NaN + if (exponent == 128 && mantissa != 0) { + return fp16_t{static_cast(sign | 0x7E00)}; + } + // Infinity + return fp16_t{static_cast(sign | 0x7C00)}; + } else if (exponent >= -14) { // Normalized case + return fp16_t{(uint16_t)(sign | ((exponent + 15) << 10) | (mantissa >> 13))}; + } else if (exponent >= -24) { + mantissa |= 0x800000; // Add implicit leading 1 + mantissa >>= (-14 - exponent); + return fp16_t{(uint16_t)(sign | (mantissa >> 13))}; + } else { + // Too small for subnormal: return signed zero + return fp16_t{(uint16_t)sign}; + } +} + +inline float _bf16_to_f32(bf16_t val) { + uint32_t bits32 = static_cast(val._v) << 16; + + float out; + std::memcpy(&out, &bits32, sizeof(out)); + return out; +} -float _bf16_to_f32(bf16_t val); -bf16_t _f32_to_bf16(float val); +inline bf16_t _f32_to_bf16(float val) { + uint32_t bits32; + std::memcpy(&bits32, &val, sizeof(bits32)); + + const uint32_t rounding_bias = 0x00007FFF + // 0111 1111 1111 1111 + ((bits32 >> 16) & 1); + + uint16_t bf16_bits = static_cast((bits32 + rounding_bias) >> 16); + + return bf16_t{bf16_bits}; +} template TypeTo cast(TypeFrom val) { diff --git a/src/utils/utils_stub.cpp b/src/utils/utils_stub.cpp new file mode 100644 index 000000000..73549369b --- /dev/null +++ b/src/utils/utils_stub.cpp @@ -0,0 +1,6 @@ +// Placeholder so llaisys-utils target has at least one source file when types.cpp is not present. +namespace llaisys { +namespace utils_stub { +void placeholder() {} +} // namespace utils_stub +} // namespace llaisys diff --git a/test/debug.py b/test/debug.py new file mode 100644 index 000000000..0dde62296 --- /dev/null +++ b/test/debug.py @@ -0,0 +1,41 @@ +import argparse +from test_utils import * +import llaisys +import sys +import ctypes +from pathlib import Path + +# 不需要 snapshot_download 了,因为你已经下载好了 +# from huggingface_hub import snapshot_download + +def test_binding_only(): + print("--- Start Binding Test ---") + + # 1. 直接指定你刚才下载好的、确定的绝对路径 + # 注意:确保这个文件夹里真的有 .safetensors 文件 + real_model_path = "/home/cpp/ai-models/DeepSeek-R1-Distill-Qwen-1.5B" + + print(f"1. Using local model at: {real_model_path}") + + # 2. 检查一下路径对不对 (防御性编程) + if not Path(real_model_path).exists(): + print(f"!!! Error: Path does not exist: {real_model_path}") + return + + # 3. 尝试加载 C++ 模型 + try: + print("2. Calling C++ Qwen2 Init (Create + LoadWeights)...") + + # 直接传路径字符串! + model = llaisys.models.Qwen2(real_model_path, llaisys_device("cpu")) + + print("3. Success! C++ Object Created & Weights Loaded.") + print(f" Model Object: {model}") + + except Exception as e: + print(f"!!! Error Occurred: {e}") + import traceback + traceback.print_exc() + +if __name__ == "__main__": + test_binding_only() \ No newline at end of file diff --git a/test/test_tensor.py b/test/test_tensor.py index 9d2e9a075..c701f58cf 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -1,5 +1,4 @@ import llaisys - import torch from test_utils import * import argparse diff --git a/xmake.lua b/xmake.lua index 1f65f7a95..48e59d33a 100644 --- a/xmake.lua +++ b/xmake.lua @@ -3,6 +3,41 @@ set_encodings("utf-8") add_includedirs("include") +-- 全局开启 OpenMP 支持、最高级别优化 +-- native 模式 (服务器): 使用 -march=native, 编译器自动启用 AVX-512 等本机指令集 +-- 默认模式 (本地): 显式指定 -mavx2 -mfma, 兼容大多数 x86-64 CPU +add_cxflags("-fopenmp", "-O3") +add_ldflags("-fopenmp") +add_shflags("-fopenmp") +add_syslinks("gomp") -- 显式链接 GNU OpenMP 库 + +option("native") + set_default(false) + set_showmenu(true) + set_description("Use -march=native for best performance on current CPU (enables AVX-512 on supported CPUs)") +option_end() + +if has_config("native") then + add_cxflags("-march=native") +else + add_cxflags("-mavx2", "-mfma") +end + +-- OpenBLAS 集成: 从源码编译安装到 ~/openblas +option("openblas") + set_default(true) + set_showmenu(true) + set_description("Whether to use OpenBLAS for linear algebra acceleration") +option_end() + +if has_config("openblas") then + add_defines("USE_OPENBLAS") + add_includedirs(os.getenv("HOME") .. "/openblas/include") + add_linkdirs(os.getenv("HOME") .. "/openblas/lib") + add_links("openblas") + add_rpathdirs(os.getenv("HOME") .. "/openblas/lib") +end + -- CPU -- includes("xmake/cpu.lua") @@ -106,6 +141,7 @@ target("llaisys") set_languages("cxx17") set_warnings("all", "error") add_files("src/llaisys/*.cc") + add_files("src/models/*.cpp") set_installdir(".")