diff --git a/Project_3.md b/Project_3.md new file mode 100644 index 000000000..cbaafafcf --- /dev/null +++ b/Project_3.md @@ -0,0 +1,265 @@ +# Project 3: AI Chatbot 实现报告 + +## 项目概述 + +Project 3 的目标是在 LLAISYS 推理框架上搭建一个可以真正对话的 AI 聊天机器人。整体分三块:随机采样算子、后端服务器、前端聊天界面。 + +--- + +## 一、随机采样(Sample 算子) + +前两个作业里模型生成 token 用的是 argmax——每次直接取概率最高的那个词。这样生成的文字很确定,但也很死板,同一个问题永远给同一个答案。Project 3 要求实现真正的随机采样,让模型的回复更自然。 + +### 实现思路 + +采样算子在 `src/ops/sample/cpu/sample_cpu.cpp` 里,核心逻辑分六步: + +**第一步:温度缩放(Temperature)** + +模型最后一层输出的是 logits(原始分数),不是概率。在转成概率之前,先把所有 logits 除以温度参数 `temperature`: + +```cpp +scores[i] = val / temperature; +``` + +温度低于 1 时,高分和低分的差距被放大,模型更倾向于选高概率词(更保守);温度高于 1 时差距缩小,选择更随机(更有创意)。温度等于 1 时不做任何改变。 + +**第二步:排序** + +把所有 token 按分数从高到低排序。这里排的是下标数组而不是分数本身,方便后面映射回原始 token ID: + +```cpp +std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) { + return scores[a] > scores[b]; +}); +``` + +**第三步:Top-K 过滤** + +只保留分数最高的 K 个候选词,其余直接丢掉。`top_k=1` 就退化成 argmax,`top_k=0` 表示不过滤。 + +**第四步:Softmax** + +对保留下来的候选词做 softmax,把分数转成概率。为了数值稳定,先减去最大值再做 exp,避免数值溢出: + +```cpp +float max_score = scores[indices[0]]; +probs[i] = std::exp(scores[indices[i]] - max_score); +``` + +**第五步:Top-P(核采样)** + +按概率从高到低累加,一旦累积概率超过阈值 `top_p` 就截断,剩下的 token 重新归一化。比如 `top_p=0.9`,意思是只从"概率之和占 90%"的那些词里采样,尾部的低概率词被排除。 + +**第六步:多项式采样** + +用 `std::discrete_distribution` 按概率权重随机抽一个 token: + +```cpp +std::discrete_distribution dist(probs.begin(), probs.end()); +*sampled_token = static_cast(indices[dist(rng)]); +``` + +随机数生成器用 `thread_local` 修饰,每个线程独立一份,避免多线程竞争。种子混合了硬件熵和当前时间,解决 Windows 上 `random_device` 可能返回固定值的问题。 + +--- + +## 二、后端服务器 + +服务器用 FastAPI 实现,代码在 `test/server/` 目录下,分 `main.py`(启动入口)和 `routes.py`(路由逻辑)两个文件。 + +### 整体架构 + +``` +浏览器 <──SSE──> FastAPI (routes.py) <──ctypes──> C++ 推理后端 + │ + HuggingFace Tokenizer(负责编解码) +``` + +C++ 模型通过 ctypes 封装暴露给 Python,tokenizer 直接用 HuggingFace 的 `AutoTokenizer`。 + +### 启动流程 + +`main.py` 里的 `main()` 函数负责启动:先加载 C++ 模型和 tokenizer,再启动 uvicorn HTTP 服务器。模型在服务器启动前就加载好,不会让第一个请求等待。 + +```python +model = llaisys.models.Qwen2(args.model, device=device) +tokenizer = AutoTokenizer.from_pretrained(args.model) +set_model(model, tokenizer) +uvicorn.run(app, host=args.host, port=args.port) +``` + +### 会话管理 + +服务器用一个全局字典 `SESSIONS` 存所有会话,每个会话有唯一的 8 位 UUID、标题和完整的对话历史: + +```python +SESSIONS: dict = {} # session_id -> {id, title, history} +``` + +提供四个 REST 接口: +- `GET /v1/sessions` — 列出所有会话 +- `POST /v1/sessions` — 新建会话 +- `GET /v1/sessions/{id}` — 获取某个会话的完整历史 +- `DELETE /v1/sessions/{id}` — 删除会话 + +### 聊天接口 + +核心接口是 `POST /v1/chat/completions`,兼容 OpenAI 的 API 格式,支持流式和非流式两种模式。 + +每次请求时,客户端把完整的对话历史发过来,服务器用 tokenizer 的 `apply_chat_template` 把历史格式化成模型训练时用的 prompt 格式,再编码成 token ID 序列送给 C++ 模型。 + +```python +text = TOKENIZER.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) +tokens = TOKENIZER.encode(text) +``` + +### 流式响应(SSE) + +流式模式是这个服务器最有意思的部分。浏览器不用等模型生成完才看到回复,而是每生成一个 token 就立刻收到一段文字,体验上就像模型在"打字"。 + +技术上用的是 SSE(Server-Sent Events)协议,每个事件格式如下: + +``` +data: {"choices": [{"delta": {"content": "你好"}}]}\n\n +``` + +但有个问题:C++ 推理是同步阻塞的,而 FastAPI 跑在 asyncio 事件循环上。如果直接在 async 函数里调用阻塞代码,整个服务器会卡死。解决方法是用 `run_in_executor` 把每次推理调用丢到线程池里执行: + +```python +next_token = await loop.run_in_executor(None, _infer_one) +``` + +这样事件循环可以在等待推理结果的同时,把已经生成的 chunk 刷给客户端。 + +另一个细节是**增量解码**。BPE 分词器的特性是,有些 token 单独解码会产生乱码,必须和前后 token 一起解码才能得到正确文字(比如中文字符经常跨多个 token)。所以每步都解码全部已生成的 token,然后取新增的部分作为 delta: + +```python +raw = TOKENIZER.decode(generated, skip_special_tokens=False) +clean = _clean_output(raw) +delta = clean[prev_clean_len:] +``` + +### DeepSeek 思维链过滤 + +DeepSeek-R1 模型会在回答前先输出一段 `...` 包裹的推理过程,这部分不应该展示给用户。`_clean_output` 函数用正则把它过滤掉: + +```python +text = re.sub(r"[\s\S]*?", "", text) +# 处理 在 prompt 模板里、只有 出现在输出里的情况 +text = re.sub(r"^[\s\S]*?", "", text) +# 过滤 <|end_of_sentence|> 等特殊 token +text = re.sub(r"<[||][^||]*[||]>", "", text) +``` + +--- + +## 三、前端聊天界面 + +前端是纯原生 HTML + CSS + JavaScript,不依赖任何框架,代码在 `test/server/static/` 目录下。 + +### 页面结构 + +页面分左右两栏:左侧是会话列表(sidebar),右侧是聊天窗口。 + +``` +┌─────────────┬──────────────────────────────┐ +│ Chats [+] │ LLAISYS Chatbot [⚙] │ +│─────────────│──────────────────────────────│ +│ > 会话 1 │ │ +│ 会话 2 │ 消息区域 │ +│ 会话 3 │ │ +│ │──────────────────────────────│ +│ │ [输入框] [Send] │ +└─────────────┴──────────────────────────────┘ +``` + +右上角的齿轮按钮可以展开参数面板,调整 Temperature、Top-K、Top-P 和最大生成长度。 + +### 客户端状态管理 + +所有会话数据存在一个 `sessions` 对象里,`activeId` 记录当前显示的会话: + +```javascript +let sessions = {}; // { session_id: { id, title, history } } +let activeId = null; +``` + +切换会话时直接从本地 `sessions` 里读历史重新渲染,不需要再请求服务器,切换是即时的。 + +### 流式接收 + +发送消息后,前端用 `fetch` + `ReadableStream` 读取 SSE 流: + +```javascript +const reader = resp.body.getReader(); +const decoder = new TextDecoder(); +let buf = ''; + +while (true) { + const { done, value } = await reader.read(); + if (done) break; + buf += decoder.decode(value, { stream: true }); + // 按行解析 SSE 事件,提取 delta 内容追加到气泡里 +} +``` + +每收到一个 delta 就更新气泡的文字内容,实现打字机效果。 + +### 页面刷新恢复 + +页面加载时会从服务器拉取已有的会话列表和历史,刷新页面后对话不会丢失: + +```javascript +async function init() { + const existing = await fetchSessions(); + for (const s of existing) { + const history = await fetchSessionHistory(s.id); + sessions[s.id] = { id: s.id, title: s.title, history }; + } + // 恢复第一个会话,或自动新建一个 +} +``` + +--- + +## 四、模型推理核心(KV Cache) + +模型推理在 `src/models/qwen2.cpp` 里,每次生成一个 token 调用 `forward_token`,流程是标准的 Transformer decoder: + +``` +token_id → embedding → [×28层: RMSNorm → QKV Linear → RoPE → Attention → Linear → Add + → RMSNorm → Gate/Up Linear → SwiGLU → Down Linear → Add] + → RMSNorm → LM Head Linear → logits → sample → next token +``` + +KV Cache 是让推理速度可用的关键。没有 KV Cache 的话,每生成一个新 token,都要把整个历史序列重新过一遍 attention,计算量随序列长度线性增长。有了 KV Cache,每层的 K 和 V 矩阵在计算过后就存起来,下一步只需要计算新 token 的 Q,然后和缓存里的 K、V 做 attention: + +```cpp +// 把当前 token 的 k, v 写入缓存 +write_kv_cache(layer, pos, k_rope, v_view); + +// 取出从位置 0 到当前位置的全部 k, v +auto k_total = _k_cache[layer]->slice(0, 0, pos + 1); +auto v_total = _v_cache[layer]->slice(0, 0, pos + 1); + +// 只有当前 token 的 q,但 k/v 包含全部历史 +llaisys::ops::self_attention(attn_val, q_rope, k_total, v_total, scale); +``` + +每层的 KV Cache 预分配为 `(maxseq, nkvh, dh)` 的张量,`maxseq` 是最大序列长度,避免运行时动态分配内存。 + +--- + +## 五、启动方式 + +```bash +# 安装依赖 +pip install fastapi uvicorn + +# 启动服务器 +python test/server/main.py --model /path/to/DeepSeek-R1-Distill-Qwen-1.5B --port 8000 + +# 浏览器访问 +http://localhost:8000 +``` diff --git a/include/llaisys/models/qwen2.h b/include/llaisys/models/qwen2.h index 7054626d4..1b83d1500 100644 --- a/include/llaisys/models/qwen2.h +++ b/include/llaisys/models/qwen2.h @@ -37,6 +37,6 @@ __C { __export struct LlaisysQwen2Weights *llaisysQwen2ModelWeights(struct LlaisysQwen2Model * model); - __export int64_t llaisysQwen2ModelInfer(struct LlaisysQwen2Model * model, int64_t * token_ids, size_t ntoken); + __export int64_t llaisysQwen2ModelInfer(struct LlaisysQwen2Model * model, int64_t * token_ids, size_t ntoken, float temperature, int top_k, float top_p); } #endif // LLAISYS_MODELS_QWEN2_H diff --git a/include/llaisys/ops.h b/include/llaisys/ops.h index ddb3be246..dcab888a2 100644 --- a/include/llaisys/ops.h +++ b/include/llaisys/ops.h @@ -11,6 +11,7 @@ __C { __export void llaisysRearrange(llaisysTensor_t out, llaisysTensor_t in); __export void llaisysRmsNorm(llaisysTensor_t out, llaisysTensor_t in, llaisysTensor_t weight, float eps); __export void llaisysROPE(llaisysTensor_t out, llaisysTensor_t in, llaisysTensor_t pos_ids, float theta); + __export void llaisysSample(llaisysTensor_t sampled_token, llaisysTensor_t logits, float temperature, int top_k, float top_p); __export void llaisysSelfAttention(llaisysTensor_t attn_val, llaisysTensor_t q, llaisysTensor_t k, llaisysTensor_t v, float scale); __export void llaisysSwiGLU(llaisysTensor_t out, llaisysTensor_t gate, llaisysTensor_t up); } diff --git a/python/llaisys/libllaisys/__init__.py b/python/llaisys/libllaisys/__init__.py index f536fb527..f2d32ea7f 100644 --- a/python/llaisys/libllaisys/__init__.py +++ b/python/llaisys/libllaisys/__init__.py @@ -12,6 +12,7 @@ from .tensor import llaisysTensor_t from .tensor import load_tensor from .ops import load_ops +from .models import load_qwen2 def load_shared_library(): @@ -38,6 +39,7 @@ def load_shared_library(): load_runtime(LIB_LLAISYS) load_tensor(LIB_LLAISYS) load_ops(LIB_LLAISYS) +load_qwen2(LIB_LLAISYS) __all__ = [ diff --git a/python/llaisys/libllaisys/models/__init__.py b/python/llaisys/libllaisys/models/__init__.py new file mode 100644 index 000000000..e33c5e728 --- /dev/null +++ b/python/llaisys/libllaisys/models/__init__.py @@ -0,0 +1,6 @@ +from .qwen2 import ( + LlaisysQwen2Meta, + LlaisysQwen2Weights, + LlaisysQwen2Model, + load_qwen2, +) diff --git a/python/llaisys/libllaisys/models/qwen2.py b/python/llaisys/libllaisys/models/qwen2.py new file mode 100644 index 000000000..8ac53c4cc --- /dev/null +++ b/python/llaisys/libllaisys/models/qwen2.py @@ -0,0 +1,59 @@ +import ctypes +from ctypes import POINTER, c_size_t, c_int, c_int64, c_float, c_void_p +from ..llaisys_types import llaisysDataType_t, llaisysDeviceType_t +from ..tensor import llaisysTensor_t + +class LlaisysQwen2Meta(ctypes.Structure): + _fields_ = [ + ("dtype", llaisysDataType_t), + ("nlayer", c_size_t), + ("hs", c_size_t), + ("nh", c_size_t), + ("nkvh", c_size_t), + ("dh", c_size_t), + ("di", c_size_t), + ("maxseq", c_size_t), + ("voc", c_size_t), + ("epsilon", c_float), + ("theta", c_float), + ("end_token", c_int64), + ] + +class LlaisysQwen2Weights(ctypes.Structure): + _fields_ = [ + ("in_embed", llaisysTensor_t), + ("out_embed", llaisysTensor_t), + ("out_norm_w", llaisysTensor_t), + ("attn_norm_w", POINTER(llaisysTensor_t)), + ("attn_q_w", POINTER(llaisysTensor_t)), + ("attn_q_b", POINTER(llaisysTensor_t)), + ("attn_k_w", POINTER(llaisysTensor_t)), + ("attn_k_b", POINTER(llaisysTensor_t)), + ("attn_v_w", POINTER(llaisysTensor_t)), + ("attn_v_b", POINTER(llaisysTensor_t)), + ("attn_o_w", POINTER(llaisysTensor_t)), + ("mlp_norm_w", POINTER(llaisysTensor_t)), + ("mlp_gate_w", POINTER(llaisysTensor_t)), + ("mlp_up_w", POINTER(llaisysTensor_t)), + ("mlp_down_w", POINTER(llaisysTensor_t)), + ] + +LlaisysQwen2Model = c_void_p + +def load_qwen2(lib): + lib.llaisysQwen2ModelCreate.argtypes = [ + POINTER(LlaisysQwen2Meta), + llaisysDeviceType_t, + POINTER(c_int), + c_int, + ] + lib.llaisysQwen2ModelCreate.restype = LlaisysQwen2Model + + lib.llaisysQwen2ModelDestroy.argtypes = [LlaisysQwen2Model] + lib.llaisysQwen2ModelDestroy.restype = None + + lib.llaisysQwen2ModelWeights.argtypes = [LlaisysQwen2Model] + lib.llaisysQwen2ModelWeights.restype = POINTER(LlaisysQwen2Weights) + + lib.llaisysQwen2ModelInfer.argtypes = [LlaisysQwen2Model, POINTER(c_int64), c_size_t, c_float, c_int, c_float] + lib.llaisysQwen2ModelInfer.restype = c_int64 diff --git a/python/llaisys/libllaisys/ops.py b/python/llaisys/libllaisys/ops.py index 5be095eff..77a10f38e 100644 --- a/python/llaisys/libllaisys/ops.py +++ b/python/llaisys/libllaisys/ops.py @@ -1,5 +1,5 @@ from .tensor import llaisysTensor_t -from ctypes import c_float +from ctypes import c_float, c_int def load_ops(lib): lib.llaisysAdd.argtypes = [llaisysTensor_t, llaisysTensor_t, llaisysTensor_t] @@ -23,6 +23,9 @@ def load_ops(lib): lib.llaisysROPE.argtypes = [llaisysTensor_t, llaisysTensor_t, llaisysTensor_t, c_float] lib.llaisysROPE.restype = None + lib.llaisysSample.argtypes = [llaisysTensor_t, llaisysTensor_t, c_float, c_int, c_float] + lib.llaisysSample.restype = None + lib.llaisysSelfAttention.argtypes = [ llaisysTensor_t, # attn_val llaisysTensor_t, # q diff --git a/python/llaisys/models/qwen2.py b/python/llaisys/models/qwen2.py index 0d07b0b21..b7c8143d7 100644 --- a/python/llaisys/models/qwen2.py +++ b/python/llaisys/models/qwen2.py @@ -1,23 +1,189 @@ from typing import Sequence -from ..libllaisys import LIB_LLAISYS -from ..libllaisys import DeviceType - from pathlib import Path +import json +import ctypes +import numpy as np import safetensors +from ..libllaisys import LIB_LLAISYS +from ..libllaisys import DeviceType, DataType +from ..libllaisys import llaisysDeviceType_t, llaisysDataType_t +from ..libllaisys.models import LlaisysQwen2Meta, LlaisysQwen2Weights + class Qwen2: def __init__(self, model_path, device: DeviceType = DeviceType.CPU): - # TODO: Implement model constructor + print(f"Loading Qwen2 model from {model_path} on device {device.name}") + self._device = device model_path = Path(model_path) + config_path = model_path / "config.json" + if not config_path.exists(): + raise FileNotFoundError(f"config.json not found under {model_path}") + + with config_path.open("r", encoding="utf-8") as f: + config = json.load(f) + + torch_dtype = str(config.get("torch_dtype", "bfloat16")).lower() + if "bfloat" in torch_dtype: + dtype = DataType.BF16 + elif "float16" in torch_dtype or "fp16" in torch_dtype: + dtype = DataType.F16 + else: + dtype = DataType.F32 + + nlayer = int(config["num_hidden_layers"]) + hs = int(config["hidden_size"]) + nh = int(config["num_attention_heads"]) + nkvh = int(config.get("num_key_value_heads", nh)) + dh = int(hs // nh) + di = int(config["intermediate_size"]) + maxseq = int(config.get("max_position_embeddings", config.get("max_seq_len", 2048))) + voc = int(config["vocab_size"]) + epsilon = float(config.get("rms_norm_eps", 1e-6)) + theta = float(config.get("rope_theta", 10000.0)) + + eos_token = config.get("eos_token_id", None) + if isinstance(eos_token, list): + end_token = int(eos_token[0]) if eos_token else -1 + elif eos_token is None: + end_token = -1 + else: + end_token = int(eos_token) + + meta = LlaisysQwen2Meta( + llaisysDataType_t(dtype), + nlayer, + hs, + nh, + nkvh, + dh, + di, + maxseq, + voc, + epsilon, + theta, + end_token, + ) + + device_ids = (ctypes.c_int * 1)(0) + self._model = LIB_LLAISYS.llaisysQwen2ModelCreate( + ctypes.byref(meta), + llaisysDeviceType_t(device), + device_ids, + ctypes.c_int(1), + ) + self._weights_ptr = LIB_LLAISYS.llaisysQwen2ModelWeights(self._model) + self._weights: LlaisysQwen2Weights = self._weights_ptr.contents + self._dtype = dtype + self._end_token = end_token + + loaded = set() + use_torch_fallback = False + if self._dtype == DataType.BF16: + try: + np.dtype("bfloat16") + except TypeError: + use_torch_fallback = True for file in sorted(model_path.glob("*.safetensors")): - data_ = safetensors.safe_open(file, framework="numpy", device="cpu") - for name_ in data_.keys(): - ## TODO: load the model weights - pass + if use_torch_fallback: + import torch + + data_ = safetensors.safe_open(file, framework="pt", device="cpu") + for name_ in data_.keys(): + tensor = self._resolve_weight(name_) + if tensor is None: + continue + t = data_.get_tensor(name_) + if t.dtype != torch.bfloat16: + t = t.to(torch.bfloat16) + raw = t.view(torch.uint16).contiguous().cpu().numpy() + LIB_LLAISYS.tensorLoad(tensor, ctypes.c_void_p(raw.ctypes.data)) + loaded.add(name_) + else: + data_ = safetensors.safe_open(file, framework="numpy", device="cpu") + for name_ in data_.keys(): + tensor = self._resolve_weight(name_) + if tensor is None: + continue + arr = data_.get_tensor(name_) + self._load_tensor(tensor, arr) + loaded.add(name_) + + if "lm_head.weight" not in loaded: + self._weights_ptr.contents.out_embed = self._weights_ptr.contents.in_embed + + def __del__(self): + if hasattr(self, "_model") and self._model is not None: + LIB_LLAISYS.llaisysQwen2ModelDestroy(self._model) + self._model = None + + def _np_dtype(self): + if self._dtype == DataType.BF16: + return np.dtype("bfloat16") + if self._dtype == DataType.F16: + return np.float16 + if self._dtype == DataType.F32: + return np.float32 + raise ValueError(f"Unsupported dtype: {self._dtype}") + + def _load_tensor(self, tensor, arr): + target_dtype = self._np_dtype() + if arr.dtype != target_dtype: + arr = arr.astype(target_dtype) + arr = np.ascontiguousarray(arr) + LIB_LLAISYS.tensorLoad(tensor, ctypes.c_void_p(arr.ctypes.data)) + + def _resolve_weight(self, name: str): + if name == "model.embed_tokens.weight": + return self._weights.in_embed + if name == "lm_head.weight": + return self._weights.out_embed + if name == "model.norm.weight": + return self._weights.out_norm_w + + if name.startswith("model.layers."): + parts = name.split(".") + if len(parts) < 5: + return None + try: + layer = int(parts[2]) + except ValueError: + return None + + block = parts[3] + if block == "input_layernorm" and parts[-1] == "weight": + return self._weights.attn_norm_w[layer] + if block in ("self_attn", "self_attention"): + if len(parts) < 6: + return None + proj = parts[4] + param = parts[5] + if proj == "q_proj": + return self._weights.attn_q_w[layer] if param == "weight" else self._weights.attn_q_b[layer] + if proj == "k_proj": + return self._weights.attn_k_w[layer] if param == "weight" else self._weights.attn_k_b[layer] + if proj == "v_proj": + return self._weights.attn_v_w[layer] if param == "weight" else self._weights.attn_v_b[layer] + if proj == "o_proj" and param == "weight": + return self._weights.attn_o_w[layer] + if block == "post_attention_layernorm" and parts[-1] == "weight": + return self._weights.mlp_norm_w[layer] + if block == "mlp": + if len(parts) < 6: + return None + proj = parts[4] + param = parts[5] + if proj == "gate_proj" and param == "weight": + return self._weights.mlp_gate_w[layer] + if proj == "up_proj" and param == "weight": + return self._weights.mlp_up_w[layer] + if proj == "down_proj" and param == "weight": + return self._weights.mlp_down_w[layer] + + return None def generate( self, @@ -27,7 +193,22 @@ def generate( top_p: float = 0.8, temperature: float = 0.8, ): + if max_new_tokens is None: + max_new_tokens = 128 - # TODO: Implement generate function + tokens = list(inputs) + for _ in range(max_new_tokens): + arr = (ctypes.c_int64 * len(tokens))(*tokens) + next_token = LIB_LLAISYS.llaisysQwen2ModelInfer( + self._model, + arr, + ctypes.c_size_t(len(tokens)), + ctypes.c_float(temperature), + ctypes.c_int(top_k), + ctypes.c_float(top_p), + ) + tokens.append(int(next_token)) + if self._end_token >= 0 and next_token == self._end_token: + break - return [] + return tokens diff --git a/python/llaisys/ops.py b/python/llaisys/ops.py index ed0180bc8..db25cd6ef 100644 --- a/python/llaisys/ops.py +++ b/python/llaisys/ops.py @@ -40,6 +40,16 @@ def rope(out: Tensor, inp: Tensor, pos_ids: Tensor, theta: float): out.lib_tensor(), inp.lib_tensor(), pos_ids.lib_tensor(), c_float(theta) ) + @staticmethod + def sample(sampled_token: Tensor, logits: Tensor, temperature: float, top_k: int, top_p: float): + LIB_LLAISYS.llaisysSample( + sampled_token.lib_tensor(), + logits.lib_tensor(), + c_float(temperature), + c_int(top_k), + c_float(top_p) + ) + @staticmethod def self_attention(attn_val: Tensor, q: Tensor, k: Tensor, v: Tensor, scale: float): LIB_LLAISYS.llaisysSelfAttention( diff --git a/src/llaisys/models/qwen2.cc b/src/llaisys/models/qwen2.cc new file mode 100644 index 000000000..35f0fff4b --- /dev/null +++ b/src/llaisys/models/qwen2.cc @@ -0,0 +1,52 @@ +#include "llaisys/models/qwen2.h" +#include "../../models/qwen2.hpp" +#include "../../utils.hpp" + +__C { + struct LlaisysQwen2Model { + llaisys::models::Qwen2Model *model; + }; + + struct LlaisysQwen2Model *llaisysQwen2ModelCreate(const LlaisysQwen2Meta *meta, llaisysDeviceType_t device, int *device_ids, int ndevice) { + if (meta == nullptr) { + std::cerr << "[ERROR] Qwen2ModelCreate: meta is null" << std::endl; + return nullptr; + } + try { + auto *model = new llaisys::models::Qwen2Model(*meta, device, device_ids, ndevice); + return new LlaisysQwen2Model{model}; + } catch (const std::exception &e) { + std::cerr << "[ERROR] Qwen2ModelCreate: " << e.what() << std::endl; + return nullptr; + } + } + + void llaisysQwen2ModelDestroy(struct LlaisysQwen2Model *model) { + if (!model) { + return; + } + delete model->model; + delete model; + } + + struct LlaisysQwen2Weights *llaisysQwen2ModelWeights(struct LlaisysQwen2Model *model) { + if (model == nullptr) { + std::cerr << "[ERROR] Qwen2ModelWeights: model is null" << std::endl; + return nullptr; + } + return model->model->weights(); + } + + int64_t llaisysQwen2ModelInfer(struct LlaisysQwen2Model *model, int64_t *token_ids, size_t ntoken, float temperature, int top_k, float top_p) { + if (model == nullptr) { + std::cerr << "[ERROR] Qwen2ModelInfer: model is null" << std::endl; + return -1; + } + try { + return model->model->infer(token_ids, ntoken, temperature, top_k, top_p); + } catch (const std::exception &e) { + std::cerr << "[ERROR] Qwen2ModelInfer: " << e.what() << std::endl; + return -1; + } + } +} diff --git a/src/llaisys/ops.cc b/src/llaisys/ops.cc index c99fbc32f..0938a54cc 100644 --- a/src/llaisys/ops.cc +++ b/src/llaisys/ops.cc @@ -9,6 +9,7 @@ #include "../ops/rearrange/op.hpp" #include "../ops/rms_norm/op.hpp" #include "../ops/rope/op.hpp" +#include "../ops/sample/op.hpp" #include "../ops/self_attention/op.hpp" #include "../ops/swiglu/op.hpp" @@ -34,6 +35,9 @@ __C { void llaisysROPE(llaisysTensor_t out, llaisysTensor_t in, llaisysTensor_t pos_ids, float theta) { llaisys::ops::rope(out->tensor, in->tensor, pos_ids->tensor, theta); } + void llaisysSample(llaisysTensor_t sampled_token, llaisysTensor_t logits, float temperature, int top_k, float top_p) { + llaisys::ops::sample(sampled_token->tensor, logits->tensor, temperature, top_k, top_p); + } void llaisysSelfAttention(llaisysTensor_t attn_val, llaisysTensor_t q, llaisysTensor_t k, llaisysTensor_t v, float scale) { llaisys::ops::self_attention(attn_val->tensor, q->tensor, k->tensor, v->tensor, scale); } diff --git a/src/models/qwen2.cpp b/src/models/qwen2.cpp new file mode 100644 index 000000000..dbd217e12 --- /dev/null +++ b/src/models/qwen2.cpp @@ -0,0 +1,277 @@ +#include "qwen2.hpp" + +#include "../ops/add/op.hpp" +#include "../ops/argmax/op.hpp" +#include "../ops/embedding/op.hpp" +#include "../ops/linear/op.hpp" +#include "../ops/rms_norm/op.hpp" +#include "../ops/rope/op.hpp" +#include "../ops/self_attention/op.hpp" +#include "../ops/swiglu/op.hpp" +#include "../ops/sample/op.hpp" +#include "../utils.hpp" + +#include +#include + +namespace llaisys::models { + +Qwen2Model::Qwen2Model(const LlaisysQwen2Meta &meta, llaisysDeviceType_t device, int *device_ids, int ndevice) + : _meta(meta), _device(device), _device_id(0), _past_len(0) { + CHECK_ARGUMENT(device == LLAISYS_DEVICE_CPU, "Qwen2Model: only CPU is supported in this implementation"); + if (device_ids != nullptr && ndevice > 0) { + _device_id = device_ids[0]; + } + CHECK_ARGUMENT(_meta.nlayer > 0, "Qwen2Model: nlayer must be > 0"); + CHECK_ARGUMENT(_meta.hs > 0, "Qwen2Model: hidden size must be > 0"); + CHECK_ARGUMENT(_meta.nh > 0, "Qwen2Model: nhead must be > 0"); + CHECK_ARGUMENT(_meta.dh > 0, "Qwen2Model: head dim must be > 0"); + CHECK_ARGUMENT(_meta.hs == _meta.nh * _meta.dh, "Qwen2Model: hidden size must equal nhead * head_dim"); + + _weights.in_embed = create_weight_tensor({_meta.voc, _meta.hs}); + _weights.out_embed = create_weight_tensor({_meta.voc, _meta.hs}); + _weights.out_norm_w = create_weight_tensor({_meta.hs}); + + _weights.attn_norm_w = new llaisysTensor_t[_meta.nlayer]; + _weights.attn_q_w = new llaisysTensor_t[_meta.nlayer]; + _weights.attn_q_b = new llaisysTensor_t[_meta.nlayer]; + _weights.attn_k_w = new llaisysTensor_t[_meta.nlayer]; + _weights.attn_k_b = new llaisysTensor_t[_meta.nlayer]; + _weights.attn_v_w = new llaisysTensor_t[_meta.nlayer]; + _weights.attn_v_b = new llaisysTensor_t[_meta.nlayer]; + _weights.attn_o_w = new llaisysTensor_t[_meta.nlayer]; + _weights.mlp_norm_w = new llaisysTensor_t[_meta.nlayer]; + _weights.mlp_gate_w = new llaisysTensor_t[_meta.nlayer]; + _weights.mlp_up_w = new llaisysTensor_t[_meta.nlayer]; + _weights.mlp_down_w = new llaisysTensor_t[_meta.nlayer]; + + const size_t qkv_dim = _meta.nh * _meta.dh; + const size_t kv_dim = _meta.nkvh * _meta.dh; + + for (size_t i = 0; i < _meta.nlayer; i++) { + _weights.attn_norm_w[i] = create_weight_tensor({_meta.hs}); + _weights.attn_q_w[i] = create_weight_tensor({qkv_dim, _meta.hs}); + _weights.attn_q_b[i] = create_weight_tensor({qkv_dim}); + _weights.attn_k_w[i] = create_weight_tensor({kv_dim, _meta.hs}); + _weights.attn_k_b[i] = create_weight_tensor({kv_dim}); + _weights.attn_v_w[i] = create_weight_tensor({kv_dim, _meta.hs}); + _weights.attn_v_b[i] = create_weight_tensor({kv_dim}); + _weights.attn_o_w[i] = create_weight_tensor({_meta.hs, qkv_dim}); + _weights.mlp_norm_w[i] = create_weight_tensor({_meta.hs}); + _weights.mlp_gate_w[i] = create_weight_tensor({_meta.di, _meta.hs}); + _weights.mlp_up_w[i] = create_weight_tensor({_meta.di, _meta.hs}); + _weights.mlp_down_w[i] = create_weight_tensor({_meta.hs, _meta.di}); + + zero_tensor(_weights.attn_q_b[i]); + zero_tensor(_weights.attn_k_b[i]); + zero_tensor(_weights.attn_v_b[i]); + } + + _zero_bias_hs = llaisys::Tensor::create({_meta.hs}, _meta.dtype, _device, _device_id); + _zero_bias_di = llaisys::Tensor::create({_meta.di}, _meta.dtype, _device, _device_id); + _zero_bias_voc = llaisys::Tensor::create({_meta.voc}, _meta.dtype, _device, _device_id); + + zero_tensor_data(_zero_bias_hs); + zero_tensor_data(_zero_bias_di); + zero_tensor_data(_zero_bias_voc); + + _k_cache.resize(_meta.nlayer); + _v_cache.resize(_meta.nlayer); + for (size_t i = 0; i < _meta.nlayer; i++) { + _k_cache[i] = llaisys::Tensor::create({_meta.maxseq, _meta.nkvh, _meta.dh}, _meta.dtype, _device, _device_id); + _v_cache[i] = llaisys::Tensor::create({_meta.maxseq, _meta.nkvh, _meta.dh}, _meta.dtype, _device, _device_id); + } +} + +Qwen2Model::~Qwen2Model() { + delete[] _weights.attn_norm_w; + delete[] _weights.attn_q_w; + delete[] _weights.attn_q_b; + delete[] _weights.attn_k_w; + delete[] _weights.attn_k_b; + delete[] _weights.attn_v_w; + delete[] _weights.attn_v_b; + delete[] _weights.attn_o_w; + delete[] _weights.mlp_norm_w; + delete[] _weights.mlp_gate_w; + delete[] _weights.mlp_up_w; + delete[] _weights.mlp_down_w; + + for (auto *tensor : _owned_tensors) { + delete tensor; + } +} + +LlaisysQwen2Weights *Qwen2Model::weights() { + return &_weights; +} + +llaisysTensor_t Qwen2Model::create_weight_tensor(const std::vector &shape) { + auto tensor = llaisys::Tensor::create(shape, _meta.dtype, _device, _device_id); + auto *wrapped = new LlaisysTensor{tensor}; + _owned_tensors.push_back(wrapped); + return wrapped; +} + +llaisys::tensor_t Qwen2Model::unwrap(llaisysTensor_t tensor) const { + return tensor->tensor; +} + +void Qwen2Model::zero_tensor(llaisysTensor_t tensor) { + zero_tensor_data(tensor->tensor); +} + +void Qwen2Model::zero_tensor_data(const llaisys::tensor_t &tensor) { + CHECK_ARGUMENT(tensor->deviceType() == LLAISYS_DEVICE_CPU, "Qwen2Model: only CPU is supported for zeroing"); + const size_t bytes = tensor->numel() * tensor->elementSize(); + if (bytes == 0) { + return; + } + std::memset(tensor->data(), 0, bytes); +} + +void Qwen2Model::write_kv_cache(size_t layer, size_t pos, const llaisys::tensor_t &k, const llaisys::tensor_t &v) { + CHECK_ARGUMENT(layer < _meta.nlayer, "Qwen2Model: layer out of range"); + CHECK_ARGUMENT(pos < _meta.maxseq, "Qwen2Model: position exceeds max sequence length"); + + const size_t elem_bytes = k->elementSize(); + const size_t row_elems = _meta.nkvh * _meta.dh; + const size_t row_bytes = row_elems * elem_bytes; + + std::byte *k_dst = _k_cache[layer]->data() + pos * row_bytes; + std::byte *v_dst = _v_cache[layer]->data() + pos * row_bytes; + + std::memcpy(k_dst, k->data(), row_bytes); + std::memcpy(v_dst, v->data(), row_bytes); +} + +void Qwen2Model::reset_state() { + _past_len = 0; +} + +llaisys::tensor_t Qwen2Model::forward_token(int64_t token_id, size_t pos) { + CHECK_ARGUMENT(pos < _meta.maxseq, "Qwen2Model: position exceeds max sequence length"); + + auto token_tensor = llaisys::Tensor::create({1}, LLAISYS_DTYPE_I64, _device, _device_id); + token_tensor->load(&token_id); + + // x->(1, hs) + auto x = llaisys::Tensor::create({1, _meta.hs}, _meta.dtype, _device, _device_id); + llaisys::ops::embedding(x, token_tensor, unwrap(_weights.in_embed)); + + auto pos_ids = llaisys::Tensor::create({1}, LLAISYS_DTYPE_I64, _device, _device_id); + int64_t pos_id = static_cast(pos); + pos_ids->load(&pos_id); + + const float scale = 1.0f / std::sqrt(static_cast(_meta.dh)); + + for (size_t layer = 0; layer < _meta.nlayer; layer++) { + // x_norm->(1, hs) + auto x_norm = llaisys::Tensor::create({1, _meta.hs}, _meta.dtype, _device, _device_id); + llaisys::ops::rms_norm(x_norm, x, unwrap(_weights.attn_norm_w[layer]), _meta.epsilon); + + // q->(1, nh*dh), k->(1, nkvh*dh), v->(1, nkvh*dh) + // for qwen2 q->(1, 1536), k,v->(1, 256) + auto q = llaisys::Tensor::create({1, _meta.nh * _meta.dh}, _meta.dtype, _device, _device_id); + auto k = llaisys::Tensor::create({1, _meta.nkvh * _meta.dh}, _meta.dtype, _device, _device_id); + auto v = llaisys::Tensor::create({1, _meta.nkvh * _meta.dh}, _meta.dtype, _device, _device_id); + + // Wq->(1536, 1536), Wk->(256, 1536), Wv->(256, 1536) + // x_norm->(1, 1536), q = x_norm * Wq^T + bq + // q->(1, 1536), k,v->(1, 256) + llaisys::ops::linear(q, x_norm, unwrap(_weights.attn_q_w[layer]), unwrap(_weights.attn_q_b[layer])); + llaisys::ops::linear(k, x_norm, unwrap(_weights.attn_k_w[layer]), unwrap(_weights.attn_k_b[layer])); + llaisys::ops::linear(v, x_norm, unwrap(_weights.attn_v_w[layer]), unwrap(_weights.attn_v_b[layer])); + + // q_view->(1, 12, 128), k_view,v_view->(1, 2, 128) + auto q_view = q->view({1, _meta.nh, _meta.dh}); + auto k_view = k->view({1, _meta.nkvh, _meta.dh}); + auto v_view = v->view({1, _meta.nkvh, _meta.dh}); + + auto q_rope = llaisys::Tensor::create({1, _meta.nh, _meta.dh}, _meta.dtype, _device, _device_id); + auto k_rope = llaisys::Tensor::create({1, _meta.nkvh, _meta.dh}, _meta.dtype, _device, _device_id); + + // q_rope->(1, 12, 128), k_rope->(1, 2, 128) + llaisys::ops::rope(q_rope, q_view, pos_ids, _meta.theta); + llaisys::ops::rope(k_rope, k_view, pos_ids, _meta.theta); + + // kv_cache[layer]->(maxseq, nkvh, dh) = (131072, 2, 128) + write_kv_cache(layer, pos, k_rope, v_view); + + // k_total->(n, 2, 128), v_total->(n, 2, 128) + auto k_total = _k_cache[layer]->slice(0, 0, pos + 1); + auto v_total = _v_cache[layer]->slice(0, 0, pos + 1); + + // attn_val->(1, nh, dh) = (1, 12, 128) + auto attn_val = llaisys::Tensor::create({1, _meta.nh, _meta.dh}, _meta.dtype, _device, _device_id); + // attn_val = q @ k^T / sqrt(dh) @v = (1, 12, 128) @ (n, 2, 128)^T @ (n, 2, 128) = (1, 12, 128) + llaisys::ops::self_attention(attn_val, q_rope, k_total, v_total, scale); + + // attn_val_2d->(1, 1536) + auto attn_val_2d = attn_val->view({1, _meta.hs}); + // attn_out = (1, 1536) attn_o_w->(1536, 1536) + auto attn_out = llaisys::Tensor::create({1, _meta.hs}, _meta.dtype, _device, _device_id); + llaisys::ops::linear(attn_out, attn_val_2d, unwrap(_weights.attn_o_w[layer]), _zero_bias_hs); + + // x = x + attn_out ->(1, 1536) + auto x_attn = llaisys::Tensor::create({1, _meta.hs}, _meta.dtype, _device, _device_id); + llaisys::ops::add(x_attn, x, attn_out); + x = x_attn; + + // mlp_norm_w->(1, 1536) + auto mlp_norm = llaisys::Tensor::create({1, _meta.hs}, _meta.dtype, _device, _device_id); + llaisys::ops::rms_norm(mlp_norm, x, unwrap(_weights.mlp_norm_w[layer]), _meta.epsilon); + + // mlp_gate_w->(8960, 1536), mlp_up_w->(8960, 1536) + auto gate = llaisys::Tensor::create({1, _meta.di}, _meta.dtype, _device, _device_id); + auto up = llaisys::Tensor::create({1, _meta.di}, _meta.dtype, _device, _device_id); + // gate, up->(1, 8960) + llaisys::ops::linear(gate, mlp_norm, unwrap(_weights.mlp_gate_w[layer]), _zero_bias_di); + llaisys::ops::linear(up, mlp_norm, unwrap(_weights.mlp_up_w[layer]), _zero_bias_di); + + // swiglu_out->(1, 8960) + auto swiglu_out = llaisys::Tensor::create({1, _meta.di}, _meta.dtype, _device, _device_id); + llaisys::ops::swiglu(swiglu_out, gate, up); + + // mlp_down_w->(1536, 8960) + auto mlp_out = llaisys::Tensor::create({1, _meta.hs}, _meta.dtype, _device, _device_id); + // mlp_out->(1, 1536) + llaisys::ops::linear(mlp_out, swiglu_out, unwrap(_weights.mlp_down_w[layer]), _zero_bias_hs); + + auto x_mlp = llaisys::Tensor::create({1, _meta.hs}, _meta.dtype, _device, _device_id); + llaisys::ops::add(x_mlp, x, mlp_out); + x = x_mlp; + } + + auto out_norm = llaisys::Tensor::create({1, _meta.hs}, _meta.dtype, _device, _device_id); + llaisys::ops::rms_norm(out_norm, x, unwrap(_weights.out_norm_w), _meta.epsilon); + + auto logits_2d = llaisys::Tensor::create({1, _meta.voc}, _meta.dtype, _device, _device_id); + llaisys::ops::linear(logits_2d, out_norm, unwrap(_weights.out_embed), _zero_bias_voc); + + return logits_2d->view({_meta.voc}); +} + +int64_t Qwen2Model::infer(const int64_t *token_ids, size_t ntoken, float temperature, int top_k, float top_p) { + CHECK_ARGUMENT(token_ids != nullptr, "Qwen2Model: token_ids is null"); + CHECK_ARGUMENT(ntoken > 0, "Qwen2Model: ntoken must be > 0"); + CHECK_ARGUMENT(ntoken <= _meta.maxseq, "Qwen2Model: ntoken exceeds max sequence length"); + + if (ntoken <= _past_len) { + reset_state(); + } + + llaisys::tensor_t logits; + for (size_t i = _past_len; i < ntoken; i++) { + logits = forward_token(token_ids[i], i); + } + + _past_len = ntoken; + + auto sampled = llaisys::Tensor::create({1}, LLAISYS_DTYPE_I64, _device, _device_id); + llaisys::ops::sample(sampled, logits, temperature, top_k, top_p); + + return *reinterpret_cast(sampled->data()); +} + +} // namespace llaisys::models diff --git a/src/models/qwen2.hpp b/src/models/qwen2.hpp new file mode 100644 index 000000000..d5adcf3a6 --- /dev/null +++ b/src/models/qwen2.hpp @@ -0,0 +1,44 @@ +#pragma once + +#include "llaisys/models/qwen2.h" + +#include "../llaisys/llaisys_tensor.hpp" +#include "../tensor/tensor.hpp" + +#include + +namespace llaisys::models { + +class Qwen2Model { +public: + Qwen2Model(const LlaisysQwen2Meta &meta, llaisysDeviceType_t device, int *device_ids, int ndevice); + ~Qwen2Model(); + + LlaisysQwen2Weights *weights(); + int64_t infer(const int64_t *token_ids, size_t ntoken, float temperature, int top_k, float top_p); + +private: + llaisysTensor_t create_weight_tensor(const std::vector &shape); + llaisys::tensor_t unwrap(llaisysTensor_t tensor) const; + void zero_tensor(llaisysTensor_t tensor); + void zero_tensor_data(const llaisys::tensor_t &tensor); + void write_kv_cache(size_t layer, size_t pos, const llaisys::tensor_t &k, const llaisys::tensor_t &v); + void reset_state(); + + llaisys::tensor_t forward_token(int64_t token_id, size_t pos); + +private: + LlaisysQwen2Meta _meta; + llaisysDeviceType_t _device; + int _device_id; + LlaisysQwen2Weights _weights; + std::vector _owned_tensors; + std::vector _k_cache; + std::vector _v_cache; + size_t _past_len; + llaisys::tensor_t _zero_bias_hs; + llaisys::tensor_t _zero_bias_di; + llaisys::tensor_t _zero_bias_voc; +}; + +} // namespace llaisys::models diff --git a/src/ops/argmax/cpu/argmax_cpu.cpp b/src/ops/argmax/cpu/argmax_cpu.cpp new file mode 100644 index 000000000..b0a14ee84 --- /dev/null +++ b/src/ops/argmax/cpu/argmax_cpu.cpp @@ -0,0 +1,55 @@ +#include "argmax_cpu.hpp" + +#include "../../../utils.hpp" + +#include + +template +void argmax_(int64_t *max_idx, T *max_val, const T *vals, size_t numel) { + if (numel == 0) { + return; + } + size_t idx = 0; + T max = vals[0]; + for (size_t i = 1; i < numel; i++) { + if constexpr (std::is_same_v || + std::is_same_v) { + float vf = llaisys::utils::cast(vals[i]); + float bf = llaisys::utils::cast(max); + if (vf > bf) { + idx = i; + max = vals[i]; + } + } else { + if (vals[i] > max) { + idx = i; + max = vals[i]; + } + } + } + *max_idx = static_cast(idx); + *max_val = max; +} + +namespace llaisys::ops::cpu { +void argmax(std::byte *max_idx, std::byte *max_val, const std::byte *vals, + llaisysDataType_t vals_dtype, size_t numel) { + int64_t *max_idx_ptr = reinterpret_cast(max_idx); + switch (vals_dtype) { + case LLAISYS_DTYPE_F32: + argmax_(max_idx_ptr, reinterpret_cast(max_val), + reinterpret_cast(vals), numel); + return; + case LLAISYS_DTYPE_BF16: + argmax_(max_idx_ptr, reinterpret_cast(max_val), + reinterpret_cast(vals), numel); + return; + case LLAISYS_DTYPE_F16: + argmax_(max_idx_ptr, reinterpret_cast(max_val), + reinterpret_cast(vals), numel); + return; + default: + EXCEPTION_UNSUPPORTED_DATATYPE(vals_dtype); + } +} +} // namespace llaisys::ops::cpu diff --git a/src/ops/argmax/cpu/argmax_cpu.hpp b/src/ops/argmax/cpu/argmax_cpu.hpp new file mode 100644 index 000000000..823f907f1 --- /dev/null +++ b/src/ops/argmax/cpu/argmax_cpu.hpp @@ -0,0 +1,9 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::cpu { +void argmax(std::byte *max_idx, std::byte *max_val, const std::byte *vals, + llaisysDataType_t vals_dtype, size_t numel); +} diff --git a/src/ops/argmax/op.cpp b/src/ops/argmax/op.cpp index 6dc37d426..dbf95762b 100644 --- a/src/ops/argmax/op.cpp +++ b/src/ops/argmax/op.cpp @@ -1,7 +1,35 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" + +#include "cpu/argmax_cpu.hpp" + namespace llaisys::ops { void argmax(tensor_t max_idx, tensor_t max_val, tensor_t vals) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(max_idx, max_val, vals); + ASSERT(vals->ndim() == 1, "argmax: vals must be 1D."); + ASSERT(max_idx->numel() == 1 && max_val->numel() == 1, "argmax: max_idx and max_val must each have one element."); + CHECK_SAME_DTYPE(max_val->dtype(), vals->dtype()); + ASSERT(max_idx->dtype() == LLAISYS_DTYPE_I64, "argmax: max_idx must be int64."); + ASSERT(max_idx->isContiguous() && max_val->isContiguous() && vals->isContiguous(), "argmax: all tensors must be contiguous."); + + if (vals->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::argmax(max_idx->data(), max_val->data(), vals->data(), vals->dtype(), vals->numel()); + } + + llaisys::core::context().setDevice(vals->deviceType(), vals->deviceId()); + + switch (vals->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::argmax(max_idx->data(), max_val->data(), vals->data(), vals->dtype(), vals->numel()); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + TO_BE_IMPLEMENTED(); + return; +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/embedding/cpu/embedding_cpu.cpp b/src/ops/embedding/cpu/embedding_cpu.cpp new file mode 100644 index 000000000..e88174a09 --- /dev/null +++ b/src/ops/embedding/cpu/embedding_cpu.cpp @@ -0,0 +1,35 @@ +#include "embedding_cpu.hpp" + +#include "../../../utils.hpp" + +#include + +template +void embedding_(T *out, const int64_t *index, const T *weight, size_t num_indices, size_t weight_rows, size_t weight_cols) { + for (size_t i = 0; i < num_indices; i++) { + int64_t row = index[i]; + const T *src = weight + static_cast(row) * weight_cols; + T *dst = out + i * weight_cols; + std::memcpy(dst, src, weight_cols * sizeof(T)); + } +} + +namespace llaisys::ops::cpu { +void embedding(std::byte *out, std::byte *index, std::byte *weight, llaisysDataType_t dtype, size_t num_indices, + size_t weight_rows, size_t weight_cols) { + auto *index_ptr = reinterpret_cast(index); + switch (dtype) { + case LLAISYS_DTYPE_F32: + return embedding_(reinterpret_cast(out), index_ptr, reinterpret_cast(weight), + num_indices, weight_rows, weight_cols); + case LLAISYS_DTYPE_F16: + return embedding_(reinterpret_cast(out), index_ptr, + reinterpret_cast(weight), num_indices, weight_rows, weight_cols); + case LLAISYS_DTYPE_BF16: + return embedding_(reinterpret_cast(out), index_ptr, + reinterpret_cast(weight), num_indices, weight_rows, weight_cols); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(dtype); + } +} +} // namespace llaisys::ops::cpu diff --git a/src/ops/embedding/cpu/embedding_cpu.hpp b/src/ops/embedding/cpu/embedding_cpu.hpp new file mode 100644 index 000000000..13fb87bb6 --- /dev/null +++ b/src/ops/embedding/cpu/embedding_cpu.hpp @@ -0,0 +1,8 @@ +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::cpu { +void embedding(std::byte *out, std::byte *index, std::byte *weight, llaisysDataType_t dtype, size_t num_indices, size_t weight_rows, size_t weight_cols); +} \ No newline at end of file diff --git a/src/ops/embedding/op.cpp b/src/ops/embedding/op.cpp index 84b9a5d06..2ef85feaa 100644 --- a/src/ops/embedding/op.cpp +++ b/src/ops/embedding/op.cpp @@ -1,7 +1,37 @@ #include "op.hpp" +#include "../../utils.hpp" +#include "cpu/embedding_cpu.hpp" namespace llaisys::ops { void embedding(tensor_t out, tensor_t index, tensor_t weight) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(out, index, weight); + CHECK_SAME_DTYPE(out->dtype(), weight->dtype()); + ASSERT(index->dtype() == LLAISYS_DTYPE_I64, "Embedding: index tensor must be of type INT64."); + // out shape: [index.shape[0], weight.shape[1]] + ASSERT(out->shape().size() == 2, "Embedding: output tensor must be 2-dimensional."); + ASSERT(index->shape().size() == 1, "Embedding: index tensor must be 1-dimensional."); + ASSERT(weight->shape().size() == 2, "Embedding: weight tensor must be 2-dimensional."); + ASSERT(out->shape()[0] == index->shape()[0], "Embedding: output tensor's first dimension must match index tensor's size."); + ASSERT(out->shape()[1] == weight->shape()[1], "Embedding: output tensor's second dimension must match weight tensor's second dimension."); + ASSERT(weight->isContiguous() && out->isContiguous() && index->isContiguous(), "Embedding: weight index and weight tensors must be contiguous."); + + // always support cpu calculation + if (out->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::embedding(out->data(), index->data(), weight->data(), out->dtype(), index->shape()[0], weight->shape()[0], weight->shape()[1]); + } + + llaisys::core::context().setDevice(out->deviceType(), out->deviceId()); + + switch (out->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::embedding(out->data(), index->data(), weight->data(), out->dtype(), index->shape()[0], weight->shape()[0], weight->shape()[1]); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + TO_BE_IMPLEMENTED(); + return; +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/linear/cpu/linear_cpu.cpp b/src/ops/linear/cpu/linear_cpu.cpp new file mode 100644 index 000000000..1f0f51586 --- /dev/null +++ b/src/ops/linear/cpu/linear_cpu.cpp @@ -0,0 +1,41 @@ +#include "linear_cpu.hpp" +#include "../../../utils.hpp" +#include +template +void linear_(T *out, const T *in, const T *weight, const T *bias, size_t input_rows, size_t input_cols, size_t weight_rows) { + for (size_t i = 0; i < input_rows; i++) { + for (size_t j = 0; j < weight_rows; j++) { + + if constexpr (std::is_same_v || std::is_same_v) { + // Accumulate in float to avoid unsupported operators on bf16/fp16. + float acc = llaisys::utils::cast(bias[j]); + // Accumulate matrix multiplication: out = in @ weight.T + for (size_t k = 0; k < input_cols; k++) { + acc += llaisys::utils::cast(in[i * input_cols + k]) * + llaisys::utils::cast(weight[j * input_cols + k]); + } + out[i * weight_rows + j] = llaisys::utils::cast(acc); + } else { + out[i * weight_rows + j] = bias[j]; + for (size_t k = 0; k < input_cols; k++) { + out[i * weight_rows + j] += in[i * input_cols + k] * weight[j * input_cols + k]; + } + } + } + } +} + +namespace llaisys::ops::cpu { +void linear(std::byte *out, std::byte *in, std::byte *weight, std::byte *bias, llaisysDataType_t type, size_t input_rows, size_t input_cols, size_t weight_rows) { + switch (type) { + case LLAISYS_DTYPE_F32: + return linear_(reinterpret_cast(out), reinterpret_cast(in), reinterpret_cast(weight), reinterpret_cast(bias), input_rows, input_cols, weight_rows); + case LLAISYS_DTYPE_BF16: + return linear_(reinterpret_cast(out), reinterpret_cast(in), reinterpret_cast(weight), reinterpret_cast(bias), input_rows, input_cols, weight_rows); + case LLAISYS_DTYPE_F16: + return linear_(reinterpret_cast(out), reinterpret_cast(in), reinterpret_cast(weight), reinterpret_cast(bias), input_rows, input_cols, weight_rows); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} // namespace llaisys::ops::cpu diff --git a/src/ops/linear/cpu/linear_cpu.hpp b/src/ops/linear/cpu/linear_cpu.hpp new file mode 100644 index 000000000..5edab89a2 --- /dev/null +++ b/src/ops/linear/cpu/linear_cpu.hpp @@ -0,0 +1,7 @@ +#pragma once +#include "llaisys.h" +#include + +namespace llaisys::ops::cpu { +void linear(std::byte *out, std::byte *in, std::byte *weight, std::byte *bias, llaisysDataType_t type, size_t input_rows, size_t input_cols, size_t weight_rows); +} // namespace llaisys::ops::cpu diff --git a/src/ops/linear/op.cpp b/src/ops/linear/op.cpp index 97d1f8655..3e2d25094 100644 --- a/src/ops/linear/op.cpp +++ b/src/ops/linear/op.cpp @@ -1,7 +1,36 @@ #include "op.hpp" +#include "../../utils.hpp" +#include "cpu/linear_cpu.hpp" namespace llaisys::ops { void linear(tensor_t out, tensor_t in, tensor_t weight, tensor_t bias) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(out, in, weight, bias); + // Only support contiguous tensor with same data type + CHECK_SAME_DTYPE(out->dtype(), in->dtype(), weight->dtype(), bias->dtype()); + ASSERT(out->isContiguous() && in->isContiguous() && weight->isContiguous() && bias->isContiguous(), "Linear: all tensors must be contiguous."); + // only support 2D input and weight, 1D bias + ASSERT(in->ndim() == 2 && weight->ndim() == 2 && bias->ndim() == 1, "Linear: in and weight must be 2D tensors, bias must be 1D tensor."); + // out = in @ weight.T + bias + ASSERT(in->shape()[1] == weight->shape()[1], "Linear: in.shape[1] must be equal to weight.shape[1]."); + ASSERT(out->shape()[0] == in->shape()[0] && out->shape()[1] == weight->shape()[0], "Linear: out.shape must be equal to (in.shape[0], weight.shape[0])."); + ASSERT(bias->shape()[0] == weight->shape()[0], "Linear: bias.shape[0] must be equal to weight.shape[0]."); + + // always support cpu calculation + if (out->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::linear(out->data(), in->data(), weight->data(), bias->data(), out->dtype(), in->shape()[0], in->shape()[1], weight->shape()[0]); + } + + llaisys::core::context().setDevice(out->deviceType(), out->deviceId()); + switch (out->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::linear(out->data(), in->data(), weight->data(), bias->data(), out->dtype(), in->shape()[0], in->shape()[1], weight->shape()[0]); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + TO_BE_IMPLEMENTED(); + return; +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/rms_norm/cpu/rms_norm_cpu.cpp b/src/ops/rms_norm/cpu/rms_norm_cpu.cpp new file mode 100644 index 000000000..04526f9ed --- /dev/null +++ b/src/ops/rms_norm/cpu/rms_norm_cpu.cpp @@ -0,0 +1,50 @@ +#include "rms_norm_cpu.hpp" +#include "../../../utils.hpp" +#include + +template +void rms_norm_(T *out, const T *in, const T *weight, size_t input_rows, size_t input_cols, float eps) { + for (size_t i = 0; i < input_rows; i++) { + float sum_squares = 0.0f; + if constexpr (std::is_same_v || std::is_same_v) { + for (size_t j = 0; j < input_cols; j++) { + float val = llaisys::utils::cast(in[i * input_cols + j]); + sum_squares += val * val; + } + + float rms = std::sqrt(sum_squares / input_cols + eps); + for (size_t j = 0; j < input_cols; j++) { + float val = llaisys::utils::cast(in[i * input_cols + j]); + out[i * input_cols + j] = llaisys::utils::cast(val * llaisys::utils::cast(weight[j]) / rms); + } + } else { + for (size_t j = 0; j < input_cols; j++) { + float val = static_cast(in[i * input_cols + j]); + sum_squares += val * val; + } + + float rms = std::sqrt(sum_squares / input_cols + eps); + for (size_t j = 0; j < input_cols; j++) { + float val = static_cast(in[i * input_cols + j]); + out[i * input_cols + j] = static_cast(val * static_cast(weight[j]) / rms); + } + } + } +} + +namespace llaisys::ops::cpu { +void rms_norm(std::byte *out, const std::byte *in, const std::byte *weight, llaisysDataType_t type, size_t input_rows, size_t input_cols, float eps) { + switch (type) { + case LLAISYS_DTYPE_F32: + return rms_norm_(reinterpret_cast(out), reinterpret_cast(in), reinterpret_cast(weight), input_rows, input_cols, eps); + case LLAISYS_DTYPE_BF16: + return rms_norm_(reinterpret_cast(out), reinterpret_cast(in), + reinterpret_cast(weight), input_rows, input_cols, eps); + case LLAISYS_DTYPE_F16: + return rms_norm_(reinterpret_cast(out), reinterpret_cast(in), + reinterpret_cast(weight), input_rows, input_cols, eps); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} // namespace llaisys::ops::cpu \ No newline at end of file diff --git a/src/ops/rms_norm/cpu/rms_norm_cpu.hpp b/src/ops/rms_norm/cpu/rms_norm_cpu.hpp new file mode 100644 index 000000000..075466298 --- /dev/null +++ b/src/ops/rms_norm/cpu/rms_norm_cpu.hpp @@ -0,0 +1,9 @@ + +#pragma once +#include "llaisys.h" + +#include + +namespace llaisys::ops::cpu { +void rms_norm(std::byte *out, const std::byte *in, const std::byte *weight, llaisysDataType_t type, size_t input_rows, size_t input_cols, float eps); +} \ No newline at end of file diff --git a/src/ops/rms_norm/op.cpp b/src/ops/rms_norm/op.cpp index 529553d9d..fb8cae6e7 100644 --- a/src/ops/rms_norm/op.cpp +++ b/src/ops/rms_norm/op.cpp @@ -1,7 +1,36 @@ #include "op.hpp" +#include "../../utils.hpp" + +#include "cpu/rms_norm_cpu.hpp" namespace llaisys::ops { void rms_norm(tensor_t out, tensor_t in, tensor_t weight, float eps) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(out, in, weight); + // Only support contiguous inputs with same shape for now. + CHECK_SAME_SHAPE(out->shape(), in->shape()); + CHECK_SAME_DTYPE(out->dtype(), in->dtype()); + ASSERT(out->isContiguous() && in->isContiguous() && weight->isContiguous(), "RMSNorm: all tensors must be contiguous."); + // Only support 2D input and output for now. + ASSERT(in->ndim() == 2, "RMSNorm: only support 2D input tensor."); + // Check weight shape + ASSERT(weight->ndim() == 1 && weight->shape()[0] == in->shape()[1], "RMSNorm: weight shape is invalid."); + + if (out->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::rms_norm(out->data(), in->data(), weight->data(), out->dtype(), in->shape()[0], in->shape()[1], eps); + } + + llaisys::core::context().setDevice(out->deviceType(), out->deviceId()); + + switch (out->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::rms_norm(out->data(), in->data(), weight->data(), out->dtype(), in->shape()[0], in->shape()[1], eps); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + TO_BE_IMPLEMENTED(); + return; +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/rope/cpu/rope_cpu.cpp b/src/ops/rope/cpu/rope_cpu.cpp new file mode 100644 index 000000000..a4c59bbb0 --- /dev/null +++ b/src/ops/rope/cpu/rope_cpu.cpp @@ -0,0 +1,47 @@ +#include "rope_cpu.hpp" +#include "../../../utils.hpp" +#include + +template +void rope_(T *out, const T *in, const int64_t *pos_ids, size_t seq_length, size_t head_nums, size_t head_dim, float theta) { + const size_t max_i = head_dim / 2; + for (size_t s = 0; s < seq_length; s++) { + for (size_t h = 0; h < head_nums; h++) { + for (size_t i = 0; i < max_i; i++) { + float angle = pos_ids[s] / std::pow(theta, (2.0f * i) / head_dim); + float cos_angle = std::cos(angle); + float sin_angle = std::sin(angle); + + size_t base = s * head_nums * head_dim + h * head_dim; + size_t index_a = base + i; + size_t index_b = base + i + max_i; + if constexpr (std::is_same_v || std::is_same_v) { + float x1 = llaisys::utils::cast(in[index_a]); + float x2 = llaisys::utils::cast(in[index_b]); + out[index_a] = llaisys::utils::cast(x1 * cos_angle - x2 * sin_angle); + out[index_b] = llaisys::utils::cast(x2 * cos_angle + x1 * sin_angle); + } else { + float x1 = in[index_a]; + float x2 = in[index_b]; + out[index_a] = x1 * cos_angle - x2 * sin_angle; + out[index_b] = x2 * cos_angle + x1 * sin_angle; + } + } + } + } +} + +namespace llaisys::ops::cpu { +void rope(std::byte *out, const std::byte *in, const std::byte *pos_ids, llaisysDataType_t type, size_t seq_length, size_t head_nums, size_t head_dim, float theta) { + switch (type) { + case LLAISYS_DTYPE_F32: + return rope_(reinterpret_cast(out), reinterpret_cast(in), reinterpret_cast(pos_ids), seq_length, head_nums, head_dim, theta); + case LLAISYS_DTYPE_BF16: + return rope_(reinterpret_cast(out), reinterpret_cast(in), reinterpret_cast(pos_ids), seq_length, head_nums, head_dim, theta); + case LLAISYS_DTYPE_F16: + return rope_(reinterpret_cast(out), reinterpret_cast(in), reinterpret_cast(pos_ids), seq_length, head_nums, head_dim, theta); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(type); + } +} +} // namespace llaisys::ops::cpu diff --git a/src/ops/rope/cpu/rope_cpu.hpp b/src/ops/rope/cpu/rope_cpu.hpp new file mode 100644 index 000000000..5c83f48f2 --- /dev/null +++ b/src/ops/rope/cpu/rope_cpu.hpp @@ -0,0 +1,7 @@ +#pragma once +#include "llaisys.h" +#include + +namespace llaisys::ops::cpu { +void rope(std::byte *out, const std::byte *in, const std::byte *pos_ids, llaisysDataType_t type, size_t seq_length, size_t head_nums, size_t head_dim, float theta); +} \ No newline at end of file diff --git a/src/ops/rope/op.cpp b/src/ops/rope/op.cpp index d60dbe64e..903452c5b 100644 --- a/src/ops/rope/op.cpp +++ b/src/ops/rope/op.cpp @@ -1,7 +1,35 @@ #include "op.hpp" +#include "cpu/rope_cpu.hpp" + namespace llaisys::ops { void rope(tensor_t out, tensor_t in, tensor_t pos_ids, float theta) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(out, in, pos_ids); + CHECK_SAME_DTYPE(out->dtype(), in->dtype()); + CHECK_SAME_DTYPE(pos_ids->dtype(), LLAISYS_DTYPE_I64); + ASSERT(in->ndim() == 3, "Rope: input tensor must be 3-dimensional."); + ASSERT(out->ndim() == 3, "Rope: output tensor must be 3-dimensional."); + ASSERT(pos_ids->ndim() == 1, "Rope: position ids tensor must be 1-dimensional."); + CHECK_SAME_SHAPE(out->shape(), in->shape()); + ASSERT(in->shape()[0] == pos_ids->shape()[0], "Rope: position ids length must match input tensor sequence length."); + ASSERT(in->shape()[2] % 2 == 0, "Rope: head dimension must be even."); + ASSERT(out->isContiguous() && in->isContiguous() && pos_ids->isContiguous(), "Rope: all tensors must be contiguous."); + + // always support cpu calculation + if (out->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::rope(out->data(), in->data(), pos_ids->data(), out->dtype(), in->shape()[0], in->shape()[1], in->shape()[2], theta); + } + llaisys::core::context().setDevice(out->deviceType(), out->deviceId()); + switch (out->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::rope(out->data(), in->data(), pos_ids->data(), out->dtype(), in->shape()[0], in->shape()[1], in->shape()[2], theta); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + TO_BE_IMPLEMENTED(); + return; +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/sample/cpu/sample_cpu.cpp b/src/ops/sample/cpu/sample_cpu.cpp new file mode 100644 index 000000000..92a40c908 --- /dev/null +++ b/src/ops/sample/cpu/sample_cpu.cpp @@ -0,0 +1,189 @@ +/** + * sample_cpu.cpp — CPU implementation of the random sampling operator + * + * This operator takes the raw logits output by the language model and samples + * the next token using three optional filtering strategies: + * + * Temperature scaling — controls the "sharpness" of the distribution + * Top-K filtering — restricts sampling to the K most likely tokens + * Top-P (nucleus) — restricts sampling to the smallest set of tokens + * whose cumulative probability exceeds P + * + * Algorithm overview (applied in order): + * 1. Convert logits to float32 and divide by temperature + * 2. Sort token indices by score (descending) + * 3. Keep only the top-K indices (if top_k > 0) + * 4. Compute softmax over the kept scores → probabilities + * 5. Truncate to the nucleus (if top_p < 1.0) and re-normalize + * 6. Draw one sample from the resulting discrete distribution + * + * References: + * Temperature: https://arxiv.org/abs/1904.09751 + * Top-K: https://arxiv.org/abs/1805.04833 + * Top-P: https://arxiv.org/abs/1904.09751 + */ + +#include "sample_cpu.hpp" + +#include "../../../utils.hpp" + +#include +#include +#include +#include +#include +#include +#include + +/** + * Core sampling function, templated over the logit data type T. + * Supports float32, float16, and bfloat16 via template specialization. + * + * @param sampled_token Output: pointer to a single int64 that receives the + * sampled token index. + * @param logits Input: raw unnormalized scores from the model's final + * linear layer, one per vocabulary token. + * @param vocab_size Number of tokens in the vocabulary. + * @param temperature Divides all logits before softmax. + * < 1.0 → sharper distribution (more confident) + * > 1.0 → flatter distribution (more random) + * = 1.0 → no change + * @param top_k Keep only the top-K tokens. 0 = disabled (keep all). + * @param top_p Nucleus threshold in [0, 1]. 1.0 = disabled (keep all). + */ +template +void sample_(int64_t *sampled_token, const T *logits, size_t vocab_size, + float temperature, int top_k, float top_p) { + if (vocab_size == 0) { + return; + } + + // ── Step 1: Convert to float32 and apply temperature scaling ───────────── + // Language models output logits in their native dtype (BF16, F16, or F32). + // We convert everything to float32 for numerical stability. + // Dividing by temperature before softmax is equivalent to raising the + // probabilities to the power of (1/temperature) after softmax. + std::vector scores(vocab_size); + for (size_t i = 0; i < vocab_size; i++) { + float val; + if constexpr (std::is_same_v || std::is_same_v) { + val = llaisys::utils::cast(logits[i]); + } else { + val = static_cast(logits[i]); + } + scores[i] = val / temperature; + } + + // ── Step 2: Sort token indices by score (descending) ───────────────────── + // We sort an index array rather than the scores themselves so we can + // map back to the original token IDs after filtering. + std::vector indices(vocab_size); + std::iota(indices.begin(), indices.end(), 0); // fill with 0, 1, 2, ... + std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) { + return scores[a] > scores[b]; // descending order + }); + + // ── Step 3: Top-K filtering ─────────────────────────────────────────────── + // Discard all but the K highest-scoring tokens. + // Example: vocab_size=50000, top_k=50 → only 50 candidates remain. + // top_k=0 means "keep all" (no filtering). + // top_k=1 is equivalent to argmax (always picks the most likely token). + size_t keep = vocab_size; + if (top_k > 0 && static_cast(top_k) < vocab_size) { + keep = static_cast(top_k); + } + indices.resize(keep); + + // ── Step 4: Numerically stable softmax over kept candidates ────────────── + // Softmax converts raw scores to probabilities that sum to 1. + // Formula: p_i = exp(s_i) / sum(exp(s_j)) + // + // Numerical stability trick: subtract the maximum score before exp(). + // This prevents overflow (exp of large numbers → inf) without changing + // the result, because the max cancels out in numerator and denominator: + // exp(s_i - max) / sum(exp(s_j - max)) == exp(s_i) / sum(exp(s_j)) + float max_score = scores[indices[0]]; // indices[0] is the highest score + std::vector probs(keep); + float sum = 0.0f; + for (size_t i = 0; i < keep; i++) { + probs[i] = std::exp(scores[indices[i]] - max_score); + sum += probs[i]; + } + for (size_t i = 0; i < keep; i++) { + probs[i] /= sum; // normalize to sum to 1 + } + + // ── Step 5: Top-P (nucleus) filtering ──────────────────────────────────── + // Walk through tokens in probability order (already sorted from step 2). + // Keep adding tokens until the cumulative probability reaches top_p, + // then discard the rest. Re-normalize the kept probabilities. + // + // Example: top_p=0.9 with probs [0.5, 0.3, 0.15, 0.05] + // After token 0: cumsum = 0.5 (< 0.9, keep going) + // After token 1: cumsum = 0.8 (< 0.9, keep going) + // After token 2: cumsum = 0.95 (>= 0.9, stop here, cutoff = 3) + // Token 3 is discarded. Remaining probs re-normalized to sum to 1. + if (top_p < 1.0f) { + float cumsum = 0.0f; + size_t cutoff = keep; + for (size_t i = 0; i < keep; i++) { + cumsum += probs[i]; + if (cumsum >= top_p) { + cutoff = i + 1; + break; + } + } + if (cutoff < keep) { + keep = cutoff; + indices.resize(keep); + probs.resize(keep); + // Re-normalize so probabilities sum to 1 again + float new_sum = 0.0f; + for (float p : probs) new_sum += p; + for (float &p : probs) p /= new_sum; + } + } + + // ── Step 6: Multinomial sampling ───────────────────────────────────────── + // Draw one token index from the filtered probability distribution. + // std::discrete_distribution handles the weighted random selection. + // + // thread_local: each thread gets its own RNG state, avoiding data races + // in multi-threaded scenarios. + // + // Seed: XOR of std::random_device (hardware entropy) and the current + // time. On Windows, std::random_device alone may return a fixed value, + // so we mix in the clock to ensure different seeds across calls. + thread_local std::mt19937 rng( + std::random_device{}() ^ + static_cast(std::chrono::steady_clock::now().time_since_epoch().count()) + ); + std::discrete_distribution dist(probs.begin(), probs.end()); + *sampled_token = static_cast(indices[dist(rng)]); +} + +/** + * Public entry point for the sample operator. + * Dispatches to the templated sample_() based on the logit data type. + * Raw byte pointers are used to match the generic tensor data interface. + */ +namespace llaisys::ops::cpu { +void sample(std::byte *sampled_token, const std::byte *logits, llaisysDataType_t logits_dtype, + size_t vocab_size, float temperature, int top_k, float top_p) { + int64_t *token_ptr = reinterpret_cast(sampled_token); + + switch (logits_dtype) { + case LLAISYS_DTYPE_F32: + sample_(token_ptr, reinterpret_cast(logits), vocab_size, temperature, top_k, top_p); + return; + case LLAISYS_DTYPE_BF16: + sample_(token_ptr, reinterpret_cast(logits), vocab_size, temperature, top_k, top_p); + return; + case LLAISYS_DTYPE_F16: + sample_(token_ptr, reinterpret_cast(logits), vocab_size, temperature, top_k, top_p); + return; + default: + EXCEPTION_UNSUPPORTED_DATATYPE(logits_dtype); + } +} +} // namespace llaisys::ops::cpu diff --git a/src/ops/sample/cpu/sample_cpu.hpp b/src/ops/sample/cpu/sample_cpu.hpp new file mode 100644 index 000000000..bc0108ca0 --- /dev/null +++ b/src/ops/sample/cpu/sample_cpu.hpp @@ -0,0 +1,19 @@ +#pragma once + +#include "llaisys.h" + +#include + +namespace llaisys::ops::cpu { + +void sample( + std::byte *sampled_token, // Output: int64 token + const std::byte *logits, // Input: logits array + llaisysDataType_t logits_dtype, + size_t vocab_size, + float temperature, + int top_k, + float top_p +); + +} // namespace llaisys::ops::cpu diff --git a/src/ops/sample/op.cpp b/src/ops/sample/op.cpp new file mode 100644 index 000000000..f70a51f1d --- /dev/null +++ b/src/ops/sample/op.cpp @@ -0,0 +1,37 @@ +#include "op.hpp" + +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" + +#include "cpu/sample_cpu.hpp" + +namespace llaisys::ops { +void sample(tensor_t sampled_token, tensor_t logits, float temperature, int top_k, float top_p) { + CHECK_SAME_DEVICE(sampled_token, logits); + ASSERT(logits->ndim() == 1, "sample: logits must be 1D."); + ASSERT(sampled_token->numel() == 1, "sample: sampled_token must have one element."); + ASSERT(sampled_token->dtype() == LLAISYS_DTYPE_I64, "sample: sampled_token must be int64."); + ASSERT(sampled_token->isContiguous() && logits->isContiguous(), "sample: all tensors must be contiguous."); + ASSERT(temperature > 0.0f, "sample: temperature must be positive."); + ASSERT(top_k >= 0, "sample: top_k must be non-negative."); + ASSERT(top_p >= 0.0f && top_p <= 1.0f, "sample: top_p must be in [0, 1]."); + + if (logits->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::sample(sampled_token->data(), logits->data(), logits->dtype(), logits->numel(), temperature, top_k, top_p); + } + + llaisys::core::context().setDevice(logits->deviceType(), logits->deviceId()); + + switch (logits->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::sample(sampled_token->data(), logits->data(), logits->dtype(), logits->numel(), temperature, top_k, top_p); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + TO_BE_IMPLEMENTED(); + return; +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } +} +} // namespace llaisys::ops diff --git a/src/ops/sample/op.hpp b/src/ops/sample/op.hpp new file mode 100644 index 000000000..705005d1d --- /dev/null +++ b/src/ops/sample/op.hpp @@ -0,0 +1,15 @@ +#pragma once + +#include "../../tensor/tensor.hpp" + +namespace llaisys::ops { + +void sample( + tensor_t sampled_token, + tensor_t logits, + float temperature, + int top_k, + float top_p +); + +} // namespace llaisys::ops diff --git a/src/ops/self_attention/cpu/self_attension_cpu.cpp b/src/ops/self_attention/cpu/self_attension_cpu.cpp new file mode 100644 index 000000000..ab434a915 --- /dev/null +++ b/src/ops/self_attention/cpu/self_attension_cpu.cpp @@ -0,0 +1,139 @@ +#include "self_attension_cpu.hpp" +#include "../../../utils.hpp" +#include +#include +#include + +template +void self_attention_(T *attn_val, const T *q, const T *k, const T *v, size_t seqlen, size_t total_len, size_t nhead, size_t nkv_head, size_t d, size_t dv, float scale) { + size_t group_size = nhead / nkv_head; + std::vector scores_buffer(total_len); + std::fill(attn_val, attn_val + seqlen * nhead * dv, llaisys::utils::cast(0)); + + if constexpr (std::is_same_v || std::is_same_v) { + for (size_t i = 0; i < seqlen; i++) { + for (size_t h = 0; h < nhead; h++) { + size_t group_id = h / group_size; + + float max_score = -std::numeric_limits::infinity(); + const size_t max_t = i + (total_len - seqlen); + + if (max_t < 0) { + continue; + } + + const size_t t_end = std::min(max_t, total_len - 1); + + for (size_t t = 0; t <= t_end; t++) { + size_t q_base = (i * nhead + h) * d; + size_t k_base = (t * nkv_head + group_id) * d; + + float score = 0.0f; + for (size_t j = 0; j < d; j++) { + score += llaisys::utils::cast(q[q_base + j]) * llaisys::utils::cast(k[k_base + j]); + } + score *= scale; + scores_buffer[t] = score; + + if (score > max_score) { + max_score = score; + } + } + + // Mask out invalid positions to match causal attention. + for (size_t t = t_end + 1; t < total_len; t++) { + scores_buffer[t] = -std::numeric_limits::infinity(); + } + + // compute softmax + float sum_exp = 0.0f; + for (size_t t = 0; t < total_len; t++) { + scores_buffer[t] = std::exp(scores_buffer[t] - max_score); + sum_exp += scores_buffer[t]; + } + for (size_t t = 0; t < total_len; t++) { + scores_buffer[t] /= sum_exp; + } + + size_t attn_base = (i * nhead + h) * dv; + for (size_t j = 0; j < dv; j++) { + float out = 0.0f; + for (size_t t = 0; t < total_len; t++) { + size_t v_base = (t * nkv_head + group_id) * dv; + out += scores_buffer[t] * llaisys::utils::cast(v[v_base + j]); + } + attn_val[attn_base + j] = llaisys::utils::cast(out); + } + } + } + } else { + for (size_t i = 0; i < seqlen; i++) { + for (size_t h = 0; h < nhead; h++) { + size_t group_id = h / group_size; + + float max_score = -std::numeric_limits::infinity(); + const size_t max_t = i + (total_len - seqlen); + + if (max_t < 0) { + continue; + } + + size_t q_base = (i * nhead + h) * d; + for (size_t t = 0; t < total_len; t++) { + size_t k_base = (t * nkv_head + group_id) * d; + + float score = 0; + + for (size_t j = 0; j < d; j++) { + score += q[q_base + j] * k[k_base + j]; + } + score *= scale; + scores_buffer[t] = score; + + if (score > max_score) { + max_score = score; + } + } + + // Mask out invalid positions to match causal attention. + for (size_t t = max_t + 1; t < total_len; t++) { + scores_buffer[t] = -std::numeric_limits::infinity(); + } + + // compute softmax + float sum_exp = 0.0f; + for (size_t t = 0; t < total_len; t++) { + scores_buffer[t] = std::exp(scores_buffer[t] - max_score); + sum_exp += scores_buffer[t]; + } + for (size_t t = 0; t < total_len; t++) { + scores_buffer[t] /= sum_exp; + } + + size_t attn_base = (i * nhead + h) * dv; + for (size_t t = 0; t < total_len; t++) { + size_t v_base = (t * nkv_head + group_id) * dv; + + for (size_t j = 0; j < dv; j++) { + attn_val[attn_base + j] += scores_buffer[t] * v[v_base + j]; + } + } + } + } + } +} + +namespace llaisys::ops::cpu { +void self_attention(std::byte *attn_val, const std::byte *q, const std::byte *k, const std::byte *v, llaisysDataType_t dtype, size_t seqlen, size_t total_len, size_t nhead, size_t kv_head, size_t d, size_t dv, float scale) { + switch (dtype) { + case llaisysDataType_t::LLAISYS_DTYPE_F32: + return self_attention_(reinterpret_cast(attn_val), reinterpret_cast(q), reinterpret_cast(k), reinterpret_cast(v), seqlen, total_len, nhead, kv_head, d, dv, scale); + case llaisysDataType_t::LLAISYS_DTYPE_F16: + return self_attention_(reinterpret_cast(attn_val), reinterpret_cast(q), reinterpret_cast(k), reinterpret_cast(v), seqlen, total_len, nhead, kv_head, d, dv, scale); + case llaisysDataType_t::LLAISYS_DTYPE_BF16: + return self_attention_(reinterpret_cast(attn_val), reinterpret_cast(q), reinterpret_cast(k), reinterpret_cast(v), seqlen, total_len, nhead, kv_head, d, dv, scale); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(dtype); + } +} +} // namespace llaisys::ops::cpu diff --git a/src/ops/self_attention/cpu/self_attension_cpu.hpp b/src/ops/self_attention/cpu/self_attension_cpu.hpp new file mode 100644 index 000000000..add9538df --- /dev/null +++ b/src/ops/self_attention/cpu/self_attension_cpu.hpp @@ -0,0 +1,7 @@ +#pragma once +#include "llaisys.h" +#include + +namespace llaisys::ops::cpu { +void self_attention(std::byte *attn_val, const std::byte *q, const std::byte *k, const std::byte *v, llaisysDataType_t dtype, size_t seqlen, size_t total_len, size_t nhead, size_t kv_head, size_t d, size_t dv, float scale); +} \ No newline at end of file diff --git a/src/ops/self_attention/op.cpp b/src/ops/self_attention/op.cpp index 43d620142..68985ca0b 100644 --- a/src/ops/self_attention/op.cpp +++ b/src/ops/self_attention/op.cpp @@ -1,7 +1,49 @@ #include "op.hpp" +#include "../../core/llaisys_core.hpp" +#include "../../utils.hpp" +#include "cpu/self_attension_cpu.hpp" + namespace llaisys::ops { void self_attention(tensor_t attn_val, tensor_t q, tensor_t k, tensor_t v, float scale) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(attn_val, q, k, v); + CHECK_SAME_DTYPE(attn_val->dtype(), q->dtype(), k->dtype(), v->dtype()); + // Only support contiguous inputs for now. + ASSERT(attn_val->isContiguous() && q->isContiguous() && k->isContiguous() && v->isContiguous(), "Self-Attention: all tensors must be contiguous."); + // attn_val, q, k, v should have same ndim and should be 3D tensors + CHECK_SAME(attn_val->ndim(), q->ndim(), k->ndim(), v->ndim()); + ASSERT(attn_val->ndim() == 3, "Self-Attention: only support 3D tensors for now."); + // attn_val shape dim 0,1 is same as q shape + CHECK_SAME_SHAPE(attn_val->shape()[0], q->shape()[0]); + CHECK_SAME_SHAPE(attn_val->shape()[1], q->shape()[1]); + + // k, v shape dim 0 is kvlen; allow kvlen != qlen + CHECK_SAME_SHAPE(k->shape()[0], v->shape()[0]); + + // q and k must share head dim; output matches v head dim + CHECK_SAME_SHAPE(k->shape()[2], q->shape()[2]); + CHECK_SAME_SHAPE(attn_val->shape()[2], v->shape()[2]); + + // nhead and kv_head should divide evenly + ASSERT(q->shape()[1] % k->shape()[1] == 0, "Self-Attention: nhead must be divisible by kv_head."); + + if (attn_val->deviceType() == LLAISYS_DEVICE_CPU) { + cpu::self_attention(attn_val->data(), q->data(), k->data(), v->data(), attn_val->dtype(), q->shape()[0], k->shape()[0], q->shape()[1], k->shape()[1], q->shape()[2], v->shape()[2], scale); + return; + } + + llaisys::core::context().setDevice(attn_val->deviceType(), attn_val->deviceId()); + + switch (attn_val->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::self_attention(attn_val->data(), q->data(), k->data(), v->data(), attn_val->dtype(), q->shape()[0], k->shape()[0], q->shape()[1], k->shape()[1], q->shape()[2], v->shape()[2], scale); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + TO_BE_IMPLEMENTED(); + return; +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/ops/swiglu/cpu/swiglu_cpu.cpp b/src/ops/swiglu/cpu/swiglu_cpu.cpp new file mode 100644 index 000000000..8f1c9c406 --- /dev/null +++ b/src/ops/swiglu/cpu/swiglu_cpu.cpp @@ -0,0 +1,35 @@ +#include "swiglu_cpu.hpp" +#include "../../../utils.hpp" +#include + +template +void swiglu_(T *out, const T *gate, const T *up, size_t seqlen, size_t intermediate_size) { + const size_t max_idx = seqlen * intermediate_size; + if constexpr (std::is_same_v || std::is_same_v) { + for (size_t i = 0; i < max_idx; i++) { + float gate_f = llaisys::utils::cast(gate[i]); + float up_f = llaisys::utils::cast(up[i]); + float sigmoid = 1.0f / (1.0f + std::exp(-gate_f)); + out[i] = llaisys::utils::cast(up_f * sigmoid * gate_f); + } + } else { + for (size_t i = 0; i < max_idx; i++) { + out[i] = up[i] * (1.0f / (1.0f + std::exp(-gate[i]))) * gate[i]; + } + } +} + +namespace llaisys::ops::cpu { +void swiglu(std::byte *out, std::byte *gate, std::byte *up, llaisysDataType_t dtype, size_t seqlen, size_t intermediate_size) { + switch (dtype) { + case LLAISYS_DTYPE_F32: + return swiglu_(reinterpret_cast(out), reinterpret_cast(gate), reinterpret_cast(up), seqlen, intermediate_size); + case LLAISYS_DTYPE_F16: + return swiglu_(reinterpret_cast(out), reinterpret_cast(gate), reinterpret_cast(up), seqlen, intermediate_size); + case LLAISYS_DTYPE_BF16: + return swiglu_(reinterpret_cast(out), reinterpret_cast(gate), reinterpret_cast(up), seqlen, intermediate_size); + default: + EXCEPTION_UNSUPPORTED_DATATYPE(dtype); + } +} +} // namespace llaisys::ops::cpu diff --git a/src/ops/swiglu/cpu/swiglu_cpu.hpp b/src/ops/swiglu/cpu/swiglu_cpu.hpp new file mode 100644 index 000000000..49d825bd8 --- /dev/null +++ b/src/ops/swiglu/cpu/swiglu_cpu.hpp @@ -0,0 +1,7 @@ +#pragma once + +#include "llaisys.h" + +namespace llaisys::ops::cpu { +void swiglu(std::byte *out, std::byte *gate, std::byte *up, llaisysDataType_t dtype, size_t seqlen, size_t intermediate_size); +} \ No newline at end of file diff --git a/src/ops/swiglu/op.cpp b/src/ops/swiglu/op.cpp index 47edbcc97..afdcc98ad 100644 --- a/src/ops/swiglu/op.cpp +++ b/src/ops/swiglu/op.cpp @@ -1,7 +1,29 @@ #include "op.hpp" +#include "../../utils.hpp" +#include "cpu/swiglu_cpu.hpp" namespace llaisys::ops { void swiglu(tensor_t out, tensor_t gate, tensor_t up) { - TO_BE_IMPLEMENTED(); + CHECK_SAME_DEVICE(out, gate, up); + CHECK_SAME_SHAPE(out->shape(), gate->shape(), up->shape()); + + ASSERT(out->ndim() == 2, "swiglu only supports 2D tensors for now"); + + if (out->deviceType() == LLAISYS_DEVICE_CPU) { + return cpu::swiglu(out->data(), gate->data(), up->data(), out->dtype(), out->shape()[0], out->shape()[1]); + } + + llaisys::core::context().setDevice(out->deviceType(), out->deviceId()); + switch (out->deviceType()) { + case LLAISYS_DEVICE_CPU: + return cpu::swiglu(out->data(), gate->data(), up->data(), out->dtype(), out->shape()[0], out->shape()[1]); +#ifdef ENABLE_NVIDIA_API + case LLAISYS_DEVICE_NVIDIA: + TO_BE_IMPLEMENTED(); + return; +#endif + default: + EXCEPTION_UNSUPPORTED_DEVICE; + } } } // namespace llaisys::ops diff --git a/src/tensor/tensor.cpp b/src/tensor/tensor.cpp index 2f594bb65..7e18fd94e 100644 --- a/src/tensor/tensor.cpp +++ b/src/tensor/tensor.cpp @@ -164,27 +164,94 @@ void Tensor::debug() const { } bool Tensor::isContiguous() const { - TO_BE_IMPLEMENTED(); + if (this->numel() == 0) { + return true; + } + if (this->ndim() == 0) { + return true; + } + ptrdiff_t expected_stride = 1; + const size_t ndim_ = this->ndim(); + const std::vector &shape_ = this->shape(); + + for (size_t i = 1; i <= ndim_; i++) { + if (this->strides()[ndim_ - i] != expected_stride) { + return false; + } + expected_stride *= shape_[ndim_ - i]; + } return true; } tensor_t Tensor::permute(const std::vector &order) const { - TO_BE_IMPLEMENTED(); - return std::shared_ptr(new Tensor(_meta, _storage)); + const size_t ndim_ = this->ndim(); + CHECK_ARGUMENT(order.size() == ndim_, "permute order size must match tensor ndim"); + + std::vector seen(ndim_, false); + for (const size_t dim : order) { + CHECK_ARGUMENT(dim < ndim_, "permute order out of range"); + CHECK_ARGUMENT(!seen[dim], "permute order has duplicate entries"); + seen[dim] = true; + } + + std::vector new_shape(ndim_); + std::vector new_strides(ndim_); + + for (size_t i = 0; i < ndim_; i++) { + new_shape[i] = _meta.shape[order[i]]; + new_strides[i] = _meta.strides[order[i]]; + } + + TensorMeta meta{this->dtype(), new_shape, new_strides}; + + return std::shared_ptr(new Tensor(meta, _storage)); } tensor_t Tensor::view(const std::vector &shape) const { - TO_BE_IMPLEMENTED(); - return std::shared_ptr(new Tensor(_meta, _storage)); + const size_t new_numel = std::accumulate(shape.begin(), shape.end(), size_t(1), std::multiplies()); + CHECK_ARGUMENT(new_numel == this->numel(), "cannot view tensor with different number of elements"); + const size_t ndim_ = shape.size(); + std::vector new_strides(ndim_); + std::vector new_shape = shape; + size_t stride = 1; + for (size_t i = 1; i <= ndim_; i++) { + new_strides[ndim_ - i] = stride; + stride *= shape[ndim_ - i]; + } + + TensorMeta meta{this->dtype(), new_shape, new_strides}; + return std::shared_ptr(new Tensor(meta, _storage)); } tensor_t Tensor::slice(size_t dim, size_t start, size_t end) const { - TO_BE_IMPLEMENTED(); - return std::shared_ptr(new Tensor(_meta, _storage)); + CHECK_ARGUMENT(dim < this->ndim(), "slice dimension out of range"); + CHECK_ARGUMENT(start <= end && end <= this->shape()[dim], "invalid slice range"); + std::vector new_shape(this->shape()); + new_shape[dim] = end - start; + + std::vector new_strides = this->strides(); + + // 计算新的 offset:start 个元素在该维度上的字节偏移 + // strides[dim] 是元素个数,需要乘以元素大小转换为字节 + size_t new_offset = _offset + start * this->strides()[dim] * this->elementSize(); + + TensorMeta meta{this->dtype(), new_shape, new_strides}; + return std::shared_ptr(new Tensor(meta, _storage, new_offset)); } void Tensor::load(const void *src_) { - TO_BE_IMPLEMENTED(); + CHECK_ARGUMENT(src_ != nullptr, "source pointer is null"); + const size_t bytes = this->numel() * this->elementSize(); + if (bytes == 0) { + return; + } + const llaisysMemcpyKind_t mem_cpy_type = (this->deviceType() == LLAISYS_DEVICE_CPU) ? LLAISYS_MEMCPY_H2H : LLAISYS_MEMCPY_H2D; + core::context().setDevice(this->deviceType(), this->deviceId()); + core::context().runtime().api()->memcpy_sync( + this->data(), + src_, + bytes, + mem_cpy_type); } tensor_t Tensor::contiguous() const { diff --git a/test/ops/sample.py b/test/ops/sample.py new file mode 100644 index 000000000..dc475764c --- /dev/null +++ b/test/ops/sample.py @@ -0,0 +1,128 @@ +import sys +import os + +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +sys.path.insert(0, parent_dir) +import llaisys +import torch +import numpy as np +from test_utils import random_tensor, check_equal, zero_tensor + + +# ── helpers ────────────────────────────────────────────────────────────────── + +def make_logits_ll(logits_np, dtype_name): + dtype_map = { + "f32": (torch.float32, llaisys.DataType.F32), + "f16": (torch.float16, llaisys.DataType.F16), + "bf16": (torch.bfloat16, llaisys.DataType.BF16), + } + td, ld = dtype_map[dtype_name] + t = torch.tensor(logits_np, dtype=td) + ll = llaisys.Tensor(logits_np.shape, dtype=ld, device=llaisys.DeviceType.CPU) + api = llaisys.RuntimeAPI(llaisys.DeviceType.CPU) + api.memcpy_sync(ll.data_ptr(), t.data_ptr(), t.numel() * t.element_size(), llaisys.MemcpyKind.D2D) + return ll + + +def run_sample(logits_np, dtype_name="f32", temperature=1.0, top_k=0, top_p=1.0): + ll = make_logits_ll(logits_np, dtype_name) + out = llaisys.Tensor((1,), dtype=llaisys.DataType.I64, device=llaisys.DeviceType.CPU) + z = torch.zeros((1,), dtype=torch.int64) + api = llaisys.RuntimeAPI(llaisys.DeviceType.CPU) + api.memcpy_sync(out.data_ptr(), z.data_ptr(), 8, llaisys.MemcpyKind.D2D) + llaisys.Ops.sample(out, ll, temperature, top_k, top_p) + r = torch.zeros((1,), dtype=torch.int64) + api.memcpy_sync(r.data_ptr(), out.data_ptr(), 8, llaisys.MemcpyKind.D2D) + return int(r[0]) + + +# ── tests ───────────────────────────────────────────────────────────────────── + +def test_op_sample_temperature(device_name="cpu", profile=False): + """temperature scaling: top_k=1 always returns argmax regardless of temperature.""" + print(" [temperature] top_k=1 always returns argmax") + logits = np.array([1.0, 5.0, 2.0, 3.0], dtype=np.float32) + for temp in [0.5, 1.0, 2.0]: + for dtype in ["f32", "f16", "bf16"]: + results = [run_sample(logits, dtype_name=dtype, temperature=temp, top_k=1) for _ in range(10)] + assert all(r == 1 for r in results), f"temperature={temp} dtype={dtype} failed: {results}" + + +def test_op_sample_distribution(device_name="cpu", profile=False): + """No filtering: sampled distribution should match softmax probabilities.""" + print(" [distribution] multinomial sampling matches softmax probs") + logits = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32) + expected_probs = torch.softmax(torch.tensor(logits), dim=0).numpy() + + n = 5000 + counts = np.zeros(4) + for _ in range(n): + counts[run_sample(logits, top_k=0, top_p=1.0)] += 1 + observed = counts / n + + # Allow 5% absolute tolerance for statistical test + for i in range(4): + assert abs(observed[i] - expected_probs[i]) < 0.05, ( + f"token {i}: observed={observed[i]:.3f} expected={expected_probs[i]:.3f}" + ) + + +def test_op_sample_top_k(device_name="cpu", profile=False): + """top_k=K: only top-K tokens should ever be sampled.""" + print(" [top_k] only top-K tokens sampled") + logits = np.array([1.0, 5.0, 4.0, 0.5], dtype=np.float32) # sorted: idx 1 > 2 > 0 > 3 + + # top_k=1 -> always argmax + results = [run_sample(logits, top_k=1) for _ in range(50)] + assert all(r == 1 for r in results), f"top_k=1 failed: {set(results)}" + + # top_k=2 -> only idx 1 and 2 + counts = [0] * 4 + for _ in range(1000): + counts[run_sample(logits, top_k=2)] += 1 + assert counts[0] == 0 and counts[3] == 0, f"top_k=2 leaked outside top-2: {counts}" + assert counts[1] > 0 and counts[2] > 0, f"top_k=2 missing expected tokens: {counts}" + + +def test_op_sample_top_p(device_name="cpu", profile=False): + """top_p filtering: very low top_p should collapse to argmax.""" + print(" [top_p] top_p=0.0 always returns argmax") + logits = np.array([1.0, 10.0, 2.0, 3.0], dtype=np.float32) + results = [run_sample(logits, top_p=0.0) for _ in range(30)] + assert all(r == 1 for r in results), f"top_p=0.0 failed: {set(results)}" + + # top_p=1.0 should allow all tokens (with enough samples, non-argmax tokens appear) + print(" [top_p] top_p=1.0 allows non-argmax tokens") + logits2 = np.array([3.0, 4.0, 3.0, 3.0], dtype=np.float32) + seen = set() + for _ in range(500): + seen.add(run_sample(logits2, top_p=1.0)) + assert len(seen) > 1, f"top_p=1.0 only sampled {seen}" + + +def test_op_sample_dtype(device_name="cpu", profile=False): + """f16 and bf16 logits: top_k=1 should still return correct argmax.""" + print(" [dtype] f16/bf16 logits return correct argmax with top_k=1") + logits = np.array([1.0, 10.0, 2.0, 3.0], dtype=np.float32) + for dtype in ["f16", "bf16"]: + results = [run_sample(logits, dtype_name=dtype, top_k=1) for _ in range(20)] + assert all(r == 1 for r in results), f"{dtype} top_k=1 failed: {set(results)}" + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia"], type=str) + parser.add_argument("--profile", action="store_true") + args = parser.parse_args() + + print(f"Testing Ops.sample on {args.device}") + test_op_sample_temperature(args.device, args.profile) + test_op_sample_distribution(args.device, args.profile) + test_op_sample_top_k(args.device, args.profile) + test_op_sample_top_p(args.device, args.profile) + test_op_sample_dtype(args.device, args.profile) + + print("\033[92mTest passed!\033[0m\n") diff --git a/test/server/__init__.py b/test/server/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/server/main.py b/test/server/main.py new file mode 100644 index 000000000..dca71188c --- /dev/null +++ b/test/server/main.py @@ -0,0 +1,100 @@ +""" +main.py — FastAPI application entry point + +This file sets up the FastAPI app, mounts static files, and provides the +CLI entry point for starting the chatbot server. + +Usage: + python -m llaisys.server.main --model /path/to/model --port 8000 + +Architecture overview: + Browser <──SSE──> FastAPI (main.py + routes.py) <──ctypes──> C++ backend + │ + Tokenizer (HuggingFace transformers) + +The server is intentionally single-user and single-threaded on the model side: +one request is served at a time. This matches the Project #3 requirement and +keeps the code simple. Multi-user batching is covered in Project #4. +""" + +import argparse +import logging + +import uvicorn +from fastapi import FastAPI +from fastapi.staticfiles import StaticFiles +from fastapi.responses import FileResponse +from pathlib import Path + +from routes import router, set_model + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Create the FastAPI application. +# Including the router from routes.py registers all API endpoints. +app = FastAPI(title="LLAISYS Chatbot") +app.include_router(router) + +# Serve static files (HTML, CSS, JS) under the /static URL prefix. +# Path(__file__).parent resolves to the directory containing this file, +# so "static/" is always found relative to main.py regardless of where +# the server is launched from. +STATIC_DIR = Path(__file__).parent / "static" +app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") + + +@app.get("/") +async def index(): + """Serve the chat UI. The browser loads this on first visit.""" + return FileResponse(STATIC_DIR / "index.html") + + +@app.get("/health") +async def health(): + """ + Health check endpoint. + Returns whether the model has been loaded successfully. + Useful for monitoring or waiting for the server to be ready. + """ + from .routes import MODEL + return {"status": "ok", "model_loaded": MODEL is not None} + + +def main(): + """ + CLI entry point: load the model then start the HTTP server. + + Model loading happens before uvicorn starts so that the first request + doesn't have to wait. The model stays in memory for the lifetime of + the server process. + """ + parser = argparse.ArgumentParser(description="LLAISYS Chatbot Server") + parser.add_argument("--model", required=True, help="Path to model directory") + parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia"]) + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--host", default="localhost") + args = parser.parse_args() + + from transformers import AutoTokenizer + import llaisys + + logger.info(f"Loading model from {args.model} ...") + device = llaisys.DeviceType.CPU if args.device == "cpu" else llaisys.DeviceType.NVIDIA + + # Load the LLAISYS C++ model (weights are read from safetensors files) + model = llaisys.models.Qwen2(args.model, device=device) + + # Load the tokenizer from HuggingFace — used to encode prompts and + # decode generated token IDs back to text + tokenizer = AutoTokenizer.from_pretrained(args.model) + + # Make model and tokenizer available to the route handlers + set_model(model, tokenizer) + logger.info(f"Model loaded. Starting server at http://{args.host}:{args.port}") + + uvicorn.run(app, host=args.host, port=args.port) + + +if __name__ == "__main__": + main() diff --git a/test/server/routes.py b/test/server/routes.py new file mode 100644 index 000000000..1292b2d60 --- /dev/null +++ b/test/server/routes.py @@ -0,0 +1,375 @@ +""" +routes.py — Chat completion API endpoints + session management + +This module implements: + - /v1/chat/completions — generate a response for a given session + - /v1/sessions — list, create, and delete chat sessions + +Session management: + Each session is an independent conversation with its own message history. + Sessions are stored in memory (SESSIONS dict) and identified by a UUID. + The client sends a session_id with every chat request so the server knows + which history to use. + +Key concepts: + - Streaming vs non-streaming: controlled by the `stream` field in the request. + Streaming uses Server-Sent Events (SSE), which lets the browser receive + tokens one by one as they are generated, instead of waiting for the full + response. + - run_in_executor: The C++ model inference is synchronous (blocking). FastAPI + runs on an async event loop (asyncio), so if we call blocking code directly + it freezes the entire server. run_in_executor offloads the blocking call to + a thread pool, keeping the event loop free to handle SSE writes. + - Incremental decode: We decode the full list of generated tokens on every + step and take the new suffix as the delta. This avoids BPE boundary issues + where a single token only produces visible text when combined with the next. + - filtering: DeepSeek-R1 wraps its chain-of-thought reasoning in + ... tags. We strip these before sending text to the client. +""" + +import asyncio +import json +import re +import time +import uuid +from typing import List + +from fastapi import APIRouter, HTTPException +from fastapi.responses import StreamingResponse +from pydantic import BaseModel, Field + +router = APIRouter() + +# Global model and tokenizer, set once at server startup via set_model(). +MODEL = None +TOKENIZER = None + +# ── Session store ───────────────────────────────────────────────────────────── +# Each session is a dict with: +# id — unique UUID string +# title — display name (set to the first user message, truncated) +# history — list of {role, content} dicts, the full conversation so far +# +# In a production system this would be persisted to a database, but for a +# single-user educational server, an in-memory dict is sufficient. +SESSIONS: dict = {} # session_id -> session dict + + +def _new_session() -> dict: + """Create a new empty session and add it to the store.""" + sid = uuid.uuid4().hex[:8] + session = {"id": sid, "title": "New chat", "history": []} + SESSIONS[sid] = session + return session + + +def set_model(model, tokenizer): + """Called by main.py after loading the model to make it available here.""" + global MODEL, TOKENIZER + MODEL = model + TOKENIZER = tokenizer + + +# ── Request / Response schemas ──────────────────────────────────────────────── +# Pydantic models validate incoming JSON automatically and produce clear error +# messages when fields are missing or out of range. + +class ChatMessage(BaseModel): + role: str # "user" or "assistant" + content: str # the message text + + +class ChatCompletionRequest(BaseModel): + messages: List[ChatMessage] + session_id: str = "" # Which session to use; empty = create a new one + # Sampling parameters — see sample operator for algorithm details + temperature: float = Field(default=0.8, ge=0.1, le=2.0, + description="Higher = more random, lower = more deterministic") + top_k: int = Field(default=50, ge=0, + description="Keep only the top-K most likely tokens (0 = disabled)") + top_p: float = Field(default=0.9, ge=0.0, le=1.0, + description="Nucleus sampling: keep tokens whose cumulative prob >= top_p") + max_tokens: int = Field(default=512, gt=0, + description="Maximum number of tokens to generate") + stream: bool = False # If True, use SSE streaming response + + +# ── Helper functions ────────────────────────────────────────────────────────── + +def _clean_output(text: str) -> str: + """ + Remove model-specific artifacts from generated text. + + DeepSeek-R1 outputs two kinds of noise we need to strip: + 1. ... blocks — the model's internal chain-of-thought + reasoning. We don't want to show this to the user. + 2. Special tokens like <|end▁of▁sentence|> — control tokens that the + tokenizer emits when skip_special_tokens=False. We need to decode with + skip_special_tokens=False so that tags are preserved for + filtering, but then we must manually remove the other special tokens. + + Note: the chat template appends "\n" to the prompt, so generated + tokens start *inside* the think block (no opening tag). We handle + both cases: full ... blocks and orphaned tags. + """ + # Strip full chain-of-thought block when both tags are present + text = re.sub(r"[\s\S]*?", "", text) + # Strip everything up to and including an orphaned tag + # (happens when was part of the prompt template, not the output) + text = re.sub(r"^[\s\S]*?", "", text) + # Strip special tokens (both ASCII <|...|> and fullwidth <|...|> variants) + text = re.sub(r"<[||][^||]*[||]>", "", text) + return text.strip() + + +def _check_model(): + """Raise HTTP 503 if the model hasn't been loaded yet.""" + if MODEL is None: + raise HTTPException(status_code=503, detail="Model not loaded") + + +def _encode(messages: List[ChatMessage]) -> List[int]: + """ + Apply the model's chat template and encode to token IDs. + + apply_chat_template formats the conversation history into the exact + prompt format the model was trained on (e.g. <|User|>...<|Assistant|>). + This is important — using the wrong format degrades response quality. + """ + chat = [{"role": m.role, "content": m.content} for m in messages] + text = TOKENIZER.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) + return TOKENIZER.encode(text) + + +def _make_chunk(cid: str, content: str) -> str: + """ + Format a text delta as an SSE data line in OpenAI chunk format. + + SSE (Server-Sent Events) protocol: each event is a line starting with + "data: " followed by the payload, terminated by two newlines. + The client reads these line by line and updates the UI incrementally. + """ + data = { + "id": cid, + "object": "chat.completion.chunk", + "created": int(time.time()), + "choices": [{"index": 0, "delta": {"content": content}, "finish_reason": None}], + } + return f"data: {json.dumps(data)}\n\n" + + +def _make_done_chunk(cid: str) -> str: + """ + Send the final SSE event signaling end of stream. + + The OpenAI streaming protocol ends with a chunk where finish_reason="stop" + followed by a special "data: [DONE]" line. The client uses [DONE] to know + it can close the connection. + """ + data = { + "id": cid, + "object": "chat.completion.chunk", + "created": int(time.time()), + "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], + } + return f"data: {json.dumps(data)}\n\ndata: [DONE]\n\n" + + +# ── Session endpoints ───────────────────────────────────────────────────────── + +@router.get("/v1/sessions") +async def list_sessions(): + """Return all sessions (id + title only, not full history).""" + return [{"id": s["id"], "title": s["title"]} for s in SESSIONS.values()] + + +@router.post("/v1/sessions") +async def create_session(): + """Create a new empty session and return it.""" + session = _new_session() + return {"id": session["id"], "title": session["title"]} + + +@router.get("/v1/sessions/{session_id}") +async def get_session(session_id: str): + """Return a session's full history (id, title, history).""" + if session_id not in SESSIONS: + raise HTTPException(status_code=404, detail="Session not found") + s = SESSIONS[session_id] + return {"id": s["id"], "title": s["title"], "history": s["history"]} + + +@router.delete("/v1/sessions/{session_id}") +async def delete_session(session_id: str): + """Delete a session by ID.""" + if session_id not in SESSIONS: + raise HTTPException(status_code=404, detail="Session not found") + del SESSIONS[session_id] + return {"deleted": session_id} + + +# ── Chat endpoint ───────────────────────────────────────────────────────────── + +@router.post("/v1/chat/completions") +async def chat_completions(req: ChatCompletionRequest): + """ + Main chat endpoint, compatible with the OpenAI chat completion API. + + The client sends the full message list and a session_id. The server + updates the session's history after each successful generation so that + subsequent requests in the same session have full context. + """ + _check_model() + + if not req.messages: + raise HTTPException(status_code=400, detail="messages must not be empty") + + # Resolve or create the session + if req.session_id and req.session_id in SESSIONS: + session = SESSIONS[req.session_id] + else: + # No valid session_id provided — create a new session automatically + session = _new_session() + + # The client sends the full history it knows about; we use that as the + # source of truth (the client may have added a new user message). + session["history"] = [{"role": m.role, "content": m.content} for m in req.messages] + + # Set the session title from the first user message (truncated to 30 chars) + user_msgs = [m for m in req.messages if m.role == "user"] + if user_msgs and session["title"] == "New chat": + session["title"] = user_msgs[0].content[:30] + + # Encode the full conversation history into a flat token sequence. + # The model sees the entire history on every turn — this is how it + # maintains context. The KV cache in the C++ backend avoids recomputing + # attention for tokens it has already seen. + tokens = _encode(req.messages) + cid = f"chatcmpl-{uuid.uuid4().hex[:8]}" + + if req.stream: + return StreamingResponse( + _stream_generate(tokens, cid, req, session), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "X-Accel-Buffering": "no", + # Send session_id in header so client can update its state + "X-Session-Id": session["id"], + }, + ) + + # ── Non-streaming path ──────────────────────────────────────────────────── + generated = [] + current_tokens = list(tokens) + for _ in range(req.max_tokens): + next_token = MODEL.generate( + current_tokens, max_new_tokens=1, + temperature=req.temperature, top_k=req.top_k, top_p=req.top_p, + )[-1] + current_tokens.append(next_token) + generated.append(next_token) + if MODEL._end_token >= 0 and next_token == MODEL._end_token: + break + + text = _clean_output(TOKENIZER.decode(generated, skip_special_tokens=False)) + + # Persist the assistant reply into the session history + session["history"].append({"role": "assistant", "content": text}) + + return { + "id": cid, + "session_id": session["id"], + "object": "chat.completion", + "created": int(time.time()), + "choices": [{ + "index": 0, + "message": {"role": "assistant", "content": text}, + "finish_reason": "stop", + }], + "usage": { + "prompt_tokens": len(tokens), + "completion_tokens": len(generated), + "total_tokens": len(tokens) + len(generated), + }, + } + + +async def _stream_generate(tokens: List[int], cid: str, req: ChatCompletionRequest, session: dict): + """ + Async generator that yields SSE chunks one token at a time. + + The challenge: model inference is synchronous C++ code, but we need to + yield SSE chunks asynchronously. Solution: run_in_executor offloads each + blocking inference call to a thread pool worker, then awaits the result. + This lets the asyncio event loop stay responsive and flush each chunk to + the HTTP client immediately after it's generated. + + Incremental decode strategy: + We decode the full list of generated tokens on every step and compare + with the previous decoded length to get the new text delta. This is + necessary because some tokens only produce visible characters when + decoded together (BPE subword tokenization). Decoding token-by-token + would produce garbled output for multi-byte characters. + + History persistence strategy: + Add assistant message to history immediately with empty content, then + update it after each token. This ensures partial responses are saved + even if the client disconnects mid-stream. + """ + loop = asyncio.get_event_loop() + current_tokens = list(tokens) + generated = [] + prev_clean_len = 0 # Length of clean text already sent to the client + + # Add assistant message placeholder to history immediately + # This ensures we save partial responses if client disconnects + assistant_msg = {"role": "assistant", "content": ""} + session["history"].append(assistant_msg) + + def _infer_one(): + """Synchronous inference call — runs in a thread pool worker.""" + return MODEL.generate( + current_tokens, max_new_tokens=1, + temperature=req.temperature, top_k=req.top_k, top_p=req.top_p, + )[-1] + + try: + for _ in range(req.max_tokens): + # Offload blocking C++ call to thread pool, await the result. + # While waiting, the event loop can handle other tasks (e.g. flush + # already-yielded chunks to the HTTP client). + next_token = await loop.run_in_executor(None, _infer_one) + current_tokens.append(next_token) + generated.append(next_token) + + # Stop generation when end-of-sequence token is produced + if MODEL._end_token >= 0 and next_token == MODEL._end_token: + break + + # Decode all generated tokens so far (not just the latest one). + # This handles BPE boundaries correctly: a token like "▁Hello" + # only decodes to " Hello" when seen in context. + raw = TOKENIZER.decode(generated, skip_special_tokens=False) + clean = _clean_output(raw) + + # Update the assistant message in history with current content + # This ensures partial responses are saved if client disconnects + assistant_msg["content"] = clean + + # The new visible text is whatever was added since last time + delta = clean[prev_clean_len:] + if delta: + prev_clean_len = len(clean) + yield _make_chunk(cid, delta) + + except Exception as e: + # Send error as an SSE event so the client can display it + err = {"error": {"message": str(e), "type": "server_error"}} + yield f"data: {json.dumps(err)}\n\n" + # Update history with partial content even on error + if generated: + assistant_msg["content"] = _clean_output(TOKENIZER.decode(generated, skip_special_tokens=False)) + + # Always send [DONE] to signal end of stream + yield _make_done_chunk(cid) diff --git a/test/server/static/app.js b/test/server/static/app.js new file mode 100644 index 000000000..a5f0a9e61 --- /dev/null +++ b/test/server/static/app.js @@ -0,0 +1,324 @@ +/** + * app.js — Chat UI with multi-session management + * + * Session management overview: + * - Each "session" is an independent conversation stored on the server. + * - The sidebar lists all sessions; clicking one switches the active session. + * - The "+" button creates a new empty session. + * - Each session has its own message history displayed in the chat panel. + * - On page load, we fetch existing sessions from the server and restore them. + * + * Client-side state: + * sessions — map of session_id -> { id, title, messages[] } + * messages[] holds the DOM-ready history for rendering + * activeId — the currently displayed session's id + */ + +// ── DOM references ──────────────────────────────────────────────────────────── + +const messagesEl = document.getElementById('messages'); +const inputEl = document.getElementById('input'); +const sendBtn = document.getElementById('sendBtn'); +const settingsBtn = document.getElementById('settingsBtn'); +const settingsPanel = document.getElementById('settingsPanel'); +const sessionListEl = document.getElementById('sessionList'); +const newChatBtn = document.getElementById('newChatBtn'); +const chatTitleEl = document.getElementById('chatTitle'); + +const tempSlider = document.getElementById('temperature'); +const topkSlider = document.getElementById('topK'); +const toppSlider = document.getElementById('topP'); +const maxTokensInput = document.getElementById('maxTokens'); +const tempVal = document.getElementById('tempVal'); +const topkVal = document.getElementById('topkVal'); +const toppVal = document.getElementById('toppVal'); + +tempSlider.addEventListener('input', () => tempVal.textContent = tempSlider.value); +topkSlider.addEventListener('input', () => topkVal.textContent = topkSlider.value); +toppSlider.addEventListener('input', () => toppVal.textContent = toppSlider.value); +settingsBtn.addEventListener('click', () => { settingsPanel.hidden = !settingsPanel.hidden; }); + +// ── Client-side session state ───────────────────────────────────────────────── + +/** + * sessions: { [id]: { id, title, history } } + * history is the array of {role, content} sent to the server on each request. + * We maintain it client-side so switching sessions is instant (no server round-trip). + */ +let sessions = {}; +let activeId = null; // currently displayed session id + +// ── Session API helpers ─────────────────────────────────────────────────────── + +async function fetchSessions() { + const res = await fetch('/v1/sessions'); + return res.json(); // [{id, title}, ...] +} + +async function createSessionOnServer() { + const res = await fetch('/v1/sessions', { method: 'POST' }); + return res.json(); // {id, title} +} + +async function deleteSessionOnServer(id) { + await fetch(`/v1/sessions/${id}`, { method: 'DELETE' }); +} + +// ── Sidebar rendering ───────────────────────────────────────────────────────── + +function renderSidebar() { + sessionListEl.innerHTML = ''; + for (const s of Object.values(sessions)) { + const li = document.createElement('li'); + li.className = 'session-item' + (s.id === activeId ? ' active' : ''); + li.dataset.id = s.id; + + const titleSpan = document.createElement('span'); + titleSpan.className = 'session-item-title'; + titleSpan.textContent = s.title; + + const delBtn = document.createElement('button'); + delBtn.className = 'session-delete'; + delBtn.textContent = '✕'; + delBtn.title = 'Delete'; + delBtn.addEventListener('click', async (e) => { + e.stopPropagation(); // Don't trigger the session switch + await deleteSessionOnServer(s.id); + delete sessions[s.id]; + // If we deleted the active session, switch to another or create new + if (activeId === s.id) { + const remaining = Object.keys(sessions); + if (remaining.length > 0) { + switchSession(remaining[0]); + } else { + await newChat(); + } + } else { + renderSidebar(); + } + }); + + li.appendChild(titleSpan); + li.appendChild(delBtn); + li.addEventListener('click', () => switchSession(s.id)); + sessionListEl.appendChild(li); + } +} + +// ── Session switching ───────────────────────────────────────────────────────── + +function switchSession(id) { + activeId = id; + const session = sessions[id]; + + // Update header title + chatTitleEl.textContent = session.title; + + // Re-render the message list for this session + messagesEl.innerHTML = ''; + for (const msg of session.history) { + addMessage(msg.role, msg.content); + } + + renderSidebar(); + inputEl.focus(); +} + +async function newChat() { + // Create session on server so it gets a persistent id + const s = await createSessionOnServer(); + sessions[s.id] = { id: s.id, title: s.title, history: [] }; + switchSession(s.id); +} + +// ── UI helpers ──────────────────────────────────────────────────────────────── + +function addMessage(role, content) { + const msg = document.createElement('div'); + msg.className = `message ${role}`; + + const label = document.createElement('div'); + label.className = 'role-label'; + label.textContent = role === 'user' ? 'You' : 'Assistant'; + + const bubble = document.createElement('div'); + bubble.className = 'bubble'; + bubble.textContent = content; + + msg.appendChild(label); + msg.appendChild(bubble); + messagesEl.appendChild(msg); + scrollToBottom(); + return bubble; +} + +function addError(text) { + const el = document.createElement('div'); + el.className = 'error-msg'; + el.textContent = text; + messagesEl.appendChild(el); + scrollToBottom(); +} + +function scrollToBottom() { + messagesEl.scrollTop = messagesEl.scrollHeight; +} + +function cleanModelOutput(text) { + return text + .replace(/[\s\S]*?<\/think>/g, '') + .replace(/^[\s\S]*?<\/think>/g, '') + .replace(/<[||][^||]*[||]>/g, '') + .trim(); +} + +// ── Send message ────────────────────────────────────────────────────────────── + +async function sendMessage() { + const text = inputEl.value.trim(); + if (!text || !activeId) return; + + inputEl.value = ''; + inputEl.style.height = 'auto'; + sendBtn.disabled = true; + + const session = sessions[activeId]; + + // Add user message to local history and display it + session.history.push({ role: 'user', content: text }); + addMessage('user', text); + + // Update session title from first user message + if (session.title === 'New chat') { + session.title = text.slice(0, 30); + chatTitleEl.textContent = session.title; + renderSidebar(); + } + + // Create empty assistant bubble with typing indicator + const assistantMsg = document.createElement('div'); + assistantMsg.className = 'message assistant typing'; + const label = document.createElement('div'); + label.className = 'role-label'; + label.textContent = 'Assistant'; + const bubble = document.createElement('div'); + bubble.className = 'bubble'; + assistantMsg.appendChild(label); + assistantMsg.appendChild(bubble); + messagesEl.appendChild(assistantMsg); + scrollToBottom(); + + const body = { + session_id: activeId, + messages: session.history, + temperature: parseFloat(tempSlider.value), + top_k: parseInt(topkSlider.value), + top_p: parseFloat(toppSlider.value), + max_tokens: parseInt(maxTokensInput.value), + stream: true, + }; + + let fullText = ''; + try { + const resp = await fetch('/v1/chat/completions', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify(body), + }); + + if (!resp.ok) { + const err = await resp.json().catch(() => ({ detail: resp.statusText })); + throw new Error(err.detail || resp.statusText); + } + + // Read SSE stream + const reader = resp.body.getReader(); + const decoder = new TextDecoder(); + let buf = ''; + + while (true) { + const { done, value } = await reader.read(); + if (done) break; + buf += decoder.decode(value, { stream: true }); + + const lines = buf.split('\n'); + buf = lines.pop(); + + for (const line of lines) { + if (!line.startsWith('data: ')) continue; + const payload = line.slice(6).trim(); + if (payload === '[DONE]') break; + try { + const chunk = JSON.parse(payload); + if (chunk.error) throw new Error(chunk.error.message); + const delta = chunk.choices?.[0]?.delta?.content; + if (delta) { + fullText += delta; + bubble.textContent = fullText; + scrollToBottom(); + } + } catch (e) { /* skip malformed chunks */ } + } + } + + assistantMsg.classList.remove('typing'); + const cleanText = cleanModelOutput(fullText); + bubble.textContent = cleanText; + + // Persist assistant reply in local session history + session.history.push({ role: 'assistant', content: cleanText }); + + } catch (err) { + assistantMsg.remove(); + addError(`Error: ${err.message}`); + session.history.pop(); // Remove the user message added optimistically + } finally { + sendBtn.disabled = false; + inputEl.focus(); + } +} + +// ── Event listeners ─────────────────────────────────────────────────────────── + +newChatBtn.addEventListener('click', newChat); +sendBtn.addEventListener('click', sendMessage); + +inputEl.addEventListener('keydown', (e) => { + if (e.key === 'Enter' && !e.shiftKey) { + e.preventDefault(); + sendMessage(); + } +}); + +inputEl.addEventListener('input', () => { + inputEl.style.height = 'auto'; + inputEl.style.height = Math.min(inputEl.scrollHeight, 120) + 'px'; +}); + +// ── Initialization ──────────────────────────────────────────────────────────── + +async function fetchSessionHistory(id) { + const res = await fetch(`/v1/sessions/${id}`); + if (!res.ok) return []; + const data = await res.json(); + return data.history || []; +} + +async function init() { + // Load any sessions that already exist on the server (e.g. after page refresh) + const existing = await fetchSessions(); + for (const s of existing) { + const history = await fetchSessionHistory(s.id); + sessions[s.id] = { id: s.id, title: s.title, history }; + } + + if (Object.keys(sessions).length > 0) { + // Restore the first session + switchSession(Object.keys(sessions)[0]); + } else { + // No sessions yet — create the first one automatically + await newChat(); + } +} + +init(); diff --git a/test/server/static/index.html b/test/server/static/index.html new file mode 100644 index 000000000..0b74a6f7d --- /dev/null +++ b/test/server/static/index.html @@ -0,0 +1,58 @@ + + + + + + LLAISYS Chatbot + + + +
+ + + + + +
+
+ LLAISYS Chatbot + +
+ + + +
+ +
+ + +
+
+ +
+ + + diff --git a/test/server/static/style.css b/test/server/static/style.css new file mode 100644 index 000000000..2eb8738c8 --- /dev/null +++ b/test/server/static/style.css @@ -0,0 +1,327 @@ +*, *::before, *::after { box-sizing: border-box; margin: 0; padding: 0; } + +body { + font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif; + background: #f5f5f5; + height: 100vh; + display: flex; + justify-content: center; + align-items: stretch; +} + +/* ── Overall layout: sidebar + chat panel side by side ── */ +.layout { + display: flex; + width: 100%; + max-width: 1000px; + box-shadow: 0 0 20px rgba(0,0,0,0.08); +} + +/* ── Sidebar ── */ +.sidebar { + width: 220px; + flex-shrink: 0; + background: #1e1e2e; + display: flex; + flex-direction: column; + color: #cdd6f4; +} + +.sidebar-header { + display: flex; + align-items: center; + justify-content: space-between; + padding: 16px 14px 12px; + border-bottom: 1px solid #313244; +} +.sidebar-title { font-size: 0.85rem; font-weight: 600; letter-spacing: 0.05em; text-transform: uppercase; color: #a6adc8; } +.sidebar-header button { + background: none; + border: none; + color: #cdd6f4; + font-size: 1.3rem; + cursor: pointer; + padding: 2px 6px; + border-radius: 6px; + line-height: 1; + transition: background 0.15s; +} +.sidebar-header button:hover { background: #313244; } + +.session-list { + list-style: none; + flex: 1; + overflow-y: auto; + padding: 8px 0; +} + +.session-item { + display: flex; + align-items: center; + padding: 8px 14px; + cursor: pointer; + border-radius: 6px; + margin: 2px 6px; + font-size: 0.88rem; + color: #cdd6f4; + transition: background 0.12s; + gap: 8px; +} +.session-item:hover { background: #313244; } +.session-item.active { background: #45475a; color: #fff; } + +.session-item-title { + flex: 1; + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; +} + +.session-delete { + background: none; + border: none; + color: #6c7086; + font-size: 0.9rem; + cursor: pointer; + padding: 2px 4px; + border-radius: 4px; + opacity: 0; + transition: opacity 0.15s, color 0.15s; + flex-shrink: 0; +} +.session-item:hover .session-delete { opacity: 1; } +.session-delete:hover { color: #f38ba8; } + +/* ── Chat panel ── */ +.app { + flex: 1; + display: flex; + flex-direction: column; + background: #fff; + min-width: 0; +} + +/* Header */ +.header { + display: flex; + align-items: center; + justify-content: space-between; + padding: 14px 20px; + border-bottom: 1px solid #e5e5e5; + background: #fff; +} +.header-title { font-size: 1.05rem; font-weight: 600; color: #212529; overflow: hidden; text-overflow: ellipsis; white-space: nowrap; } +.settings-btn { + background: none; + border: none; + font-size: 1.3rem; + cursor: pointer; + color: #666; + padding: 4px 8px; + border-radius: 6px; + transition: background 0.15s; + flex-shrink: 0; +} +.settings-btn:hover { background: #f0f0f0; } + +/* Settings panel */ +.settings-panel { + padding: 14px 20px; + border-bottom: 1px solid #e5e5e5; + background: #fafafa; + display: flex; + flex-wrap: wrap; + gap: 16px; +} +.setting-row { display: flex; flex-direction: column; gap: 4px; min-width: 140px; } +.setting-row label { font-size: 0.8rem; color: #555; } +.setting-row input[type="range"] { width: 100%; accent-color: #007bff; } +.setting-row input[type="number"] { width: 80px; padding: 4px 8px; border: 1px solid #ddd; border-radius: 6px; font-size: 0.9rem; } + +/* Messages */ +.messages { + flex: 1; + overflow-y: auto; + padding: 20px; + display: flex; + flex-direction: column; + gap: 12px; +} + +.message { display: flex; flex-direction: column; max-width: 75%; animation: fadeIn 0.15s ease; } +@keyframes fadeIn { from { opacity: 0; transform: translateY(4px); } to { opacity: 1; } } + +.message.user { align-self: flex-end; align-items: flex-end; } +.message.assistant { align-self: flex-start; align-items: flex-start; } + +.bubble { + padding: 10px 14px; + border-radius: 16px; + font-size: 0.95rem; + line-height: 1.5; + white-space: pre-wrap; + word-break: break-word; +} +.message.user .bubble { background: #007bff; color: #fff; border-bottom-right-radius: 4px; } +.message.assistant .bubble { background: #f1f3f5; color: #212529; border-bottom-left-radius: 4px; } +.role-label { font-size: 0.72rem; color: #999; margin-bottom: 3px; padding: 0 4px; } + +/* Typing indicator */ +.typing .bubble::after { content: "▋"; animation: blink 0.8s step-start infinite; } +@keyframes blink { 50% { opacity: 0; } } + +/* Input area */ +.input-area { display: flex; gap: 10px; padding: 14px 20px; border-top: 1px solid #e5e5e5; background: #fff; } +#input { + flex: 1; resize: none; border: 1px solid #ddd; border-radius: 10px; + padding: 10px 14px; font-size: 0.95rem; font-family: inherit; + outline: none; max-height: 120px; overflow-y: auto; transition: border-color 0.15s; +} +#input:focus { border-color: #007bff; } +#sendBtn { + padding: 10px 20px; background: #007bff; color: #fff; border: none; + border-radius: 10px; font-size: 0.95rem; cursor: pointer; transition: background 0.15s; white-space: nowrap; +} +#sendBtn:hover:not(:disabled) { background: #0069d9; } +#sendBtn:disabled { background: #aaa; cursor: not-allowed; } + +/* Error message */ +.error-msg { + color: #dc3545; font-size: 0.85rem; padding: 6px 10px; + background: #fff5f5; border-radius: 8px; border: 1px solid #f5c6cb; +} + + +/* Header */ +.header { + display: flex; + align-items: center; + justify-content: space-between; + padding: 14px 20px; + border-bottom: 1px solid #e5e5e5; + background: #fff; +} +.header-title { font-size: 1.1rem; font-weight: 600; color: #212529; } +.settings-btn { + background: none; + border: none; + font-size: 1.3rem; + cursor: pointer; + color: #666; + padding: 4px 8px; + border-radius: 6px; + transition: background 0.15s; +} +.settings-btn:hover { background: #f0f0f0; } + +/* Settings panel */ +.settings-panel { + padding: 14px 20px; + border-bottom: 1px solid #e5e5e5; + background: #fafafa; + display: flex; + flex-wrap: wrap; + gap: 16px; +} +.setting-row { + display: flex; + flex-direction: column; + gap: 4px; + min-width: 140px; +} +.setting-row label { font-size: 0.8rem; color: #555; } +.setting-row input[type="range"] { width: 100%; accent-color: #007bff; } +.setting-row input[type="number"] { + width: 80px; + padding: 4px 8px; + border: 1px solid #ddd; + border-radius: 6px; + font-size: 0.9rem; +} + +/* Messages */ +.messages { + flex: 1; + overflow-y: auto; + padding: 20px; + display: flex; + flex-direction: column; + gap: 12px; +} + +.message { + display: flex; + flex-direction: column; + max-width: 75%; + animation: fadeIn 0.15s ease; +} +@keyframes fadeIn { from { opacity: 0; transform: translateY(4px); } to { opacity: 1; } } + +.message.user { align-self: flex-end; align-items: flex-end; } +.message.assistant { align-self: flex-start; align-items: flex-start; } + +.bubble { + padding: 10px 14px; + border-radius: 16px; + font-size: 0.95rem; + line-height: 1.5; + white-space: pre-wrap; + word-break: break-word; +} +.message.user .bubble { background: #007bff; color: #fff; border-bottom-right-radius: 4px; } +.message.assistant .bubble { background: #f1f3f5; color: #212529; border-bottom-left-radius: 4px; } + +.role-label { font-size: 0.72rem; color: #999; margin-bottom: 3px; padding: 0 4px; } + +/* Typing indicator */ +.typing .bubble::after { + content: "▋"; + animation: blink 0.8s step-start infinite; +} +@keyframes blink { 50% { opacity: 0; } } + +/* Input area */ +.input-area { + display: flex; + gap: 10px; + padding: 14px 20px; + border-top: 1px solid #e5e5e5; + background: #fff; +} +#input { + flex: 1; + resize: none; + border: 1px solid #ddd; + border-radius: 10px; + padding: 10px 14px; + font-size: 0.95rem; + font-family: inherit; + outline: none; + max-height: 120px; + overflow-y: auto; + transition: border-color 0.15s; +} +#input:focus { border-color: #007bff; } +#sendBtn { + padding: 10px 20px; + background: #007bff; + color: #fff; + border: none; + border-radius: 10px; + font-size: 0.95rem; + cursor: pointer; + transition: background 0.15s; + white-space: nowrap; +} +#sendBtn:hover:not(:disabled) { background: #0069d9; } +#sendBtn:disabled { background: #aaa; cursor: not-allowed; } + +/* Error message */ +.error-msg { + color: #dc3545; + font-size: 0.85rem; + padding: 6px 10px; + background: #fff5f5; + border-radius: 8px; + border: 1px solid #f5c6cb; +} diff --git a/test/test_infer.py b/test/test_infer.py index 59d06b874..b5f6e34b2 100644 --- a/test/test_infer.py +++ b/test/test_infer.py @@ -25,7 +25,7 @@ def load_hf_model(model_path=None, device_name="cpu"): tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( model_path, - torch_dtype=torch.bfloat16, + dtype =torch.bfloat16, device_map=torch_device(device_name), trust_remote_code=True, ) diff --git a/xmake.lua b/xmake.lua index 1f65f7a95..da853bebe 100644 --- a/xmake.lua +++ b/xmake.lua @@ -95,6 +95,26 @@ target("llaisys-ops") on_install(function (target) end) target_end() +target("llaisys-models") + set_kind("static") + add_deps("llaisys-utils") + add_deps("llaisys-device") + add_deps("llaisys-core") + add_deps("llaisys-tensor") + add_deps("llaisys-ops") + + set_languages("cxx17") + set_warnings("all", "error") + if not is_plat("windows") then + add_cxflags("-fPIC", "-Wno-unknown-pragmas") + end + + add_files("src/models/*.cpp") + add_files("src/llaisys/models/*.cc") + + on_install(function (target) end) +target_end() + target("llaisys") set_kind("shared") add_deps("llaisys-utils") @@ -102,10 +122,12 @@ target("llaisys") add_deps("llaisys-core") add_deps("llaisys-tensor") add_deps("llaisys-ops") + add_deps("llaisys-models") set_languages("cxx17") set_warnings("all", "error") add_files("src/llaisys/*.cc") + add_files("src/llaisys/models/*.cc") set_installdir(".") @@ -119,4 +141,4 @@ target("llaisys") os.cp("lib/*.so", "python/llaisys/libllaisys/") end end) -target_end() \ No newline at end of file +target_end()