diff --git a/include/llaisys/build_config.h.in b/include/llaisys/build_config.h.in
new file mode 100644
index 000000000..b73b36684
--- /dev/null
+++ b/include/llaisys/build_config.h.in
@@ -0,0 +1,6 @@
+#ifndef LLAISYS_BUILD_CONFIG_H
+#define LLAISYS_BUILD_CONFIG_H
+
+${define ENABLE_NVIDIA_API}
+
+#endif
diff --git a/include/llaisys/models/qwen2.h b/include/llaisys/models/qwen2.h
index 7054626d4..9e7726eb6 100644
--- a/include/llaisys/models/qwen2.h
+++ b/include/llaisys/models/qwen2.h
@@ -38,5 +38,10 @@ __C {
__export struct LlaisysQwen2Weights *llaisysQwen2ModelWeights(struct LlaisysQwen2Model * model);
__export int64_t llaisysQwen2ModelInfer(struct LlaisysQwen2Model * model, int64_t * token_ids, size_t ntoken);
+
+ __export int64_t llaisysQwen2ModelInferSample(struct LlaisysQwen2Model * model, int64_t * token_ids, size_t ntoken,
+ float temperature, int top_k, float top_p);
+
+ __export void llaisysQwen2ModelResetKVCache(struct LlaisysQwen2Model * model);
}
#endif // LLAISYS_MODELS_QWEN2_H
diff --git a/include/llaisys/ops.h b/include/llaisys/ops.h
index ddb3be246..c631f62d7 100644
--- a/include/llaisys/ops.h
+++ b/include/llaisys/ops.h
@@ -13,6 +13,7 @@ __C {
__export void llaisysROPE(llaisysTensor_t out, llaisysTensor_t in, llaisysTensor_t pos_ids, float theta);
__export void llaisysSelfAttention(llaisysTensor_t attn_val, llaisysTensor_t q, llaisysTensor_t k, llaisysTensor_t v, float scale);
__export void llaisysSwiGLU(llaisysTensor_t out, llaisysTensor_t gate, llaisysTensor_t up);
+ __export void llaisysSample(llaisysTensor_t out_idx, llaisysTensor_t logits, float temperature, int top_k, float top_p);
}
#endif
diff --git a/python/llaisys/libllaisys/__init__.py b/python/llaisys/libllaisys/__init__.py
index f536fb527..01ad8db2f 100644
--- a/python/llaisys/libllaisys/__init__.py
+++ b/python/llaisys/libllaisys/__init__.py
@@ -12,6 +12,8 @@
from .tensor import llaisysTensor_t
from .tensor import load_tensor
from .ops import load_ops
+from .qwen2 import load_qwen2
+from .qwen2 import LlaisysQwen2Meta, LlaisysQwen2Weights, llaisysQwen2Model_t
def load_shared_library():
@@ -38,6 +40,7 @@ def load_shared_library():
load_runtime(LIB_LLAISYS)
load_tensor(LIB_LLAISYS)
load_ops(LIB_LLAISYS)
+load_qwen2(LIB_LLAISYS)
__all__ = [
diff --git a/python/llaisys/libllaisys/ops.py b/python/llaisys/libllaisys/ops.py
index 5be095eff..2d195dc18 100644
--- a/python/llaisys/libllaisys/ops.py
+++ b/python/llaisys/libllaisys/ops.py
@@ -1,5 +1,5 @@
from .tensor import llaisysTensor_t
-from ctypes import c_float
+from ctypes import c_float, c_int
def load_ops(lib):
lib.llaisysAdd.argtypes = [llaisysTensor_t, llaisysTensor_t, llaisysTensor_t]
@@ -34,3 +34,6 @@ def load_ops(lib):
lib.llaisysSwiGLU.argtypes = [llaisysTensor_t, llaisysTensor_t, llaisysTensor_t]
lib.llaisysSwiGLU.restype = None
+
+ lib.llaisysSample.argtypes = [llaisysTensor_t, llaisysTensor_t, c_float, c_int, c_float]
+ lib.llaisysSample.restype = None
diff --git a/python/llaisys/libllaisys/qwen2.py b/python/llaisys/libllaisys/qwen2.py
new file mode 100644
index 000000000..1ea1cc59d
--- /dev/null
+++ b/python/llaisys/libllaisys/qwen2.py
@@ -0,0 +1,72 @@
+import ctypes
+from ctypes import c_void_p, c_size_t, c_int, c_int64, c_float, Structure, POINTER
+from .llaisys_types import llaisysDataType_t, llaisysDeviceType_t
+from .tensor import llaisysTensor_t
+
+
+class LlaisysQwen2Meta(Structure):
+ _fields_ = [
+ ("dtype", llaisysDataType_t),
+ ("nlayer", c_size_t),
+ ("hs", c_size_t),
+ ("nh", c_size_t),
+ ("nkvh", c_size_t),
+ ("dh", c_size_t),
+ ("di", c_size_t),
+ ("maxseq", c_size_t),
+ ("voc", c_size_t),
+ ("epsilon", c_float),
+ ("theta", c_float),
+ ("end_token", c_int64),
+ ]
+
+
+class LlaisysQwen2Weights(Structure):
+ _fields_ = [
+ ("in_embed", llaisysTensor_t),
+ ("out_embed", llaisysTensor_t),
+ ("out_norm_w", llaisysTensor_t),
+ ("attn_norm_w", POINTER(llaisysTensor_t)),
+ ("attn_q_w", POINTER(llaisysTensor_t)),
+ ("attn_q_b", POINTER(llaisysTensor_t)),
+ ("attn_k_w", POINTER(llaisysTensor_t)),
+ ("attn_k_b", POINTER(llaisysTensor_t)),
+ ("attn_v_w", POINTER(llaisysTensor_t)),
+ ("attn_v_b", POINTER(llaisysTensor_t)),
+ ("attn_o_w", POINTER(llaisysTensor_t)),
+ ("mlp_norm_w", POINTER(llaisysTensor_t)),
+ ("mlp_gate_w", POINTER(llaisysTensor_t)),
+ ("mlp_up_w", POINTER(llaisysTensor_t)),
+ ("mlp_down_w", POINTER(llaisysTensor_t)),
+ ]
+
+
+llaisysQwen2Model_t = c_void_p
+
+
+def load_qwen2(lib):
+ lib.llaisysQwen2ModelCreate.argtypes = [
+ POINTER(LlaisysQwen2Meta),
+ llaisysDeviceType_t,
+ POINTER(c_int),
+ c_int,
+ ]
+ lib.llaisysQwen2ModelCreate.restype = llaisysQwen2Model_t
+
+ lib.llaisysQwen2ModelDestroy.argtypes = [llaisysQwen2Model_t]
+ lib.llaisysQwen2ModelDestroy.restype = None
+
+ lib.llaisysQwen2ModelWeights.argtypes = [llaisysQwen2Model_t]
+ lib.llaisysQwen2ModelWeights.restype = POINTER(LlaisysQwen2Weights)
+
+ lib.llaisysQwen2ModelInfer.argtypes = [llaisysQwen2Model_t, POINTER(c_int64), c_size_t]
+ lib.llaisysQwen2ModelInfer.restype = c_int64
+
+ lib.llaisysQwen2ModelInferSample.argtypes = [
+ llaisysQwen2Model_t, POINTER(c_int64), c_size_t,
+ c_float, c_int, c_float,
+ ]
+ lib.llaisysQwen2ModelInferSample.restype = c_int64
+
+ lib.llaisysQwen2ModelResetKVCache.argtypes = [llaisysQwen2Model_t]
+ lib.llaisysQwen2ModelResetKVCache.restype = None
diff --git a/python/llaisys/models/qwen2.py b/python/llaisys/models/qwen2.py
index 0d07b0b21..37d7a2a5f 100644
--- a/python/llaisys/models/qwen2.py
+++ b/python/llaisys/models/qwen2.py
@@ -1,23 +1,121 @@
-from typing import Sequence
+from typing import Sequence, Iterator
from ..libllaisys import LIB_LLAISYS
-from ..libllaisys import DeviceType
+from ..libllaisys import DeviceType, DataType
+from ..libllaisys import LlaisysQwen2Meta, LlaisysQwen2Weights
from pathlib import Path
+import ctypes
+import json
import safetensors
+import torch
class Qwen2:
- def __init__(self, model_path, device: DeviceType = DeviceType.CPU):
- # TODO: Implement model constructor
+ DTYPE_MAP = {
+ "bfloat16": DataType.BF16,
+ "float16": DataType.F16,
+ "float32": DataType.F32,
+ }
+ def __init__(self, model_path, device: DeviceType = DeviceType.CPU):
model_path = Path(model_path)
+ with open(model_path / "config.json") as f:
+ config = json.load(f)
+
+ torch_dtype = config.get("torch_dtype", "bfloat16")
+ dtype = self.DTYPE_MAP.get(torch_dtype, DataType.BF16)
+
+ nh = config["num_attention_heads"]
+ nkvh = config["num_key_value_heads"]
+ hs = config["hidden_size"]
+ dh = hs // nh
+
+ meta = LlaisysQwen2Meta()
+ meta.dtype = dtype
+ meta.nlayer = config["num_hidden_layers"]
+ meta.hs = hs
+ meta.nh = nh
+ meta.nkvh = nkvh
+ meta.dh = dh
+ meta.di = config["intermediate_size"]
+ meta.maxseq = min(config.get("max_position_embeddings", 131072), 4096)
+ meta.voc = config["vocab_size"]
+ meta.epsilon = config.get("rms_norm_eps", 1e-6)
+ meta.theta = config.get("rope_theta", 10000.0)
+ meta.end_token = config.get("eos_token_id", 151643)
+ if isinstance(meta.end_token, list):
+ meta.end_token = meta.end_token[0]
+
+ self._nlayer = meta.nlayer
+ self._end_token = meta.end_token
+ self._device = device
+
+ device_ids = (ctypes.c_int * 1)(0)
+ self._model = LIB_LLAISYS.llaisysQwen2ModelCreate(
+ ctypes.byref(meta),
+ ctypes.c_int(device),
+ device_ids,
+ ctypes.c_int(1),
+ )
+
+ weights_ptr = LIB_LLAISYS.llaisysQwen2ModelWeights(self._model)
+ weights = weights_ptr.contents
+
+ name_map = self._build_name_map(weights)
+
for file in sorted(model_path.glob("*.safetensors")):
- data_ = safetensors.safe_open(file, framework="numpy", device="cpu")
+ data_ = safetensors.safe_open(file, framework="pt", device="cpu")
for name_ in data_.keys():
- ## TODO: load the model weights
- pass
+ if name_ in name_map:
+ tensor_handle = name_map[name_]
+ t = data_.get_tensor(name_).contiguous()
+ LIB_LLAISYS.tensorLoad(tensor_handle, ctypes.c_void_p(t.data_ptr()))
+
+ def _build_name_map(self, weights: LlaisysQwen2Weights):
+ m = {}
+ m["model.embed_tokens.weight"] = weights.in_embed
+ m["lm_head.weight"] = weights.out_embed
+ m["model.norm.weight"] = weights.out_norm_w
+
+ for i in range(self._nlayer):
+ prefix = f"model.layers.{i}"
+ m[f"{prefix}.input_layernorm.weight"] = weights.attn_norm_w[i]
+ m[f"{prefix}.self_attn.q_proj.weight"] = weights.attn_q_w[i]
+ m[f"{prefix}.self_attn.q_proj.bias"] = weights.attn_q_b[i]
+ m[f"{prefix}.self_attn.k_proj.weight"] = weights.attn_k_w[i]
+ m[f"{prefix}.self_attn.k_proj.bias"] = weights.attn_k_b[i]
+ m[f"{prefix}.self_attn.v_proj.weight"] = weights.attn_v_w[i]
+ m[f"{prefix}.self_attn.v_proj.bias"] = weights.attn_v_b[i]
+ m[f"{prefix}.self_attn.o_proj.weight"] = weights.attn_o_w[i]
+ m[f"{prefix}.post_attention_layernorm.weight"] = weights.mlp_norm_w[i]
+ m[f"{prefix}.mlp.gate_proj.weight"] = weights.mlp_gate_w[i]
+ m[f"{prefix}.mlp.up_proj.weight"] = weights.mlp_up_w[i]
+ m[f"{prefix}.mlp.down_proj.weight"] = weights.mlp_down_w[i]
+
+ return m
+
+ def __del__(self):
+ if hasattr(self, "_model") and self._model is not None:
+ LIB_LLAISYS.llaisysQwen2ModelDestroy(self._model)
+ self._model = None
+
+ def reset_kvcache(self):
+ LIB_LLAISYS.llaisysQwen2ModelResetKVCache(self._model)
+
+ def _infer_one(self, token_ids, use_sample, temperature, top_k, top_p):
+ arr = (ctypes.c_int64 * len(token_ids))(*token_ids)
+ n = ctypes.c_size_t(len(token_ids))
+ if use_sample:
+ return LIB_LLAISYS.llaisysQwen2ModelInferSample(
+ self._model, arr, n,
+ ctypes.c_float(temperature),
+ ctypes.c_int(top_k),
+ ctypes.c_float(top_p),
+ )
+ else:
+ return LIB_LLAISYS.llaisysQwen2ModelInfer(self._model, arr, n)
def generate(
self,
@@ -27,7 +125,38 @@ def generate(
top_p: float = 0.8,
temperature: float = 0.8,
):
+ if max_new_tokens is None:
+ max_new_tokens = 128
+
+ use_sample = not (top_k == 1 and temperature == 1.0)
+ tokens = list(inputs)
+
+ next_token = self._infer_one(tokens, use_sample, temperature, top_k, top_p)
+ tokens.append(next_token)
+
+ for _ in range(max_new_tokens - 1):
+ if next_token == self._end_token:
+ break
+ next_token = self._infer_one([next_token], use_sample, temperature, top_k, top_p)
+ tokens.append(next_token)
+
+ return tokens
+
+ def generate_stream(
+ self,
+ inputs: Sequence[int],
+ max_new_tokens: int = 512,
+ top_k: int = 50,
+ top_p: float = 0.9,
+ temperature: float = 0.8,
+ ) -> Iterator[int]:
+ use_sample = not (top_k == 1 and temperature == 1.0)
- # TODO: Implement generate function
+ next_token = self._infer_one(list(inputs), use_sample, temperature, top_k, top_p)
+ yield next_token
- return []
+ for _ in range(max_new_tokens - 1):
+ if next_token == self._end_token:
+ return
+ next_token = self._infer_one([next_token], use_sample, temperature, top_k, top_p)
+ yield next_token
diff --git a/python/llaisys/ops.py b/python/llaisys/ops.py
index ed0180bc8..3fa7770c7 100644
--- a/python/llaisys/ops.py
+++ b/python/llaisys/ops.py
@@ -1,6 +1,6 @@
from .libllaisys import LIB_LLAISYS
from .tensor import Tensor
-from ctypes import c_float, c_int
+from ctypes import c_float, c_int, c_int64
class Ops:
@@ -19,9 +19,10 @@ def embedding(out: Tensor, index: Tensor, weight: Tensor):
)
@staticmethod
- def linear(out: Tensor, inp: Tensor, weight: Tensor, bias: Tensor):
+ def linear(out: Tensor, inp: Tensor, weight: Tensor, bias: Tensor = None):
+ bias_handle = bias.lib_tensor() if bias is not None else None
LIB_LLAISYS.llaisysLinear(
- out.lib_tensor(), inp.lib_tensor(), weight.lib_tensor(), bias.lib_tensor()
+ out.lib_tensor(), inp.lib_tensor(), weight.lib_tensor(), bias_handle
)
@staticmethod
@@ -53,3 +54,10 @@ def self_attention(attn_val: Tensor, q: Tensor, k: Tensor, v: Tensor, scale: flo
@staticmethod
def swiglu(out: Tensor, gate: Tensor, up: Tensor):
LIB_LLAISYS.llaisysSwiGLU(out.lib_tensor(), gate.lib_tensor(), up.lib_tensor())
+
+ @staticmethod
+ def sample(out_idx: Tensor, logits: Tensor, temperature: float = 1.0, top_k: int = 50, top_p: float = 0.9):
+ LIB_LLAISYS.llaisysSample(
+ out_idx.lib_tensor(), logits.lib_tensor(),
+ c_float(temperature), c_int(top_k), c_float(top_p)
+ )
diff --git a/python/llaisys/server.py b/python/llaisys/server.py
new file mode 100644
index 000000000..97f58a230
--- /dev/null
+++ b/python/llaisys/server.py
@@ -0,0 +1,244 @@
+"""
+LLAISYS Chat Server — OpenAI-compatible chat-completion API.
+
+Usage:
+ python -m llaisys.server --model /path/to/model [--host 0.0.0.0] [--port 8000]
+"""
+
+import argparse
+import json
+import time
+import uuid
+import threading
+from pathlib import Path
+from typing import List, Optional
+
+from fastapi import FastAPI, HTTPException
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
+from fastapi.staticfiles import StaticFiles
+from pydantic import BaseModel, Field
+
+from transformers import AutoTokenizer
+
+from .models.qwen2 import Qwen2
+from .libllaisys import DeviceType
+
+app = FastAPI(title="LLAISYS Chat Server")
+app.add_middleware(
+ CORSMiddleware,
+ allow_origins=["*"],
+ allow_methods=["*"],
+ allow_headers=["*"],
+)
+
+_model: Optional[Qwen2] = None
+_tokenizer = None
+_lock = threading.Lock()
+_model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
+
+
+# ── Pydantic schemas (OpenAI-compatible) ──────────────────────────────────
+
+class ChatMessage(BaseModel):
+ role: str
+ content: str
+
+class ChatCompletionRequest(BaseModel):
+ model: str = "qwen2"
+ messages: List[ChatMessage]
+ max_tokens: Optional[int] = Field(default=512, alias="max_tokens")
+ temperature: Optional[float] = 0.8
+ top_p: Optional[float] = 0.9
+ top_k: Optional[int] = 50
+ stream: Optional[bool] = False
+
+class ChatChoice(BaseModel):
+ index: int = 0
+ message: ChatMessage
+ finish_reason: str = "stop"
+
+class ChatUsage(BaseModel):
+ prompt_tokens: int = 0
+ completion_tokens: int = 0
+ total_tokens: int = 0
+
+class ChatCompletionResponse(BaseModel):
+ id: str
+ object: str = "chat.completion"
+ created: int
+ model: str
+ choices: List[ChatChoice]
+ usage: ChatUsage
+
+
+# ── Helpers ────────────────────────────────────────────────────────────────
+
+def _build_prompt(messages: List[ChatMessage]) -> str:
+ conversation = [{"role": m.role, "content": m.content} for m in messages]
+ return _tokenizer.apply_chat_template(
+ conversation=conversation,
+ add_generation_prompt=True,
+ tokenize=False,
+ )
+
+
+def _generate_stream_chunks(request_id, model_name, input_ids, temperature, top_k, top_p, max_tokens):
+ """Yield SSE data chunks for streaming responses."""
+ _model.reset_kvcache()
+
+ for token_id in _model.generate_stream(
+ input_ids,
+ max_new_tokens=max_tokens,
+ temperature=temperature,
+ top_k=top_k,
+ top_p=top_p,
+ ):
+ text = _tokenizer.decode([token_id], skip_special_tokens=True)
+ if not text:
+ continue
+ chunk = {
+ "id": request_id,
+ "object": "chat.completion.chunk",
+ "created": int(time.time()),
+ "model": model_name,
+ "choices": [{
+ "index": 0,
+ "delta": {"content": text},
+ "finish_reason": None,
+ }],
+ }
+ yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
+
+ done_chunk = {
+ "id": request_id,
+ "object": "chat.completion.chunk",
+ "created": int(time.time()),
+ "model": model_name,
+ "choices": [{
+ "index": 0,
+ "delta": {},
+ "finish_reason": "stop",
+ }],
+ }
+ yield f"data: {json.dumps(done_chunk, ensure_ascii=False)}\n\n"
+ yield "data: [DONE]\n\n"
+
+
+# ── Routes ─────────────────────────────────────────────────────────────────
+
+def _find_static_dir() -> Path:
+ candidates = [
+ Path(__file__).parent / "static",
+ Path(__file__).resolve().parent / "static",
+ Path(__file__).resolve().parent.parent.parent / "python" / "llaisys" / "static",
+ ]
+ for c in candidates:
+ if (c / "index.html").is_file():
+ return c
+ return candidates[0]
+
+_static_dir = _find_static_dir()
+
+
+@app.get("/")
+async def index():
+ html_path = _static_dir / "index.html"
+ if not html_path.is_file():
+ return HTMLResponse("
index.html not found
", status_code=500)
+ return FileResponse(html_path)
+
+
+@app.get("/v1/models")
+async def list_models():
+ return {
+ "object": "list",
+ "data": [{"id": _model_name, "object": "model", "owned_by": "llaisys"}],
+ }
+
+
+@app.post("/v1/chat/completions")
+async def chat_completions(req: ChatCompletionRequest):
+ if _model is None:
+ raise HTTPException(status_code=503, detail="Model not loaded")
+
+ prompt_text = _build_prompt(req.messages)
+ input_ids = _tokenizer.encode(prompt_text)
+
+ request_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
+ temperature = req.temperature or 0.8
+ top_k = req.top_k or 50
+ top_p = req.top_p or 0.9
+ max_tokens = req.max_tokens or 512
+
+ if req.stream:
+ def locked_stream():
+ with _lock:
+ yield from _generate_stream_chunks(
+ request_id, req.model, input_ids,
+ temperature, top_k, top_p, max_tokens,
+ )
+ return StreamingResponse(
+ locked_stream(),
+ media_type="text/event-stream",
+ headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
+ )
+
+ with _lock:
+ _model.reset_kvcache()
+ output_tokens = _model.generate(
+ input_ids,
+ max_new_tokens=max_tokens,
+ temperature=temperature,
+ top_k=top_k,
+ top_p=top_p,
+ )
+
+ new_tokens = output_tokens[len(input_ids):]
+ text = _tokenizer.decode(new_tokens, skip_special_tokens=True)
+
+ return ChatCompletionResponse(
+ id=request_id,
+ created=int(time.time()),
+ model=req.model,
+ choices=[ChatChoice(message=ChatMessage(role="assistant", content=text))],
+ usage=ChatUsage(
+ prompt_tokens=len(input_ids),
+ completion_tokens=len(new_tokens),
+ total_tokens=len(output_tokens),
+ ),
+ )
+
+
+# ── Server bootstrap ──────────────────────────────────────────────────────
+
+def init_model(model_path: str, device: str = "cpu"):
+ global _model, _tokenizer, _model_name
+ device_type = DeviceType.CPU if device == "cpu" else DeviceType.NVIDIA
+
+ from huggingface_hub import snapshot_download
+ local_path = snapshot_download(model_path)
+
+ print(f"Loading tokenizer from {local_path} ...")
+ _tokenizer = AutoTokenizer.from_pretrained(local_path, trust_remote_code=True)
+ print(f"Loading LLAISYS model on {device} from {local_path} ...")
+ _model = Qwen2(local_path, device_type)
+ print("Model loaded.")
+
+
+def main():
+ parser = argparse.ArgumentParser(description="LLAISYS Chat Server")
+ parser.add_argument("--model", required=True, type=str, help="Path to model directory")
+ parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia"], type=str)
+ parser.add_argument("--host", default="0.0.0.0", type=str)
+ parser.add_argument("--port", default=8000, type=int)
+ args = parser.parse_args()
+
+ init_model(args.model, args.device)
+
+ import uvicorn
+ uvicorn.run(app, host=args.host, port=args.port)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/python/llaisys/static/index.html b/python/llaisys/static/index.html
new file mode 100644
index 000000000..482d27903
--- /dev/null
+++ b/python/llaisys/static/index.html
@@ -0,0 +1,345 @@
+
+
+
+
+
+LLAISYS Chat
+
+
+
+
+
+ LLAISYS Chat
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/src/core/context/context.cpp b/src/core/context/context.cpp
index 44894b9e7..7b9fd1ae2 100644
--- a/src/core/context/context.cpp
+++ b/src/core/context/context.cpp
@@ -50,10 +50,16 @@ Context::~Context() {
}
void Context::setDevice(llaisysDeviceType_t device_type, int device_id) {
- // If doest not match the current runtime.
if (_current_runtime == nullptr || _current_runtime->deviceType() != device_type || _current_runtime->deviceId() != device_id) {
- auto runtimes = _runtime_map[device_type];
- CHECK_ARGUMENT((size_t)device_id < runtimes.size() && device_id >= 0, "invalid device id");
+ auto &runtimes = _runtime_map[device_type];
+
+ if ((size_t)device_id >= runtimes.size()) {
+ const LlaisysRuntimeAPI *api_ = llaisysGetRuntimeAPI(device_type);
+ int device_count = api_->get_device_count();
+ CHECK_ARGUMENT(device_id >= 0 && device_id < device_count, "invalid device id");
+ runtimes.resize(device_count, nullptr);
+ }
+
if (_current_runtime != nullptr) {
_current_runtime->_deactivate();
}
diff --git a/src/device/nvidia/nvidia_runtime_api.cu b/src/device/nvidia/nvidia_runtime_api.cu
index cab928261..7d29445dc 100644
--- a/src/device/nvidia/nvidia_runtime_api.cu
+++ b/src/device/nvidia/nvidia_runtime_api.cu
@@ -1,56 +1,87 @@
#include "../runtime_api.hpp"
-#include
-#include
+#include
+#include
+
+#define CUDA_CHECK(call) \
+ do { \
+ cudaError_t err = (call); \
+ if (err != cudaSuccess) { \
+ fprintf(stderr, "[CUDA ERROR] %s at %s:%d\n", \
+ cudaGetErrorString(err), __FILE__, __LINE__); \
+ throw std::runtime_error(cudaGetErrorString(err)); \
+ } \
+ } while (0)
namespace llaisys::device::nvidia {
+static cudaMemcpyKind toCudaMemcpyKind(llaisysMemcpyKind_t kind) {
+ switch (kind) {
+ case LLAISYS_MEMCPY_H2H: return cudaMemcpyHostToHost;
+ case LLAISYS_MEMCPY_H2D: return cudaMemcpyHostToDevice;
+ case LLAISYS_MEMCPY_D2H: return cudaMemcpyDeviceToHost;
+ case LLAISYS_MEMCPY_D2D: return cudaMemcpyDeviceToDevice;
+ default: return cudaMemcpyDefault;
+ }
+}
+
namespace runtime_api {
+
int getDeviceCount() {
- TO_BE_IMPLEMENTED();
+ int count = 0;
+ CUDA_CHECK(cudaGetDeviceCount(&count));
+ return count;
}
-void setDevice(int) {
- TO_BE_IMPLEMENTED();
+void setDevice(int device) {
+ CUDA_CHECK(cudaSetDevice(device));
}
void deviceSynchronize() {
- TO_BE_IMPLEMENTED();
+ CUDA_CHECK(cudaDeviceSynchronize());
}
llaisysStream_t createStream() {
- TO_BE_IMPLEMENTED();
+ cudaStream_t stream;
+ CUDA_CHECK(cudaStreamCreate(&stream));
+ return reinterpret_cast(stream);
}
void destroyStream(llaisysStream_t stream) {
- TO_BE_IMPLEMENTED();
+ CUDA_CHECK(cudaStreamDestroy(reinterpret_cast(stream)));
}
+
void streamSynchronize(llaisysStream_t stream) {
- TO_BE_IMPLEMENTED();
+ CUDA_CHECK(cudaStreamSynchronize(reinterpret_cast(stream)));
}
void *mallocDevice(size_t size) {
- TO_BE_IMPLEMENTED();
+ void *ptr = nullptr;
+ CUDA_CHECK(cudaMalloc(&ptr, size));
+ return ptr;
}
void freeDevice(void *ptr) {
- TO_BE_IMPLEMENTED();
+ CUDA_CHECK(cudaFree(ptr));
}
void *mallocHost(size_t size) {
- TO_BE_IMPLEMENTED();
+ void *ptr = nullptr;
+ CUDA_CHECK(cudaMallocHost(&ptr, size));
+ return ptr;
}
void freeHost(void *ptr) {
- TO_BE_IMPLEMENTED();
+ CUDA_CHECK(cudaFreeHost(ptr));
}
void memcpySync(void *dst, const void *src, size_t size, llaisysMemcpyKind_t kind) {
- TO_BE_IMPLEMENTED();
+ CUDA_CHECK(cudaMemcpy(dst, src, size, toCudaMemcpyKind(kind)));
}
-void memcpyAsync(void *dst, const void *src, size_t size, llaisysMemcpyKind_t kind) {
- TO_BE_IMPLEMENTED();
+void memcpyAsync(void *dst, const void *src, size_t size, llaisysMemcpyKind_t kind, llaisysStream_t stream) {
+ CUDA_CHECK(cudaMemcpyAsync(dst, src, size, toCudaMemcpyKind(kind),
+ reinterpret_cast(stream)));
}
static const LlaisysRuntimeAPI RUNTIME_API = {
diff --git a/src/device/runtime_api.hpp b/src/device/runtime_api.hpp
index e6b9f80d6..29f22288b 100644
--- a/src/device/runtime_api.hpp
+++ b/src/device/runtime_api.hpp
@@ -1,4 +1,5 @@
#pragma once
+#include "llaisys/build_config.h"
#include "llaisys/runtime.h"
#include "../utils.hpp"
diff --git a/src/llaisys/ops.cc b/src/llaisys/ops.cc
index c99fbc32f..ca86ace9d 100644
--- a/src/llaisys/ops.cc
+++ b/src/llaisys/ops.cc
@@ -11,6 +11,7 @@
#include "../ops/rope/op.hpp"
#include "../ops/self_attention/op.hpp"
#include "../ops/swiglu/op.hpp"
+#include "../ops/sample/op.hpp"
__C {
void llaisysAdd(llaisysTensor_t c, llaisysTensor_t a, llaisysTensor_t b) {
@@ -23,7 +24,7 @@ __C {
llaisys::ops::embedding(out->tensor, index->tensor, weight->tensor);
}
void llaisysLinear(llaisysTensor_t out, llaisysTensor_t in, llaisysTensor_t weight, llaisysTensor_t bias) {
- llaisys::ops::linear(out->tensor, in->tensor, weight->tensor, bias->tensor);
+ llaisys::ops::linear(out->tensor, in->tensor, weight->tensor, bias ? bias->tensor : nullptr);
}
void llaisysRearrange(llaisysTensor_t out, llaisysTensor_t in) {
llaisys::ops::rearrange(out->tensor, in->tensor);
@@ -40,4 +41,7 @@ __C {
void llaisysSwiGLU(llaisysTensor_t out, llaisysTensor_t gate, llaisysTensor_t up) {
llaisys::ops::swiglu(out->tensor, gate->tensor, up->tensor);
}
+ void llaisysSample(llaisysTensor_t out_idx, llaisysTensor_t logits, float temperature, int top_k, float top_p) {
+ llaisys::ops::sample(out_idx->tensor, logits->tensor, temperature, top_k, top_p);
+ }
}
diff --git a/src/llaisys/qwen2.cc b/src/llaisys/qwen2.cc
new file mode 100644
index 000000000..24803aab9
--- /dev/null
+++ b/src/llaisys/qwen2.cc
@@ -0,0 +1,147 @@
+#include "llaisys/models/qwen2.h"
+#include "llaisys_tensor.hpp"
+#include "../models/qwen2.hpp"
+
+#include
+
+__C {
+ struct LlaisysQwen2Model {
+ llaisys::models::Qwen2Model *model;
+ LlaisysQwen2Weights c_weights;
+ std::vector attn_norm_w_ptrs;
+ std::vector attn_q_w_ptrs, attn_q_b_ptrs;
+ std::vector attn_k_w_ptrs, attn_k_b_ptrs;
+ std::vector attn_v_w_ptrs, attn_v_b_ptrs;
+ std::vector attn_o_w_ptrs;
+ std::vector mlp_norm_w_ptrs;
+ std::vector mlp_gate_w_ptrs, mlp_up_w_ptrs, mlp_down_w_ptrs;
+ // Use deque to avoid pointer invalidation on push_back
+ std::deque tensor_store;
+ };
+
+ struct LlaisysQwen2Model *llaisysQwen2ModelCreate(
+ const LlaisysQwen2Meta *meta,
+ llaisysDeviceType_t device,
+ int *device_ids,
+ int ndevice) {
+
+ llaisys::models::Qwen2Config config;
+ config.dtype = meta->dtype;
+ config.nlayer = meta->nlayer;
+ config.hs = meta->hs;
+ config.nh = meta->nh;
+ config.nkvh = meta->nkvh;
+ config.dh = meta->dh;
+ config.di = meta->di;
+ config.maxseq = meta->maxseq;
+ config.voc = meta->voc;
+ config.epsilon = meta->epsilon;
+ config.theta = meta->theta;
+ config.end_token = meta->end_token;
+
+ int device_id = (ndevice > 0) ? device_ids[0] : 0;
+
+ auto *w = new LlaisysQwen2Model();
+ w->model = new llaisys::models::Qwen2Model(config, device, device_id);
+
+ auto &weights = w->model->weights();
+ size_t nlayer = config.nlayer;
+ size_t hs = config.hs, nh = config.nh, nkvh = config.nkvh;
+ size_t dh = config.dh, di = config.di, voc = config.voc;
+ auto dtype = config.dtype;
+
+ auto wrap = [&](llaisys::tensor_t t) -> llaisysTensor_t {
+ w->tensor_store.push_back(LlaisysTensor{t});
+ return &w->tensor_store.back();
+ };
+
+ weights.in_embed = llaisys::Tensor::create({voc, hs}, dtype, device, device_id);
+ weights.out_embed = llaisys::Tensor::create({voc, hs}, dtype, device, device_id);
+ weights.out_norm_w = llaisys::Tensor::create({hs}, dtype, device, device_id);
+
+ w->c_weights.in_embed = wrap(weights.in_embed);
+ w->c_weights.out_embed = wrap(weights.out_embed);
+ w->c_weights.out_norm_w = wrap(weights.out_norm_w);
+
+ w->attn_norm_w_ptrs.resize(nlayer);
+ w->attn_q_w_ptrs.resize(nlayer);
+ w->attn_q_b_ptrs.resize(nlayer);
+ w->attn_k_w_ptrs.resize(nlayer);
+ w->attn_k_b_ptrs.resize(nlayer);
+ w->attn_v_w_ptrs.resize(nlayer);
+ w->attn_v_b_ptrs.resize(nlayer);
+ w->attn_o_w_ptrs.resize(nlayer);
+ w->mlp_norm_w_ptrs.resize(nlayer);
+ w->mlp_gate_w_ptrs.resize(nlayer);
+ w->mlp_up_w_ptrs.resize(nlayer);
+ w->mlp_down_w_ptrs.resize(nlayer);
+
+ for (size_t i = 0; i < nlayer; i++) {
+ auto &lw = weights.layers[i];
+ lw.attn_norm_w = llaisys::Tensor::create({hs}, dtype, device, device_id);
+ lw.attn_q_w = llaisys::Tensor::create({nh * dh, hs}, dtype, device, device_id);
+ lw.attn_q_b = llaisys::Tensor::create({nh * dh}, dtype, device, device_id);
+ lw.attn_k_w = llaisys::Tensor::create({nkvh * dh, hs}, dtype, device, device_id);
+ lw.attn_k_b = llaisys::Tensor::create({nkvh * dh}, dtype, device, device_id);
+ lw.attn_v_w = llaisys::Tensor::create({nkvh * dh, hs}, dtype, device, device_id);
+ lw.attn_v_b = llaisys::Tensor::create({nkvh * dh}, dtype, device, device_id);
+ lw.attn_o_w = llaisys::Tensor::create({hs, nh * dh}, dtype, device, device_id);
+ lw.mlp_norm_w = llaisys::Tensor::create({hs}, dtype, device, device_id);
+ lw.mlp_gate_w = llaisys::Tensor::create({di, hs}, dtype, device, device_id);
+ lw.mlp_up_w = llaisys::Tensor::create({di, hs}, dtype, device, device_id);
+ lw.mlp_down_w = llaisys::Tensor::create({hs, di}, dtype, device, device_id);
+
+ w->attn_norm_w_ptrs[i] = wrap(lw.attn_norm_w);
+ w->attn_q_w_ptrs[i] = wrap(lw.attn_q_w);
+ w->attn_q_b_ptrs[i] = wrap(lw.attn_q_b);
+ w->attn_k_w_ptrs[i] = wrap(lw.attn_k_w);
+ w->attn_k_b_ptrs[i] = wrap(lw.attn_k_b);
+ w->attn_v_w_ptrs[i] = wrap(lw.attn_v_w);
+ w->attn_v_b_ptrs[i] = wrap(lw.attn_v_b);
+ w->attn_o_w_ptrs[i] = wrap(lw.attn_o_w);
+ w->mlp_norm_w_ptrs[i] = wrap(lw.mlp_norm_w);
+ w->mlp_gate_w_ptrs[i] = wrap(lw.mlp_gate_w);
+ w->mlp_up_w_ptrs[i] = wrap(lw.mlp_up_w);
+ w->mlp_down_w_ptrs[i] = wrap(lw.mlp_down_w);
+ }
+
+ w->c_weights.attn_norm_w = w->attn_norm_w_ptrs.data();
+ w->c_weights.attn_q_w = w->attn_q_w_ptrs.data();
+ w->c_weights.attn_q_b = w->attn_q_b_ptrs.data();
+ w->c_weights.attn_k_w = w->attn_k_w_ptrs.data();
+ w->c_weights.attn_k_b = w->attn_k_b_ptrs.data();
+ w->c_weights.attn_v_w = w->attn_v_w_ptrs.data();
+ w->c_weights.attn_v_b = w->attn_v_b_ptrs.data();
+ w->c_weights.attn_o_w = w->attn_o_w_ptrs.data();
+ w->c_weights.mlp_norm_w = w->mlp_norm_w_ptrs.data();
+ w->c_weights.mlp_gate_w = w->mlp_gate_w_ptrs.data();
+ w->c_weights.mlp_up_w = w->mlp_up_w_ptrs.data();
+ w->c_weights.mlp_down_w = w->mlp_down_w_ptrs.data();
+
+ return w;
+ }
+
+ void llaisysQwen2ModelDestroy(struct LlaisysQwen2Model *model) {
+ if (model) {
+ delete model->model;
+ delete model;
+ }
+ }
+
+ struct LlaisysQwen2Weights *llaisysQwen2ModelWeights(struct LlaisysQwen2Model *model) {
+ return &model->c_weights;
+ }
+
+ int64_t llaisysQwen2ModelInfer(struct LlaisysQwen2Model *model, int64_t *token_ids, size_t ntoken) {
+ return model->model->infer(token_ids, ntoken);
+ }
+
+ int64_t llaisysQwen2ModelInferSample(struct LlaisysQwen2Model *model, int64_t *token_ids, size_t ntoken,
+ float temperature, int top_k, float top_p) {
+ return model->model->infer_sample(token_ids, ntoken, temperature, top_k, top_p);
+ }
+
+ void llaisysQwen2ModelResetKVCache(struct LlaisysQwen2Model *model) {
+ model->model->reset_kvcache();
+ }
+}
diff --git a/src/models/qwen2.cpp b/src/models/qwen2.cpp
new file mode 100644
index 000000000..f73f1e793
--- /dev/null
+++ b/src/models/qwen2.cpp
@@ -0,0 +1,180 @@
+#include "qwen2.hpp"
+#include "../core/llaisys_core.hpp"
+#include "../utils.hpp"
+
+#include
+#include
+#include
+
+namespace llaisys::models {
+
+Qwen2Model::Qwen2Model(const Qwen2Config &config, llaisysDeviceType_t device_type, int device_id)
+ : _config(config), _device_type(device_type), _device_id(device_id) {
+
+ core::context().setDevice(_device_type, _device_id);
+
+ _weights.layers.resize(config.nlayer);
+
+ _kvcache.resize(config.nlayer);
+ for (size_t i = 0; i < config.nlayer; i++) {
+ _kvcache[i].k = _alloc({config.maxseq, config.nkvh, config.dh});
+ _kvcache[i].v = _alloc({config.maxseq, config.nkvh, config.dh});
+ _kvcache[i].len = 0;
+ }
+}
+
+tensor_t Qwen2Model::_alloc(const std::vector &shape) {
+ return Tensor::create(shape, _config.dtype, _device_type, _device_id);
+}
+
+tensor_t Qwen2Model::_alloc(const std::vector &shape, llaisysDataType_t dtype) {
+ return Tensor::create(shape, dtype, _device_type, _device_id);
+}
+
+void Qwen2Model::_copy_into(tensor_t dst, size_t dst_offset_elems, tensor_t src) {
+ size_t bytes = src->numel() * src->elementSize();
+ size_t offset_bytes = dst_offset_elems * dst->elementSize();
+ auto &rt = core::context().runtime();
+ rt.api()->memcpy_async(
+ dst->data() + offset_bytes, src->data(), bytes, LLAISYS_MEMCPY_D2D, rt.stream());
+}
+
+void Qwen2Model::_ensure_workspace(size_t seqlen) {
+ if (_ws.seqlen == seqlen) return;
+ _ws.seqlen = seqlen;
+
+ auto &c = _config;
+ _ws.input_ids = _alloc({seqlen}, LLAISYS_DTYPE_I64);
+ _ws.pos_ids = _alloc({seqlen}, LLAISYS_DTYPE_I64);
+ _ws.hidden = _alloc({seqlen, c.hs});
+ _ws.normed = _alloc({seqlen, c.hs});
+ _ws.q_proj = _alloc({seqlen, c.nh * c.dh});
+ _ws.k_proj = _alloc({seqlen, c.nkvh * c.dh});
+ _ws.v_proj = _alloc({seqlen, c.nkvh * c.dh});
+ _ws.attn_out_flat = _alloc({seqlen, c.nh * c.dh});
+ _ws.attn_projected = _alloc({seqlen, c.hs});
+ _ws.gate_buf = _alloc({seqlen, c.di});
+ _ws.up_buf = _alloc({seqlen, c.di});
+ _ws.swiglu_out = _alloc({seqlen, c.di});
+ _ws.mlp_out = _alloc({seqlen, c.hs});
+ _ws.residual = _alloc({seqlen, c.hs});
+ _ws.q_rope = _alloc({seqlen, c.nh, c.dh});
+ _ws.k_rope = _alloc({seqlen, c.nkvh, c.dh});
+ _ws.attn_val = _alloc({seqlen, c.nh, c.dh});
+ _ws.logits = _alloc({1, c.voc});
+ _ws.max_idx = _alloc({1}, LLAISYS_DTYPE_I64);
+ _ws.max_val = _alloc({1});
+ _ws.sampled_idx = _alloc({1}, LLAISYS_DTYPE_I64);
+}
+
+void Qwen2Model::reset_kvcache() {
+ for (auto &kv : _kvcache) {
+ kv.len = 0;
+ }
+}
+
+tensor_t Qwen2Model::forward(const int64_t *token_ids, size_t ntoken) {
+ core::context().setDevice(_device_type, _device_id);
+
+ auto &cfg = _config;
+ size_t seqlen = ntoken;
+ size_t nh = cfg.nh;
+ size_t nkvh = cfg.nkvh;
+ size_t dh = cfg.dh;
+
+ _ensure_workspace(seqlen);
+
+ _ws.input_ids->load(token_ids);
+
+ size_t start_pos = _kvcache[0].len;
+ std::vector pos_data(seqlen);
+ for (size_t i = 0; i < seqlen; i++) {
+ pos_data[i] = static_cast(start_pos + i);
+ }
+ _ws.pos_ids->load(pos_data.data());
+
+ ops::embedding(_ws.hidden, _ws.input_ids, _weights.in_embed);
+
+ for (size_t layer = 0; layer < cfg.nlayer; layer++) {
+ auto &lw = _weights.layers[layer];
+ auto &kv = _kvcache[layer];
+
+ ops::rms_norm(_ws.normed, _ws.hidden, lw.attn_norm_w, cfg.epsilon);
+
+ ops::linear(_ws.q_proj, _ws.normed, lw.attn_q_w, lw.attn_q_b);
+ ops::linear(_ws.k_proj, _ws.normed, lw.attn_k_w, lw.attn_k_b);
+ ops::linear(_ws.v_proj, _ws.normed, lw.attn_v_w, lw.attn_v_b);
+
+ auto q = _ws.q_proj->view({seqlen, nh, dh});
+ auto k_new = _ws.k_proj->view({seqlen, nkvh, dh});
+ auto v_new = _ws.v_proj->view({seqlen, nkvh, dh});
+
+ ops::rope(_ws.q_rope, q, _ws.pos_ids, cfg.theta);
+ ops::rope(_ws.k_rope, k_new, _ws.pos_ids, cfg.theta);
+
+ size_t kv_offset = kv.len * nkvh * dh;
+ _copy_into(kv.k, kv_offset, _ws.k_rope);
+ _copy_into(kv.v, kv_offset, v_new);
+
+ size_t total_len = kv.len + seqlen;
+
+ auto k_full = kv.k->slice(0, 0, total_len);
+ auto v_full = kv.v->slice(0, 0, total_len);
+
+ float scale = 1.0f / std::sqrt(static_cast(dh));
+ ops::self_attention(_ws.attn_val, _ws.q_rope, k_full, v_full, scale);
+
+ auto attn_flat = _ws.attn_val->view({seqlen, nh * dh});
+ ops::linear(_ws.attn_projected, attn_flat, lw.attn_o_w, nullptr);
+
+ ops::add(_ws.residual, _ws.hidden, _ws.attn_projected);
+
+ ops::rms_norm(_ws.normed, _ws.residual, lw.mlp_norm_w, cfg.epsilon);
+
+ ops::linear(_ws.gate_buf, _ws.normed, lw.mlp_gate_w, nullptr);
+ ops::linear(_ws.up_buf, _ws.normed, lw.mlp_up_w, nullptr);
+ ops::swiglu(_ws.swiglu_out, _ws.gate_buf, _ws.up_buf);
+ ops::linear(_ws.mlp_out, _ws.swiglu_out, lw.mlp_down_w, nullptr);
+
+ ops::add(_ws.hidden, _ws.residual, _ws.mlp_out);
+
+ kv.len = total_len;
+ }
+
+ ops::rms_norm(_ws.normed, _ws.hidden, _weights.out_norm_w, cfg.epsilon);
+
+ auto last_hidden = _ws.normed->slice(0, seqlen - 1, seqlen);
+
+ ops::linear(_ws.logits, last_hidden, _weights.out_embed, nullptr);
+
+ return _ws.logits;
+}
+
+int64_t Qwen2Model::infer(const int64_t *token_ids, size_t ntoken) {
+ auto logits = forward(token_ids, ntoken);
+
+ _ensure_workspace(ntoken);
+ ops::argmax(_ws.max_idx, _ws.max_val, logits->view({_config.voc}));
+
+ int64_t result = 0;
+ core::context().runtime().api()->memcpy_sync(
+ &result, _ws.max_idx->data(), sizeof(int64_t), LLAISYS_MEMCPY_D2H);
+
+ return result;
+}
+
+int64_t Qwen2Model::infer_sample(const int64_t *token_ids, size_t ntoken,
+ float temperature, int top_k, float top_p) {
+ auto logits = forward(token_ids, ntoken);
+
+ _ensure_workspace(ntoken);
+ ops::sample(_ws.sampled_idx, logits->view({_config.voc}), temperature, top_k, top_p);
+
+ int64_t result = 0;
+ core::context().runtime().api()->memcpy_sync(
+ &result, _ws.sampled_idx->data(), sizeof(int64_t), LLAISYS_MEMCPY_D2H);
+
+ return result;
+}
+
+} // namespace llaisys::models
diff --git a/src/models/qwen2.hpp b/src/models/qwen2.hpp
new file mode 100644
index 000000000..99f64274b
--- /dev/null
+++ b/src/models/qwen2.hpp
@@ -0,0 +1,90 @@
+#pragma once
+
+#include "../tensor/tensor.hpp"
+#include "../ops/add/op.hpp"
+#include "../ops/argmax/op.hpp"
+#include "../ops/embedding/op.hpp"
+#include "../ops/linear/op.hpp"
+#include "../ops/rms_norm/op.hpp"
+#include "../ops/rope/op.hpp"
+#include "../ops/self_attention/op.hpp"
+#include "../ops/swiglu/op.hpp"
+#include "../ops/sample/op.hpp"
+
+#include
+
+namespace llaisys::models {
+
+struct Qwen2Config {
+ llaisysDataType_t dtype;
+ size_t nlayer, hs, nh, nkvh, dh, di, maxseq, voc;
+ float epsilon, theta;
+ int64_t end_token;
+};
+
+struct Qwen2LayerWeights {
+ tensor_t attn_norm_w;
+ tensor_t attn_q_w, attn_q_b;
+ tensor_t attn_k_w, attn_k_b;
+ tensor_t attn_v_w, attn_v_b;
+ tensor_t attn_o_w;
+ tensor_t mlp_norm_w;
+ tensor_t mlp_gate_w, mlp_up_w, mlp_down_w;
+};
+
+struct Qwen2Weights {
+ tensor_t in_embed;
+ tensor_t out_embed;
+ tensor_t out_norm_w;
+ std::vector layers;
+};
+
+struct KVCache {
+ tensor_t k; // [maxseq, nkvh, dh]
+ tensor_t v; // [maxseq, nkvh, dh]
+ size_t len;
+};
+
+struct Qwen2Workspace {
+ size_t seqlen = 0;
+ tensor_t input_ids, pos_ids;
+ tensor_t hidden, normed;
+ tensor_t q_proj, k_proj, v_proj;
+ tensor_t attn_out_flat, attn_projected;
+ tensor_t gate_buf, up_buf, swiglu_out, mlp_out;
+ tensor_t residual;
+ tensor_t q_rope, k_rope, attn_val;
+ tensor_t logits;
+ tensor_t max_idx, max_val, sampled_idx;
+};
+
+class Qwen2Model {
+private:
+ Qwen2Config _config;
+ Qwen2Weights _weights;
+ std::vector _kvcache;
+ llaisysDeviceType_t _device_type;
+ int _device_id;
+ Qwen2Workspace _ws;
+
+ tensor_t _alloc(const std::vector &shape);
+ tensor_t _alloc(const std::vector &shape, llaisysDataType_t dtype);
+
+ void _copy_into(tensor_t dst, size_t dst_offset_elems, tensor_t src);
+ void _ensure_workspace(size_t seqlen);
+
+ tensor_t forward(const int64_t *token_ids, size_t ntoken);
+
+public:
+ Qwen2Model(const Qwen2Config &config, llaisysDeviceType_t device_type, int device_id);
+ ~Qwen2Model() = default;
+
+ Qwen2Weights &weights() { return _weights; }
+
+ int64_t infer(const int64_t *token_ids, size_t ntoken);
+ int64_t infer_sample(const int64_t *token_ids, size_t ntoken,
+ float temperature, int top_k, float top_p);
+ void reset_kvcache();
+};
+
+} // namespace llaisys::models
diff --git a/src/ops/add/cpu/add_cpu.cpp b/src/ops/add/cpu/add_cpu.cpp
index 47f6a3d49..3da1009e0 100644
--- a/src/ops/add/cpu/add_cpu.cpp
+++ b/src/ops/add/cpu/add_cpu.cpp
@@ -1,11 +1,34 @@
+#ifdef __AVX2__
+#include
+#endif
+
#include "add_cpu.hpp"
#include "../../../utils.hpp"
#include
+#ifdef _OPENMP
+#include
+#endif
+
template
void add_(T *c, const T *a, const T *b, size_t numel) {
+#ifdef __AVX2__
+ if constexpr (std::is_same_v) {
+ #pragma omp parallel for schedule(static)
+ for (size_t i = 0; i < numel - (numel % 8); i += 8) {
+ __m256 va = _mm256_loadu_ps(a + i);
+ __m256 vb = _mm256_loadu_ps(b + i);
+ _mm256_storeu_ps(c + i, _mm256_add_ps(va, vb));
+ }
+ for (size_t i = numel - (numel % 8); i < numel; i++) {
+ c[i] = a[i] + b[i];
+ }
+ return;
+ }
+#endif
+ #pragma omp parallel for schedule(static)
for (size_t i = 0; i < numel; i++) {
if constexpr (std::is_same_v || std::is_same_v) {
c[i] = llaisys::utils::cast(llaisys::utils::cast(a[i]) + llaisys::utils::cast(b[i]));
diff --git a/src/ops/add/cuda/add_cuda.cu b/src/ops/add/cuda/add_cuda.cu
new file mode 100644
index 000000000..ab9263483
--- /dev/null
+++ b/src/ops/add/cuda/add_cuda.cu
@@ -0,0 +1,20 @@
+#include "add_cuda.cuh"
+#include "../../cuda_utils.cuh"
+
+__global__ void add_kernel(void *c, const void *a, const void *b,
+ llaisysDataType_t dtype, size_t numel) {
+ size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
+ if (idx >= numel) return;
+
+ float va = load_as_f32(a, idx, dtype);
+ float vb = load_as_f32(b, idx, dtype);
+ store_from_f32(c, idx, va + vb, dtype);
+}
+
+namespace llaisys::ops::cuda {
+void add(std::byte *c, const std::byte *a, const std::byte *b,
+ llaisysDataType_t type, size_t numel) {
+ add_kernel<<>>(c, a, b, type, numel);
+ CUDA_KERNEL_CHECK();
+}
+} // namespace llaisys::ops::cuda
diff --git a/src/ops/add/cuda/add_cuda.cuh b/src/ops/add/cuda/add_cuda.cuh
new file mode 100644
index 000000000..208261877
--- /dev/null
+++ b/src/ops/add/cuda/add_cuda.cuh
@@ -0,0 +1,7 @@
+#pragma once
+#include "llaisys.h"
+#include
+
+namespace llaisys::ops::cuda {
+void add(std::byte *c, const std::byte *a, const std::byte *b, llaisysDataType_t type, size_t size);
+}
diff --git a/src/ops/add/op.cpp b/src/ops/add/op.cpp
index a057330d7..8954eb14c 100644
--- a/src/ops/add/op.cpp
+++ b/src/ops/add/op.cpp
@@ -4,16 +4,17 @@
#include "../../utils.hpp"
#include "cpu/add_cpu.hpp"
+#ifdef ENABLE_NVIDIA_API
+#include "cuda/add_cuda.cuh"
+#endif
namespace llaisys::ops {
void add(tensor_t c, tensor_t a, tensor_t b) {
CHECK_SAME_DEVICE(c, a, b);
- // Only support contiguous inputs with same shape for now.
CHECK_SAME_SHAPE(c->shape(), a->shape(), b->shape());
CHECK_SAME_DTYPE(c->dtype(), a->dtype(), b->dtype());
ASSERT(c->isContiguous() && a->isContiguous() && b->isContiguous(), "Add: all tensors must be contiguous.");
- // always support cpu calculation
if (c->deviceType() == LLAISYS_DEVICE_CPU) {
return cpu::add(c->data(), a->data(), b->data(), c->dtype(), c->numel());
}
@@ -25,8 +26,7 @@ void add(tensor_t c, tensor_t a, tensor_t b) {
return cpu::add(c->data(), a->data(), b->data(), c->dtype(), c->numel());
#ifdef ENABLE_NVIDIA_API
case LLAISYS_DEVICE_NVIDIA:
- TO_BE_IMPLEMENTED();
- return;
+ return cuda::add(c->data(), a->data(), b->data(), c->dtype(), c->numel());
#endif
default:
EXCEPTION_UNSUPPORTED_DEVICE;
diff --git a/src/ops/argmax/cpu/argmax_cpu.cpp b/src/ops/argmax/cpu/argmax_cpu.cpp
new file mode 100644
index 000000000..0ad8fe2b9
--- /dev/null
+++ b/src/ops/argmax/cpu/argmax_cpu.cpp
@@ -0,0 +1,86 @@
+#ifdef __AVX2__
+#include
+#endif
+
+#include "argmax_cpu.hpp"
+
+#include "../../../utils.hpp"
+
+#include
+#include
+
+template
+void argmax_(int64_t *max_idx, T *max_val, const T *vals, size_t numel) {
+ float best = -std::numeric_limits::infinity();
+ int64_t best_idx = 0;
+
+#ifdef __AVX2__
+ if constexpr (std::is_same_v) {
+ if (numel >= 8) {
+ __m256 vbest = _mm256_set1_ps(-std::numeric_limits::infinity());
+ __m256i vidx = _mm256_setzero_si256();
+ __m256i vcur = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
+ __m256i vinc = _mm256_set1_epi32(8);
+
+ size_t i = 0;
+ for (; i + 8 <= numel; i += 8) {
+ __m256 vv = _mm256_loadu_ps(vals + i);
+ __m256 mask = _mm256_cmp_ps(vv, vbest, _CMP_GT_OQ);
+ vbest = _mm256_blendv_ps(vbest, vv, mask);
+ vidx = _mm256_castps_si256(_mm256_blendv_ps(
+ _mm256_castsi256_ps(vidx), _mm256_castsi256_ps(vcur), mask));
+ vcur = _mm256_add_epi32(vcur, vinc);
+ }
+
+ float bests[8];
+ int32_t idxs[8];
+ _mm256_storeu_ps(bests, vbest);
+ _mm256_storeu_si256(reinterpret_cast<__m256i *>(idxs), vidx);
+
+ for (int j = 0; j < 8; j++) {
+ if (bests[j] > best) {
+ best = bests[j];
+ best_idx = idxs[j];
+ }
+ }
+
+ for (; i < numel; i++) {
+ if (vals[i] > best) {
+ best = vals[i];
+ best_idx = static_cast(i);
+ }
+ }
+
+ *max_idx = best_idx;
+ *max_val = static_cast(best);
+ return;
+ }
+ }
+#endif
+
+ for (size_t i = 0; i < numel; i++) {
+ float v = llaisys::utils::cast(vals[i]);
+ if (v > best) {
+ best = v;
+ best_idx = static_cast(i);
+ }
+ }
+ *max_idx = best_idx;
+ *max_val = llaisys::utils::cast(best);
+}
+
+namespace llaisys::ops::cpu {
+void argmax(std::byte *max_idx, std::byte *max_val, const std::byte *vals, llaisysDataType_t type, size_t numel) {
+ auto *idx_ptr = reinterpret_cast(max_idx);
+ switch (type) {
+ case LLAISYS_DTYPE_F32:
+ return argmax_(idx_ptr, reinterpret_cast(max_val), reinterpret_cast(vals), numel);
+ case LLAISYS_DTYPE_BF16:
+ return argmax_(idx_ptr, reinterpret_cast(max_val), reinterpret_cast(vals), numel);
+ case LLAISYS_DTYPE_F16:
+ return argmax_(idx_ptr, reinterpret_cast(max_val), reinterpret_cast(vals), numel);
+ default:
+ EXCEPTION_UNSUPPORTED_DATATYPE(type);
+ }
+}
+} // namespace llaisys::ops::cpu
diff --git a/src/ops/argmax/cpu/argmax_cpu.hpp b/src/ops/argmax/cpu/argmax_cpu.hpp
new file mode 100644
index 000000000..26ae3ef03
--- /dev/null
+++ b/src/ops/argmax/cpu/argmax_cpu.hpp
@@ -0,0 +1,8 @@
+#pragma once
+#include "llaisys.h"
+
+#include
+
+namespace llaisys::ops::cpu {
+void argmax(std::byte *max_idx, std::byte *max_val, const std::byte *vals, llaisysDataType_t type, size_t numel);
+}
diff --git a/src/ops/argmax/cuda/argmax_cuda.cu b/src/ops/argmax/cuda/argmax_cuda.cu
new file mode 100644
index 000000000..0ced106f6
--- /dev/null
+++ b/src/ops/argmax/cuda/argmax_cuda.cu
@@ -0,0 +1,90 @@
+#include "argmax_cuda.cuh"
+#include "../../cuda_utils.cuh"
+
+#include
+
+// Parallel reduction for argmax
+__global__ void argmax_kernel(int64_t *max_idx_out, void *max_val_out,
+ const void *vals, llaisysDataType_t dtype, size_t numel) {
+ extern __shared__ char shared_mem[];
+ float *svals = reinterpret_cast(shared_mem);
+ int *sidxs = reinterpret_cast(shared_mem + blockDim.x * sizeof(float));
+
+ int tid = threadIdx.x;
+ size_t idx = blockIdx.x * blockDim.x + tid;
+
+ svals[tid] = -FLT_MAX;
+ sidxs[tid] = 0;
+
+ if (idx < numel) {
+ svals[tid] = load_as_f32(vals, idx, dtype);
+ sidxs[tid] = idx;
+ }
+ __syncthreads();
+
+ for (int s = blockDim.x / 2; s > 0; s >>= 1) {
+ if (tid < s && svals[tid + s] > svals[tid]) {
+ svals[tid] = svals[tid + s];
+ sidxs[tid] = sidxs[tid + s];
+ }
+ __syncthreads();
+ }
+
+ if (tid == 0) {
+ // Atomic compare: use atomicCAS on a global flag
+ // For single-block case, just write directly
+ // For multi-block, we need a second pass. Simplify: use single block for vocab-sized vectors.
+ max_idx_out[blockIdx.x] = sidxs[0];
+ store_from_f32(max_val_out, blockIdx.x, svals[0], dtype);
+ }
+}
+
+// Second pass: reduce across blocks
+__global__ void argmax_reduce_kernel(int64_t *final_idx, void *final_val,
+ const int64_t *block_idx, const void *block_val,
+ llaisysDataType_t dtype, int nblocks) {
+ float best = -FLT_MAX;
+ int64_t best_idx = 0;
+ for (int i = 0; i < nblocks; i++) {
+ float v = load_as_f32(block_val, i, dtype);
+ if (v > best) {
+ best = v;
+ best_idx = block_idx[i];
+ }
+ }
+ *final_idx = best_idx;
+ store_from_f32(final_val, 0, best, dtype);
+}
+
+namespace llaisys::ops::cuda {
+void argmax(std::byte *max_idx, std::byte *max_val, const std::byte *vals,
+ llaisysDataType_t type, size_t numel) {
+ int block_size = 1024;
+ int nblocks = cuda_grid_size(numel, block_size);
+ size_t shared_size = block_size * (sizeof(float) + sizeof(int));
+
+ if (nblocks == 1) {
+ argmax_kernel<<<1, block_size, shared_size>>>(
+ reinterpret_cast(max_idx), max_val, vals, type, numel);
+ CUDA_KERNEL_CHECK();
+ } else {
+ int64_t *block_idx;
+ std::byte *block_val;
+ size_t val_size = cuda_dsize(type);
+ cudaMalloc(&block_idx, nblocks * sizeof(int64_t));
+ cudaMalloc(&block_val, nblocks * val_size);
+
+ argmax_kernel<<>>(
+ block_idx, block_val, vals, type, numel);
+ CUDA_KERNEL_CHECK();
+
+ argmax_reduce_kernel<<<1, 1>>>(
+ reinterpret_cast(max_idx), max_val,
+ block_idx, block_val, type, nblocks);
+ CUDA_KERNEL_CHECK();
+
+ cudaFree(block_idx);
+ cudaFree(block_val);
+ }
+}
+} // namespace llaisys::ops::cuda
diff --git a/src/ops/argmax/cuda/argmax_cuda.cuh b/src/ops/argmax/cuda/argmax_cuda.cuh
new file mode 100644
index 000000000..179eded8b
--- /dev/null
+++ b/src/ops/argmax/cuda/argmax_cuda.cuh
@@ -0,0 +1,7 @@
+#pragma once
+#include "llaisys.h"
+#include
+
+namespace llaisys::ops::cuda {
+void argmax(std::byte *max_idx, std::byte *max_val, const std::byte *vals, llaisysDataType_t type, size_t numel);
+}
diff --git a/src/ops/argmax/op.cpp b/src/ops/argmax/op.cpp
index 6dc37d426..89c9f3271 100644
--- a/src/ops/argmax/op.cpp
+++ b/src/ops/argmax/op.cpp
@@ -1,7 +1,32 @@
#include "op.hpp"
+#include "../../core/llaisys_core.hpp"
+#include "../../utils.hpp"
+
+#include "cpu/argmax_cpu.hpp"
+#ifdef ENABLE_NVIDIA_API
+#include "cuda/argmax_cuda.cuh"
+#endif
+
namespace llaisys::ops {
void argmax(tensor_t max_idx, tensor_t max_val, tensor_t vals) {
- TO_BE_IMPLEMENTED();
+ ASSERT(vals->isContiguous(), "Argmax: vals must be contiguous.");
+
+ if (vals->deviceType() == LLAISYS_DEVICE_CPU) {
+ return cpu::argmax(max_idx->data(), max_val->data(), vals->data(), vals->dtype(), vals->numel());
+ }
+
+ llaisys::core::context().setDevice(vals->deviceType(), vals->deviceId());
+
+ switch (vals->deviceType()) {
+ case LLAISYS_DEVICE_CPU:
+ return cpu::argmax(max_idx->data(), max_val->data(), vals->data(), vals->dtype(), vals->numel());
+#ifdef ENABLE_NVIDIA_API
+ case LLAISYS_DEVICE_NVIDIA:
+ return cuda::argmax(max_idx->data(), max_val->data(), vals->data(), vals->dtype(), vals->numel());
+#endif
+ default:
+ EXCEPTION_UNSUPPORTED_DEVICE;
+ }
}
} // namespace llaisys::ops
diff --git a/src/ops/cuda_utils.cuh b/src/ops/cuda_utils.cuh
new file mode 100644
index 000000000..59966b323
--- /dev/null
+++ b/src/ops/cuda_utils.cuh
@@ -0,0 +1,91 @@
+#pragma once
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include "llaisys.h"
+
+#define CUDA_KERNEL_CHECK() \
+ do { \
+ cudaError_t err = cudaGetLastError(); \
+ if (err != cudaSuccess) { \
+ fprintf(stderr, "[CUDA KERNEL ERROR] %s at %s:%d\n", \
+ cudaGetErrorString(err), __FILE__, __LINE__); \
+ throw std::runtime_error(cudaGetErrorString(err)); \
+ } \
+ } while (0)
+
+__device__ __forceinline__ float bf16_to_f32(uint16_t v) {
+ __nv_bfloat16 bf;
+ memcpy(&bf, &v, sizeof(uint16_t));
+ return __bfloat162float(bf);
+}
+
+__device__ __forceinline__ uint16_t f32_to_bf16(float v) {
+ __nv_bfloat16 bf = __float2bfloat16(v);
+ uint16_t r;
+ memcpy(&r, &bf, sizeof(uint16_t));
+ return r;
+}
+
+__device__ __forceinline__ float fp16_to_f32(uint16_t v) {
+ __half h;
+ memcpy(&h, &v, sizeof(uint16_t));
+ return __half2float(h);
+}
+
+__device__ __forceinline__ uint16_t f32_to_fp16(float v) {
+ __half h = __float2half(v);
+ uint16_t r;
+ memcpy(&r, &h, sizeof(uint16_t));
+ return r;
+}
+
+__device__ __forceinline__ float load_as_f32(const void *ptr, size_t idx, llaisysDataType_t dtype) {
+ switch (dtype) {
+ case LLAISYS_DTYPE_F32:
+ return reinterpret_cast(ptr)[idx];
+ case LLAISYS_DTYPE_BF16:
+ return bf16_to_f32(reinterpret_cast(ptr)[idx]);
+ case LLAISYS_DTYPE_F16:
+ return fp16_to_f32(reinterpret_cast(ptr)[idx]);
+ default:
+ return 0.0f;
+ }
+}
+
+__device__ __forceinline__ void store_from_f32(void *ptr, size_t idx, float val, llaisysDataType_t dtype) {
+ switch (dtype) {
+ case LLAISYS_DTYPE_F32:
+ reinterpret_cast(ptr)[idx] = val;
+ break;
+ case LLAISYS_DTYPE_BF16:
+ reinterpret_cast(ptr)[idx] = f32_to_bf16(val);
+ break;
+ case LLAISYS_DTYPE_F16:
+ reinterpret_cast(ptr)[idx] = f32_to_fp16(val);
+ break;
+ default:
+ break;
+ }
+}
+
+inline size_t cuda_dsize(llaisysDataType_t dtype) {
+ switch (dtype) {
+ case LLAISYS_DTYPE_F32: return 4;
+ case LLAISYS_DTYPE_BF16: return 2;
+ case LLAISYS_DTYPE_F16: return 2;
+ case LLAISYS_DTYPE_I64: return 8;
+ default: return 0;
+ }
+}
+
+constexpr int CUDA_BLOCK_SIZE = 256;
+
+inline int cuda_grid_size(size_t n, int block_size = CUDA_BLOCK_SIZE) {
+ return static_cast((n + block_size - 1) / block_size);
+}
diff --git a/src/ops/embedding/cpu/embedding_cpu.cpp b/src/ops/embedding/cpu/embedding_cpu.cpp
new file mode 100644
index 000000000..db02da2d9
--- /dev/null
+++ b/src/ops/embedding/cpu/embedding_cpu.cpp
@@ -0,0 +1,24 @@
+#include "embedding_cpu.hpp"
+
+#include "../../../utils.hpp"
+
+#include
+
+#ifdef _OPENMP
+#include
+#endif
+
+namespace llaisys::ops::cpu {
+void embedding(std::byte *out, const std::byte *index, const std::byte *weight,
+ llaisysDataType_t dtype, size_t n_idx, size_t embd_dim) {
+ auto *idx = reinterpret_cast(index);
+ size_t esize = llaisys::utils::dsize(dtype);
+ size_t row_bytes = embd_dim * esize;
+
+ #pragma omp parallel for schedule(static)
+ for (size_t i = 0; i < n_idx; i++) {
+ int64_t row = idx[i];
+ std::memcpy(out + i * row_bytes, weight + row * row_bytes, row_bytes);
+ }
+}
+} // namespace llaisys::ops::cpu
diff --git a/src/ops/embedding/cpu/embedding_cpu.hpp b/src/ops/embedding/cpu/embedding_cpu.hpp
new file mode 100644
index 000000000..933784ce4
--- /dev/null
+++ b/src/ops/embedding/cpu/embedding_cpu.hpp
@@ -0,0 +1,9 @@
+#pragma once
+#include "llaisys.h"
+
+#include
+
+namespace llaisys::ops::cpu {
+void embedding(std::byte *out, const std::byte *index, const std::byte *weight,
+ llaisysDataType_t dtype, size_t n_idx, size_t embd_dim);
+}
diff --git a/src/ops/embedding/cuda/embedding_cuda.cu b/src/ops/embedding/cuda/embedding_cuda.cu
new file mode 100644
index 000000000..259afc92f
--- /dev/null
+++ b/src/ops/embedding/cuda/embedding_cuda.cu
@@ -0,0 +1,33 @@
+#include "embedding_cuda.cuh"
+#include "../../cuda_utils.cuh"
+
+__global__ void embedding_kernel(void *out, const int64_t *index, const void *weight,
+ size_t esize, size_t n_idx, size_t embd_dim) {
+ size_t i = blockIdx.x;
+ size_t j = threadIdx.x + blockIdx.y * blockDim.x;
+ if (i >= n_idx || j >= embd_dim) return;
+
+ int64_t row = index[i];
+ size_t src_off = row * embd_dim * esize + j * esize;
+ size_t dst_off = i * embd_dim * esize + j * esize;
+
+ const char *src = reinterpret_cast(weight) + src_off;
+ char *dst = reinterpret_cast(out) + dst_off;
+
+ for (size_t b = 0; b < esize; b++) {
+ dst[b] = src[b];
+ }
+}
+
+namespace llaisys::ops::cuda {
+void embedding(std::byte *out, const std::byte *index, const std::byte *weight,
+ llaisysDataType_t dtype, size_t n_idx, size_t embd_dim) {
+ size_t esize = cuda_dsize(dtype);
+ int threads_per_block = 256;
+ dim3 grid(n_idx, (embd_dim + threads_per_block - 1) / threads_per_block);
+ dim3 block(threads_per_block);
+ embedding_kernel<<>>(out, reinterpret_cast(index),
+ weight, esize, n_idx, embd_dim);
+ CUDA_KERNEL_CHECK();
+}
+} // namespace llaisys::ops::cuda
diff --git a/src/ops/embedding/cuda/embedding_cuda.cuh b/src/ops/embedding/cuda/embedding_cuda.cuh
new file mode 100644
index 000000000..8ced9b25b
--- /dev/null
+++ b/src/ops/embedding/cuda/embedding_cuda.cuh
@@ -0,0 +1,8 @@
+#pragma once
+#include "llaisys.h"
+#include
+
+namespace llaisys::ops::cuda {
+void embedding(std::byte *out, const std::byte *index, const std::byte *weight,
+ llaisysDataType_t dtype, size_t n_idx, size_t embd_dim);
+}
diff --git a/src/ops/embedding/op.cpp b/src/ops/embedding/op.cpp
index 84b9a5d06..d20075e0a 100644
--- a/src/ops/embedding/op.cpp
+++ b/src/ops/embedding/op.cpp
@@ -1,7 +1,38 @@
#include "op.hpp"
+#include "../../core/llaisys_core.hpp"
+#include "../../utils.hpp"
+
+#include "cpu/embedding_cpu.hpp"
+#ifdef ENABLE_NVIDIA_API
+#include "cuda/embedding_cuda.cuh"
+#endif
+
namespace llaisys::ops {
void embedding(tensor_t out, tensor_t index, tensor_t weight) {
- TO_BE_IMPLEMENTED();
+ ASSERT(index->dtype() == LLAISYS_DTYPE_I64, "Embedding: index must be int64.");
+ ASSERT(weight->ndim() == 2, "Embedding: weight must be 2D.");
+ ASSERT(out->ndim() == 2, "Embedding: out must be 2D.");
+ ASSERT(out->isContiguous() && weight->isContiguous(), "Embedding: tensors must be contiguous.");
+
+ size_t n_idx = index->numel();
+ size_t embd_dim = weight->shape()[1];
+
+ if (out->deviceType() == LLAISYS_DEVICE_CPU) {
+ return cpu::embedding(out->data(), index->data(), weight->data(), weight->dtype(), n_idx, embd_dim);
+ }
+
+ llaisys::core::context().setDevice(out->deviceType(), out->deviceId());
+
+ switch (out->deviceType()) {
+ case LLAISYS_DEVICE_CPU:
+ return cpu::embedding(out->data(), index->data(), weight->data(), weight->dtype(), n_idx, embd_dim);
+#ifdef ENABLE_NVIDIA_API
+ case LLAISYS_DEVICE_NVIDIA:
+ return cuda::embedding(out->data(), index->data(), weight->data(), weight->dtype(), n_idx, embd_dim);
+#endif
+ default:
+ EXCEPTION_UNSUPPORTED_DEVICE;
+ }
}
} // namespace llaisys::ops
diff --git a/src/ops/linear/cpu/linear_cpu.cpp b/src/ops/linear/cpu/linear_cpu.cpp
new file mode 100644
index 000000000..1e30e8903
--- /dev/null
+++ b/src/ops/linear/cpu/linear_cpu.cpp
@@ -0,0 +1,175 @@
+#ifdef __AVX2__
+#include
+#endif
+
+#ifdef USE_OPENBLAS
+#include
+#endif
+
+#include "linear_cpu.hpp"
+
+#include "../../../utils.hpp"
+
+#include
+#include
+#include
+
+#ifdef _OPENMP
+#include
+#endif
+
+#ifdef USE_OPENBLAS
+
+static void linear_f32_blas(float *out, const float *in, const float *weight,
+ const float *bias, size_t M, size_t N, size_t K, bool has_bias) {
+ // out[M,N] = in[M,K] * weight[N,K]^T + bias[N]
+ if (has_bias) {
+ for (size_t m = 0; m < M; m++) {
+ std::memcpy(out + m * N, bias, N * sizeof(float));
+ }
+ scipy_cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
+ (blasint)M, (blasint)N, (blasint)K,
+ 1.0f, in, (blasint)K, weight, (blasint)K,
+ 1.0f, out, (blasint)N);
+ } else {
+ scipy_cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
+ (blasint)M, (blasint)N, (blasint)K,
+ 1.0f, in, (blasint)K, weight, (blasint)K,
+ 0.0f, out, (blasint)N);
+ }
+}
+
+#else // !USE_OPENBLAS
+
+#ifdef __AVX2__
+
+static void linear_f32_avx2(float *out, const float *in, const float *weight,
+ const float *bias, size_t M, size_t N, size_t K, bool has_bias) {
+ #pragma omp parallel for schedule(dynamic)
+ for (size_t m = 0; m < M; m++) {
+ const float *a_row = in + m * K;
+ float *c_row = out + m * N;
+
+ for (size_t n = 0; n < N; n++) {
+ const float *b_row = weight + n * K;
+ __m256 vsum = _mm256_setzero_ps();
+ size_t k = 0;
+
+ for (; k + 8 <= K; k += 8) {
+ __m256 va = _mm256_loadu_ps(a_row + k);
+ __m256 vb = _mm256_loadu_ps(b_row + k);
+ vsum = _mm256_fmadd_ps(va, vb, vsum);
+ }
+
+ float tmp[8];
+ _mm256_storeu_ps(tmp, vsum);
+ float sum = tmp[0] + tmp[1] + tmp[2] + tmp[3] +
+ tmp[4] + tmp[5] + tmp[6] + tmp[7];
+
+ for (; k < K; k++) {
+ sum += a_row[k] * b_row[k];
+ }
+
+ if (has_bias) sum += bias[n];
+ c_row[n] = sum;
+ }
+ }
+}
+
+#endif // __AVX2__
+
+#endif // USE_OPENBLAS
+
+template
+void linear_generic(T *out, const T *in, const T *weight, const T *bias,
+ size_t M, size_t N, size_t K, bool has_bias) {
+ // Convert to F32, compute, convert back
+ std::vector f_in(M * K), f_weight(N * K), f_out(M * N);
+ std::vector f_bias;
+ if (has_bias) f_bias.resize(N);
+
+ #pragma omp parallel for schedule(static)
+ for (size_t i = 0; i < M * K; i++)
+ f_in[i] = llaisys::utils::cast(in[i]);
+
+ #pragma omp parallel for schedule(static)
+ for (size_t i = 0; i < N * K; i++)
+ f_weight[i] = llaisys::utils::cast(weight[i]);
+
+ if (has_bias) {
+ for (size_t i = 0; i < N; i++)
+ f_bias[i] = llaisys::utils::cast(bias[i]);
+ }
+
+#ifdef USE_OPENBLAS
+ linear_f32_blas(f_out.data(), f_in.data(), f_weight.data(),
+ has_bias ? f_bias.data() : nullptr, M, N, K, has_bias);
+#elif defined(__AVX2__)
+ linear_f32_avx2(f_out.data(), f_in.data(), f_weight.data(),
+ has_bias ? f_bias.data() : nullptr, M, N, K, has_bias);
+#else
+ // Fallback naive
+ for (size_t m = 0; m < M; m++) {
+ for (size_t n = 0; n < N; n++) {
+ float sum = 0.0f;
+ for (size_t k = 0; k < K; k++)
+ sum += f_in[m * K + k] * f_weight[n * K + k];
+ if (has_bias) sum += f_bias[n];
+ f_out[m * N + n] = sum;
+ }
+ }
+#endif
+
+ #pragma omp parallel for schedule(static)
+ for (size_t i = 0; i < M * N; i++)
+ out[i] = llaisys::utils::cast(f_out[i]);
+}
+
+static void linear_f32(float *out, const float *in, const float *weight,
+ const float *bias, size_t M, size_t N, size_t K, bool has_bias) {
+#ifdef USE_OPENBLAS
+ linear_f32_blas(out, in, weight, bias, M, N, K, has_bias);
+#elif defined(__AVX2__)
+ linear_f32_avx2(out, in, weight, bias, M, N, K, has_bias);
+#else
+ // Fallback: naive with OpenMP
+ #pragma omp parallel for schedule(dynamic)
+ for (size_t m = 0; m < M; m++) {
+ for (size_t n = 0; n < N; n++) {
+ float sum = 0.0f;
+ for (size_t k = 0; k < K; k++)
+ sum += in[m * K + k] * weight[n * K + k];
+ if (has_bias) sum += bias[n];
+ out[m * N + n] = sum;
+ }
+ }
+#endif
+}
+
+namespace llaisys::ops::cpu {
+void linear(std::byte *out, const std::byte *in, const std::byte *weight, const std::byte *bias,
+ llaisysDataType_t dtype, size_t M, size_t N, size_t K, bool has_bias) {
+ switch (dtype) {
+ case LLAISYS_DTYPE_F32:
+ return linear_f32(reinterpret_cast(out),
+ reinterpret_cast(in),
+ reinterpret_cast(weight),
+ has_bias ? reinterpret_cast(bias) : nullptr,
+ M, N, K, has_bias);
+ case LLAISYS_DTYPE_BF16:
+ return linear_generic(reinterpret_cast(out),
+ reinterpret_cast(in),
+ reinterpret_cast(weight),
+ has_bias ? reinterpret_cast(bias) : nullptr,
+ M, N, K, has_bias);
+ case LLAISYS_DTYPE_F16:
+ return linear_generic(reinterpret_cast(out),
+ reinterpret_cast(in),
+ reinterpret_cast(weight),
+ has_bias ? reinterpret_cast(bias) : nullptr,
+ M, N, K, has_bias);
+ default:
+ EXCEPTION_UNSUPPORTED_DATATYPE(dtype);
+ }
+}
+} // namespace llaisys::ops::cpu
diff --git a/src/ops/linear/cpu/linear_cpu.hpp b/src/ops/linear/cpu/linear_cpu.hpp
new file mode 100644
index 000000000..19ddec8a2
--- /dev/null
+++ b/src/ops/linear/cpu/linear_cpu.hpp
@@ -0,0 +1,9 @@
+#pragma once
+#include "llaisys.h"
+
+#include
+
+namespace llaisys::ops::cpu {
+void linear(std::byte *out, const std::byte *in, const std::byte *weight, const std::byte *bias,
+ llaisysDataType_t dtype, size_t M, size_t N, size_t K, bool has_bias);
+}
diff --git a/src/ops/linear/cuda/linear_cuda.cu b/src/ops/linear/cuda/linear_cuda.cu
new file mode 100644
index 000000000..a35f5127f
--- /dev/null
+++ b/src/ops/linear/cuda/linear_cuda.cu
@@ -0,0 +1,102 @@
+#include "linear_cuda.cuh"
+#include "../../cuda_utils.cuh"
+
+#include
+#include
+
+static cublasHandle_t get_cublas_handle() {
+ static cublasHandle_t handle = nullptr;
+ if (!handle) {
+ cublasStatus_t st = cublasCreate(&handle);
+ if (st != CUBLAS_STATUS_SUCCESS) {
+ fprintf(stderr, "[cuBLAS] cublasCreate failed: %d\n", (int)st);
+ }
+ }
+ return handle;
+}
+
+__global__ void add_bias_kernel(void *out, const void *bias,
+ llaisysDataType_t dtype, size_t M, size_t N) {
+ size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
+ if (idx >= M * N) return;
+ size_t n = idx % N;
+ float val = load_as_f32(out, idx, dtype);
+ float b = load_as_f32(bias, n, dtype);
+ store_from_f32(out, idx, val + b, dtype);
+}
+
+__global__ void convert_to_f32_kernel(float *out, const void *in,
+ llaisysDataType_t dtype, size_t n) {
+ size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
+ if (idx >= n) return;
+ out[idx] = load_as_f32(in, idx, dtype);
+}
+
+__global__ void convert_from_f32_kernel(void *out, const float *in,
+ llaisysDataType_t dtype, size_t n) {
+ size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
+ if (idx >= n) return;
+ store_from_f32(out, idx, in[idx], dtype);
+}
+
+static cudaDataType_t to_cuda_dtype(llaisysDataType_t dtype) {
+ switch (dtype) {
+ case LLAISYS_DTYPE_BF16: return CUDA_R_16BF;
+ case LLAISYS_DTYPE_F16: return CUDA_R_16F;
+ default: return CUDA_R_32F;
+ }
+}
+
+namespace llaisys::ops::cuda {
+void linear(std::byte *out, const std::byte *in, const std::byte *weight, const std::byte *bias,
+ llaisysDataType_t dtype, size_t M, size_t N, size_t K, bool has_bias) {
+ // out[M,N] = in[M,K] * weight[N,K]^T
+ // cuBLAS column-major: C(N,M) = A^T(N,K) * B(K,M)
+ cublasHandle_t handle = get_cublas_handle();
+ float alpha = 1.0f, beta = 0.0f;
+
+ if (dtype == LLAISYS_DTYPE_F16) {
+ // FP16: cublasGemmEx natively supported on all recent GPUs
+ cublasStatus_t st = cublasGemmEx(handle, CUBLAS_OP_T, CUBLAS_OP_N,
+ (int)N, (int)M, (int)K,
+ &alpha,
+ weight, CUDA_R_16F, (int)K,
+ in, CUDA_R_16F, (int)K,
+ &beta,
+ out, CUDA_R_16F, (int)N,
+ CUBLAS_COMPUTE_32F,
+ CUBLAS_GEMM_DEFAULT);
+ if (st != CUBLAS_STATUS_SUCCESS) {
+ fprintf(stderr, "[cuBLAS] GemmEx FP16 failed: %d\n", (int)st);
+ }
+ } else if (dtype == LLAISYS_DTYPE_F32) {
+ cublasSgemm(handle, CUBLAS_OP_T, CUBLAS_OP_N,
+ (int)N, (int)M, (int)K,
+ &alpha,
+ reinterpret_cast(weight), (int)K,
+ reinterpret_cast(in), (int)K,
+ &beta,
+ reinterpret_cast(out), (int)N);
+ } else {
+ // BF16: use cublasGemmEx with native BF16 support (SM 80+, Ampere tensor cores)
+ cudaDataType_t cuda_dt = to_cuda_dtype(dtype);
+ cublasStatus_t st = cublasGemmEx(handle, CUBLAS_OP_T, CUBLAS_OP_N,
+ (int)N, (int)M, (int)K,
+ &alpha,
+ weight, cuda_dt, (int)K,
+ in, cuda_dt, (int)K,
+ &beta,
+ out, cuda_dt, (int)N,
+ CUBLAS_COMPUTE_32F,
+ CUBLAS_GEMM_DEFAULT);
+ if (st != CUBLAS_STATUS_SUCCESS) {
+ fprintf(stderr, "[cuBLAS] GemmEx BF16 failed: %d\n", (int)st);
+ }
+ }
+
+ if (has_bias && bias) {
+ add_bias_kernel<<>>(out, bias, dtype, M, N);
+ CUDA_KERNEL_CHECK();
+ }
+}
+} // namespace llaisys::ops::cuda
diff --git a/src/ops/linear/cuda/linear_cuda.cuh b/src/ops/linear/cuda/linear_cuda.cuh
new file mode 100644
index 000000000..248761923
--- /dev/null
+++ b/src/ops/linear/cuda/linear_cuda.cuh
@@ -0,0 +1,8 @@
+#pragma once
+#include "llaisys.h"
+#include
+
+namespace llaisys::ops::cuda {
+void linear(std::byte *out, const std::byte *in, const std::byte *weight, const std::byte *bias,
+ llaisysDataType_t dtype, size_t M, size_t N, size_t K, bool has_bias);
+}
diff --git a/src/ops/linear/op.cpp b/src/ops/linear/op.cpp
index 97d1f8655..a71cb52ac 100644
--- a/src/ops/linear/op.cpp
+++ b/src/ops/linear/op.cpp
@@ -1,7 +1,45 @@
#include "op.hpp"
+#include "../../core/llaisys_core.hpp"
+#include "../../utils.hpp"
+
+#include "cpu/linear_cpu.hpp"
+#ifdef ENABLE_NVIDIA_API
+#include "cuda/linear_cuda.cuh"
+#endif
+
namespace llaisys::ops {
void linear(tensor_t out, tensor_t in, tensor_t weight, tensor_t bias) {
- TO_BE_IMPLEMENTED();
+ ASSERT(out->ndim() == 2 && in->ndim() == 2 && weight->ndim() == 2,
+ "Linear: out, in, weight must be 2D.");
+ ASSERT(out->isContiguous() && in->isContiguous() && weight->isContiguous(),
+ "Linear: tensors must be contiguous.");
+
+ size_t M = in->shape()[0];
+ size_t K = in->shape()[1];
+ size_t N = weight->shape()[0];
+
+ bool has_bias = (bias != nullptr);
+ const std::byte *bias_data = has_bias ? bias->data() : nullptr;
+
+ if (out->deviceType() == LLAISYS_DEVICE_CPU) {
+ return cpu::linear(out->data(), in->data(), weight->data(), bias_data,
+ out->dtype(), M, N, K, has_bias);
+ }
+
+ llaisys::core::context().setDevice(out->deviceType(), out->deviceId());
+
+ switch (out->deviceType()) {
+ case LLAISYS_DEVICE_CPU:
+ return cpu::linear(out->data(), in->data(), weight->data(), bias_data,
+ out->dtype(), M, N, K, has_bias);
+#ifdef ENABLE_NVIDIA_API
+ case LLAISYS_DEVICE_NVIDIA:
+ return cuda::linear(out->data(), in->data(), weight->data(), bias_data,
+ out->dtype(), M, N, K, has_bias);
+#endif
+ default:
+ EXCEPTION_UNSUPPORTED_DEVICE;
+ }
}
} // namespace llaisys::ops
diff --git a/src/ops/rearrange/cpu/rearrange_cpu.cpp b/src/ops/rearrange/cpu/rearrange_cpu.cpp
new file mode 100644
index 000000000..c63b47354
--- /dev/null
+++ b/src/ops/rearrange/cpu/rearrange_cpu.cpp
@@ -0,0 +1,28 @@
+#include "rearrange_cpu.hpp"
+
+#include
+
+namespace llaisys::ops::cpu {
+void rearrange(std::byte *out, const std::byte *in,
+ const std::vector &shape,
+ const std::vector &out_strides,
+ const std::vector &in_strides,
+ size_t esize, size_t numel) {
+ size_t ndim = shape.size();
+ std::vector idx(ndim, 0);
+
+ for (size_t i = 0; i < numel; ++i) {
+ ptrdiff_t src_off = 0, dst_off = 0;
+ for (size_t d = 0; d < ndim; ++d) {
+ src_off += idx[d] * in_strides[d];
+ dst_off += idx[d] * out_strides[d];
+ }
+ std::memcpy(out + dst_off * esize, in + src_off * esize, esize);
+
+ for (int d = static_cast(ndim) - 1; d >= 0; --d) {
+ if (++idx[d] < shape[d]) break;
+ idx[d] = 0;
+ }
+ }
+}
+} // namespace llaisys::ops::cpu
diff --git a/src/ops/rearrange/cpu/rearrange_cpu.hpp b/src/ops/rearrange/cpu/rearrange_cpu.hpp
new file mode 100644
index 000000000..6ae100852
--- /dev/null
+++ b/src/ops/rearrange/cpu/rearrange_cpu.hpp
@@ -0,0 +1,14 @@
+#pragma once
+#include "llaisys.h"
+
+#include
+#include
+#include
+
+namespace llaisys::ops::cpu {
+void rearrange(std::byte *out, const std::byte *in,
+ const std::vector &shape,
+ const std::vector &out_strides,
+ const std::vector &in_strides,
+ size_t esize, size_t numel);
+}
diff --git a/src/ops/rearrange/cuda/rearrange_cuda.cu b/src/ops/rearrange/cuda/rearrange_cuda.cu
new file mode 100644
index 000000000..243ee2aa9
--- /dev/null
+++ b/src/ops/rearrange/cuda/rearrange_cuda.cu
@@ -0,0 +1,59 @@
+#include "rearrange_cuda.cuh"
+#include "../../cuda_utils.cuh"
+
+#include
+
+// Max supported dimensions for device-side arrays
+#define MAX_DIMS 8
+
+__global__ void rearrange_kernel(void *out, const void *in,
+ const size_t *d_shape,
+ const ptrdiff_t *d_out_strides,
+ const ptrdiff_t *d_in_strides,
+ size_t ndim, size_t esize, size_t numel) {
+ size_t flat_idx = blockIdx.x * blockDim.x + threadIdx.x;
+ if (flat_idx >= numel) return;
+
+ // Convert flat index to multi-dimensional index
+ size_t remaining = flat_idx;
+ ptrdiff_t src_off = 0;
+ ptrdiff_t dst_off = 0;
+ for (size_t d = 0; d < ndim; d++) {
+ size_t prod = 1;
+ for (size_t dd = d + 1; dd < ndim; dd++) prod *= d_shape[dd];
+ size_t coord = remaining / prod;
+ remaining %= prod;
+ src_off += coord * d_in_strides[d];
+ dst_off += coord * d_out_strides[d];
+ }
+
+ const char *src = reinterpret_cast(in) + src_off * esize;
+ char *dst = reinterpret_cast(out) + dst_off * esize;
+ for (size_t b = 0; b < esize; b++) {
+ dst[b] = src[b];
+ }
+}
+
+namespace llaisys::ops::cuda {
+void rearrange(std::byte *out, const std::byte *in,
+ const size_t *shape, const ptrdiff_t *out_strides, const ptrdiff_t *in_strides,
+ size_t ndim, size_t esize, size_t numel) {
+ // Copy shape and strides to device
+ size_t *d_shape;
+ ptrdiff_t *d_out_strides, *d_in_strides;
+ cudaMalloc(&d_shape, ndim * sizeof(size_t));
+ cudaMalloc(&d_out_strides, ndim * sizeof(ptrdiff_t));
+ cudaMalloc(&d_in_strides, ndim * sizeof(ptrdiff_t));
+ cudaMemcpy(d_shape, shape, ndim * sizeof(size_t), cudaMemcpyHostToDevice);
+ cudaMemcpy(d_out_strides, out_strides, ndim * sizeof(ptrdiff_t), cudaMemcpyHostToDevice);
+ cudaMemcpy(d_in_strides, in_strides, ndim * sizeof(ptrdiff_t), cudaMemcpyHostToDevice);
+
+ rearrange_kernel<<>>(
+ out, in, d_shape, d_out_strides, d_in_strides, ndim, esize, numel);
+ CUDA_KERNEL_CHECK();
+
+ cudaFree(d_shape);
+ cudaFree(d_out_strides);
+ cudaFree(d_in_strides);
+}
+} // namespace llaisys::ops::cuda
diff --git a/src/ops/rearrange/cuda/rearrange_cuda.cuh b/src/ops/rearrange/cuda/rearrange_cuda.cuh
new file mode 100644
index 000000000..1a86e2808
--- /dev/null
+++ b/src/ops/rearrange/cuda/rearrange_cuda.cuh
@@ -0,0 +1,10 @@
+#pragma once
+#include "llaisys.h"
+#include
+#include
+
+namespace llaisys::ops::cuda {
+void rearrange(std::byte *out, const std::byte *in,
+ const size_t *shape, const ptrdiff_t *out_strides, const ptrdiff_t *in_strides,
+ size_t ndim, size_t esize, size_t numel);
+}
diff --git a/src/ops/rearrange/op.cpp b/src/ops/rearrange/op.cpp
index 017a6ae59..9cea171b2 100644
--- a/src/ops/rearrange/op.cpp
+++ b/src/ops/rearrange/op.cpp
@@ -1,7 +1,39 @@
#include "op.hpp"
+#include "../../core/llaisys_core.hpp"
+#include "../../utils.hpp"
+
+#include "cpu/rearrange_cpu.hpp"
+#ifdef ENABLE_NVIDIA_API
+#include "cuda/rearrange_cuda.cuh"
+#endif
+
namespace llaisys::ops {
void rearrange(tensor_t out, tensor_t in) {
- TO_BE_IMPLEMENTED();
+ CHECK_SAME_SHAPE(out->shape(), in->shape());
+ CHECK_SAME_DTYPE(out->dtype(), in->dtype());
+
+ if (out->deviceType() == LLAISYS_DEVICE_CPU) {
+ return cpu::rearrange(out->data(), in->data(), out->shape(),
+ out->strides(), in->strides(),
+ out->elementSize(), out->numel());
+ }
+
+ llaisys::core::context().setDevice(out->deviceType(), out->deviceId());
+
+ switch (out->deviceType()) {
+ case LLAISYS_DEVICE_CPU:
+ return cpu::rearrange(out->data(), in->data(), out->shape(),
+ out->strides(), in->strides(),
+ out->elementSize(), out->numel());
+#ifdef ENABLE_NVIDIA_API
+ case LLAISYS_DEVICE_NVIDIA:
+ return cuda::rearrange(out->data(), in->data(),
+ out->shape().data(), out->strides().data(), in->strides().data(),
+ out->ndim(), out->elementSize(), out->numel());
+#endif
+ default:
+ EXCEPTION_UNSUPPORTED_DEVICE;
+ }
}
} // namespace llaisys::ops
diff --git a/src/ops/rms_norm/cpu/rms_norm_cpu.cpp b/src/ops/rms_norm/cpu/rms_norm_cpu.cpp
new file mode 100644
index 000000000..cad37c8e6
--- /dev/null
+++ b/src/ops/rms_norm/cpu/rms_norm_cpu.cpp
@@ -0,0 +1,105 @@
+#ifdef __AVX2__
+#include
+#endif
+
+#include "rms_norm_cpu.hpp"
+
+#include "../../../utils.hpp"
+
+#include
+
+#ifdef _OPENMP
+#include
+#endif
+
+template
+void rms_norm_(T *out, const T *in, const T *weight, float eps, size_t rows, size_t cols) {
+ #pragma omp parallel for schedule(dynamic)
+ for (size_t r = 0; r < rows; r++) {
+ const T *row_in = in + r * cols;
+ T *row_out = out + r * cols;
+
+ float sum_sq = 0.0f;
+
+#ifdef __AVX2__
+ if constexpr (std::is_same_v) {
+ __m256 vsum = _mm256_setzero_ps();
+ size_t c = 0;
+ for (; c + 8 <= cols; c += 8) {
+ __m256 vx = _mm256_loadu_ps(row_in + c);
+ vsum = _mm256_fmadd_ps(vx, vx, vsum);
+ }
+ float tmp[8];
+ _mm256_storeu_ps(tmp, vsum);
+ sum_sq = tmp[0] + tmp[1] + tmp[2] + tmp[3] +
+ tmp[4] + tmp[5] + tmp[6] + tmp[7];
+ for (; c < cols; c++) {
+ float v = row_in[c];
+ sum_sq += v * v;
+ }
+ } else {
+ for (size_t c = 0; c < cols; c++) {
+ float v = llaisys::utils::cast(row_in[c]);
+ sum_sq += v * v;
+ }
+ }
+#else
+ for (size_t c = 0; c < cols; c++) {
+ float v = llaisys::utils::cast(row_in[c]);
+ sum_sq += v * v;
+ }
+#endif
+
+ float rms = 1.0f / std::sqrt(sum_sq / static_cast(cols) + eps);
+
+#ifdef __AVX2__
+ if constexpr (std::is_same_v) {
+ __m256 vrms = _mm256_set1_ps(rms);
+ size_t c = 0;
+ for (; c + 8 <= cols; c += 8) {
+ __m256 vx = _mm256_loadu_ps(row_in + c);
+ __m256 vw = _mm256_loadu_ps(reinterpret_cast(weight) + c);
+ __m256 vout = _mm256_mul_ps(_mm256_mul_ps(vw, vx), vrms);
+ _mm256_storeu_ps(row_out + c, vout);
+ }
+ for (; c < cols; c++) {
+ row_out[c] = weight[c] * row_in[c] * rms;
+ }
+ } else {
+ for (size_t c = 0; c < cols; c++) {
+ float v = llaisys::utils::cast(row_in[c]);
+ float w = llaisys::utils::cast(weight[c]);
+ row_out[c] = llaisys::utils::cast(w * v * rms);
+ }
+ }
+#else
+ for (size_t c = 0; c < cols; c++) {
+ float v = llaisys::utils::cast(row_in[c]);
+ float w = llaisys::utils::cast(weight[c]);
+ row_out[c] = llaisys::utils::cast(w * v * rms);
+ }
+#endif
+ }
+}
+
+namespace llaisys::ops::cpu {
+void rms_norm(std::byte *out, const std::byte *in, const std::byte *weight,
+ float eps, llaisysDataType_t dtype, size_t rows, size_t cols) {
+ switch (dtype) {
+ case LLAISYS_DTYPE_F32:
+ return rms_norm_(reinterpret_cast(out),
+ reinterpret_cast(in),
+ reinterpret_cast(weight), eps, rows, cols);
+ case LLAISYS_DTYPE_BF16:
+ return rms_norm_(reinterpret_cast(out),
+ reinterpret_cast(in),
+ reinterpret_cast(weight), eps, rows, cols);
+ case LLAISYS_DTYPE_F16:
+ return rms_norm_(reinterpret_cast(out),
+ reinterpret_cast(in),
+ reinterpret_cast(weight), eps, rows, cols);
+ default:
+ EXCEPTION_UNSUPPORTED_DATATYPE(dtype);
+ }
+}
+} // namespace llaisys::ops::cpu
diff --git a/src/ops/rms_norm/cpu/rms_norm_cpu.hpp b/src/ops/rms_norm/cpu/rms_norm_cpu.hpp
new file mode 100644
index 000000000..bd5862701
--- /dev/null
+++ b/src/ops/rms_norm/cpu/rms_norm_cpu.hpp
@@ -0,0 +1,9 @@
+#pragma once
+#include "llaisys.h"
+
+#include
+
+namespace llaisys::ops::cpu {
+void rms_norm(std::byte *out, const std::byte *in, const std::byte *weight,
+ float eps, llaisysDataType_t dtype, size_t rows, size_t cols);
+}
diff --git a/src/ops/rms_norm/cuda/rms_norm_cuda.cu b/src/ops/rms_norm/cuda/rms_norm_cuda.cu
new file mode 100644
index 000000000..67a7bf6ee
--- /dev/null
+++ b/src/ops/rms_norm/cuda/rms_norm_cuda.cu
@@ -0,0 +1,51 @@
+#include "rms_norm_cuda.cuh"
+#include "../../cuda_utils.cuh"
+
+#include
+
+// Each block handles one row. Block-level reduction for sum of squares.
+__global__ void rms_norm_kernel(void *out, const void *in, const void *weight,
+ float eps, llaisysDataType_t dtype,
+ size_t rows, size_t cols) {
+ size_t row = blockIdx.x;
+ if (row >= rows) return;
+
+ extern __shared__ float sdata[];
+
+ float local_sum = 0.0f;
+ for (size_t c = threadIdx.x; c < cols; c += blockDim.x) {
+ float v = load_as_f32(in, row * cols + c, dtype);
+ local_sum += v * v;
+ }
+
+ sdata[threadIdx.x] = local_sum;
+ __syncthreads();
+
+ // Block reduction
+ for (int s = blockDim.x / 2; s > 0; s >>= 1) {
+ if (threadIdx.x < s) {
+ sdata[threadIdx.x] += sdata[threadIdx.x + s];
+ }
+ __syncthreads();
+ }
+
+ float rms = rsqrtf(sdata[0] / static_cast(cols) + eps);
+
+ for (size_t c = threadIdx.x; c < cols; c += blockDim.x) {
+ float v = load_as_f32(in, row * cols + c, dtype);
+ float w = load_as_f32(weight, c, dtype);
+ store_from_f32(out, row * cols + c, w * v * rms, dtype);
+ }
+}
+
+namespace llaisys::ops::cuda {
+void rms_norm(std::byte *out, const std::byte *in, const std::byte *weight,
+ float eps, llaisysDataType_t dtype, size_t rows, size_t cols) {
+ int block_size = 256;
+ if (cols > 256) block_size = 512;
+ if (cols > 512) block_size = 1024;
+ size_t shared_mem = block_size * sizeof(float);
+ rms_norm_kernel<<>>(out, in, weight, eps, dtype, rows, cols);
+ CUDA_KERNEL_CHECK();
+}
+} // namespace llaisys::ops::cuda
diff --git a/src/ops/rms_norm/cuda/rms_norm_cuda.cuh b/src/ops/rms_norm/cuda/rms_norm_cuda.cuh
new file mode 100644
index 000000000..96f720800
--- /dev/null
+++ b/src/ops/rms_norm/cuda/rms_norm_cuda.cuh
@@ -0,0 +1,8 @@
+#pragma once
+#include "llaisys.h"
+#include
+
+namespace llaisys::ops::cuda {
+void rms_norm(std::byte *out, const std::byte *in, const std::byte *weight,
+ float eps, llaisysDataType_t dtype, size_t rows, size_t cols);
+}
diff --git a/src/ops/rms_norm/op.cpp b/src/ops/rms_norm/op.cpp
index 529553d9d..778628fdf 100644
--- a/src/ops/rms_norm/op.cpp
+++ b/src/ops/rms_norm/op.cpp
@@ -1,7 +1,36 @@
#include "op.hpp"
+#include "../../core/llaisys_core.hpp"
+#include "../../utils.hpp"
+
+#include "cpu/rms_norm_cpu.hpp"
+#ifdef ENABLE_NVIDIA_API
+#include "cuda/rms_norm_cuda.cuh"
+#endif
+
namespace llaisys::ops {
void rms_norm(tensor_t out, tensor_t in, tensor_t weight, float eps) {
- TO_BE_IMPLEMENTED();
+ ASSERT(out->ndim() == 2 && in->ndim() == 2, "RmsNorm: out and in must be 2D.");
+ ASSERT(out->isContiguous() && in->isContiguous(), "RmsNorm: tensors must be contiguous.");
+
+ size_t rows = in->shape()[0];
+ size_t cols = in->shape()[1];
+
+ if (out->deviceType() == LLAISYS_DEVICE_CPU) {
+ return cpu::rms_norm(out->data(), in->data(), weight->data(), eps, out->dtype(), rows, cols);
+ }
+
+ llaisys::core::context().setDevice(out->deviceType(), out->deviceId());
+
+ switch (out->deviceType()) {
+ case LLAISYS_DEVICE_CPU:
+ return cpu::rms_norm(out->data(), in->data(), weight->data(), eps, out->dtype(), rows, cols);
+#ifdef ENABLE_NVIDIA_API
+ case LLAISYS_DEVICE_NVIDIA:
+ return cuda::rms_norm(out->data(), in->data(), weight->data(), eps, out->dtype(), rows, cols);
+#endif
+ default:
+ EXCEPTION_UNSUPPORTED_DEVICE;
+ }
}
} // namespace llaisys::ops
diff --git a/src/ops/rope/cpu/rope_cpu.cpp b/src/ops/rope/cpu/rope_cpu.cpp
new file mode 100644
index 000000000..269e3c385
--- /dev/null
+++ b/src/ops/rope/cpu/rope_cpu.cpp
@@ -0,0 +1,66 @@
+#include "rope_cpu.hpp"
+
+#include "../../../utils.hpp"
+
+#include
+#include
+
+#ifdef _OPENMP
+#include
+#endif
+
+template
+void rope_(T *out, const T *in, const int64_t *pos_ids,
+ float theta, size_t seqlen, size_t nhead, size_t d) {
+ size_t half_d = d / 2;
+
+ // Precompute theta powers to avoid redundant pow() calls per element
+ std::vector theta_pow(half_d);
+ for (size_t j = 0; j < half_d; j++) {
+ theta_pow[j] = std::pow(theta, 2.0f * static_cast(j) / static_cast(d));
+ }
+
+ #pragma omp parallel for collapse(2) schedule(static)
+ for (size_t s = 0; s < seqlen; s++) {
+ for (size_t h = 0; h < nhead; h++) {
+ float pos = static_cast(pos_ids[s]);
+ const T *x = in + (s * nhead + h) * d;
+ T *y = out + (s * nhead + h) * d;
+ const T *a = x;
+ const T *b = x + half_d;
+ T *a_out = y;
+ T *b_out = y + half_d;
+
+ for (size_t j = 0; j < half_d; j++) {
+ float phi = pos / theta_pow[j];
+ float cos_phi = std::cos(phi);
+ float sin_phi = std::sin(phi);
+ float a_val = llaisys::utils::cast(a[j]);
+ float b_val = llaisys::utils::cast(b[j]);
+ a_out[j] = llaisys::utils::cast(a_val * cos_phi - b_val * sin_phi);
+ b_out[j] = llaisys::utils::cast(b_val * cos_phi + a_val * sin_phi);
+ }
+ }
+ }
+}
+
+namespace llaisys::ops::cpu {
+void rope(std::byte *out, const std::byte *in, const std::byte *pos_ids,
+ float theta, llaisysDataType_t dtype,
+ size_t seqlen, size_t nhead, size_t d) {
+ auto *pids = reinterpret_cast(pos_ids);
+ switch (dtype) {
+ case LLAISYS_DTYPE_F32:
+ return rope_(reinterpret_cast(out), reinterpret_cast(in),
+ pids, theta, seqlen, nhead, d);
+ case LLAISYS_DTYPE_BF16:
+ return rope_(reinterpret_cast(out), reinterpret_cast(in),
+ pids, theta, seqlen, nhead, d);
+ case LLAISYS_DTYPE_F16:
+ return rope_(reinterpret_cast(out), reinterpret_cast(in),
+ pids, theta, seqlen, nhead, d);
+ default:
+ EXCEPTION_UNSUPPORTED_DATATYPE(dtype);
+ }
+}
+} // namespace llaisys::ops::cpu
diff --git a/src/ops/rope/cpu/rope_cpu.hpp b/src/ops/rope/cpu/rope_cpu.hpp
new file mode 100644
index 000000000..7a525eb41
--- /dev/null
+++ b/src/ops/rope/cpu/rope_cpu.hpp
@@ -0,0 +1,10 @@
+#pragma once
+#include "llaisys.h"
+
+#include
+
+namespace llaisys::ops::cpu {
+void rope(std::byte *out, const std::byte *in, const std::byte *pos_ids,
+ float theta, llaisysDataType_t dtype,
+ size_t seqlen, size_t nhead, size_t d);
+}
diff --git a/src/ops/rope/cuda/rope_cuda.cu b/src/ops/rope/cuda/rope_cuda.cu
new file mode 100644
index 000000000..4b4cf4a2e
--- /dev/null
+++ b/src/ops/rope/cuda/rope_cuda.cu
@@ -0,0 +1,41 @@
+#include "rope_cuda.cuh"
+#include "../../cuda_utils.cuh"
+
+// Each thread handles one (seq, head, pair_idx) triple
+__global__ void rope_kernel(void *out, const void *in, const int64_t *pos_ids,
+ float theta, llaisysDataType_t dtype,
+ size_t seqlen, size_t nhead, size_t d) {
+ size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
+ size_t half_d = d / 2;
+ size_t total = seqlen * nhead * half_d;
+ if (idx >= total) return;
+
+ size_t j = idx % half_d;
+ size_t h = (idx / half_d) % nhead;
+ size_t s = idx / (half_d * nhead);
+
+ float pos = static_cast(pos_ids[s]);
+ float theta_pow = powf(theta, 2.0f * static_cast(j) / static_cast(d));
+ float phi = pos / theta_pow;
+ float cos_phi = cosf(phi);
+ float sin_phi = sinf(phi);
+
+ size_t base = (s * nhead + h) * d;
+ float a_val = load_as_f32(in, base + j, dtype);
+ float b_val = load_as_f32(in, base + half_d + j, dtype);
+
+ store_from_f32(out, base + j, a_val * cos_phi - b_val * sin_phi, dtype);
+ store_from_f32(out, base + half_d + j, b_val * cos_phi + a_val * sin_phi, dtype);
+}
+
+namespace llaisys::ops::cuda {
+void rope(std::byte *out, const std::byte *in, const std::byte *pos_ids,
+ float theta, llaisysDataType_t dtype,
+ size_t seqlen, size_t nhead, size_t d) {
+ size_t total = seqlen * nhead * (d / 2);
+ rope_kernel<<>>(
+ out, in, reinterpret_cast(pos_ids),
+ theta, dtype, seqlen, nhead, d);
+ CUDA_KERNEL_CHECK();
+}
+} // namespace llaisys::ops::cuda
diff --git a/src/ops/rope/cuda/rope_cuda.cuh b/src/ops/rope/cuda/rope_cuda.cuh
new file mode 100644
index 000000000..fb8c9014e
--- /dev/null
+++ b/src/ops/rope/cuda/rope_cuda.cuh
@@ -0,0 +1,9 @@
+#pragma once
+#include "llaisys.h"
+#include
+
+namespace llaisys::ops::cuda {
+void rope(std::byte *out, const std::byte *in, const std::byte *pos_ids,
+ float theta, llaisysDataType_t dtype,
+ size_t seqlen, size_t nhead, size_t d);
+}
diff --git a/src/ops/rope/op.cpp b/src/ops/rope/op.cpp
index d60dbe64e..88e133560 100644
--- a/src/ops/rope/op.cpp
+++ b/src/ops/rope/op.cpp
@@ -1,7 +1,38 @@
#include "op.hpp"
+#include "../../core/llaisys_core.hpp"
+#include "../../utils.hpp"
+
+#include "cpu/rope_cpu.hpp"
+#ifdef ENABLE_NVIDIA_API
+#include "cuda/rope_cuda.cuh"
+#endif
+
namespace llaisys::ops {
void rope(tensor_t out, tensor_t in, tensor_t pos_ids, float theta) {
- TO_BE_IMPLEMENTED();
+ ASSERT(out->ndim() == 3 && in->ndim() == 3, "RoPE: out and in must be 3D [seqlen, nhead, d].");
+ ASSERT(out->isContiguous() && in->isContiguous(), "RoPE: tensors must be contiguous.");
+ ASSERT(pos_ids->dtype() == LLAISYS_DTYPE_I64, "RoPE: pos_ids must be int64.");
+
+ size_t seqlen = in->shape()[0];
+ size_t nhead = in->shape()[1];
+ size_t d = in->shape()[2];
+
+ if (out->deviceType() == LLAISYS_DEVICE_CPU) {
+ return cpu::rope(out->data(), in->data(), pos_ids->data(), theta, out->dtype(), seqlen, nhead, d);
+ }
+
+ llaisys::core::context().setDevice(out->deviceType(), out->deviceId());
+
+ switch (out->deviceType()) {
+ case LLAISYS_DEVICE_CPU:
+ return cpu::rope(out->data(), in->data(), pos_ids->data(), theta, out->dtype(), seqlen, nhead, d);
+#ifdef ENABLE_NVIDIA_API
+ case LLAISYS_DEVICE_NVIDIA:
+ return cuda::rope(out->data(), in->data(), pos_ids->data(), theta, out->dtype(), seqlen, nhead, d);
+#endif
+ default:
+ EXCEPTION_UNSUPPORTED_DEVICE;
+ }
}
} // namespace llaisys::ops
diff --git a/src/ops/sample/cpu/sample_cpu.cpp b/src/ops/sample/cpu/sample_cpu.cpp
new file mode 100644
index 000000000..d09ff8daf
--- /dev/null
+++ b/src/ops/sample/cpu/sample_cpu.cpp
@@ -0,0 +1,96 @@
+#ifdef __AVX2__
+#include
+#endif
+
+#include "sample_cpu.hpp"
+
+#include "../../../utils.hpp"
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+static thread_local std::mt19937 rng{std::random_device{}()};
+
+template
+void sample_(int64_t *out_idx, const T *logits, size_t numel,
+ float temperature, int top_k, float top_p) {
+ std::vector probs(numel);
+ for (size_t i = 0; i < numel; i++) {
+ probs[i] = llaisys::utils::cast(logits[i]);
+ }
+
+ if (temperature <= 0.0f) temperature = 1.0f;
+ if (temperature != 1.0f) {
+ for (size_t i = 0; i < numel; i++) {
+ probs[i] /= temperature;
+ }
+ }
+
+ // Build index array sorted by descending logit value
+ std::vector indices(numel);
+ std::iota(indices.begin(), indices.end(), 0);
+ std::sort(indices.begin(), indices.end(),
+ [&](int a, int b) { return probs[a] > probs[b]; });
+
+ // Top-K: keep at most top_k candidates
+ size_t keep = numel;
+ if (top_k > 0 && static_cast(top_k) < numel) {
+ keep = static_cast(top_k);
+ }
+
+ // Softmax over the kept candidates
+ float max_val = probs[indices[0]];
+ std::vector softmax_vals(keep);
+ float sum_exp = 0.0f;
+ for (size_t i = 0; i < keep; i++) {
+ softmax_vals[i] = std::exp(probs[indices[i]] - max_val);
+ sum_exp += softmax_vals[i];
+ }
+ for (size_t i = 0; i < keep; i++) {
+ softmax_vals[i] /= sum_exp;
+ }
+
+ // Top-P (nucleus): find cutoff where cumulative prob >= top_p
+ if (top_p > 0.0f && top_p < 1.0f) {
+ float cumsum = 0.0f;
+ size_t cutoff = keep;
+ for (size_t i = 0; i < keep; i++) {
+ cumsum += softmax_vals[i];
+ if (cumsum >= top_p) {
+ cutoff = i + 1;
+ break;
+ }
+ }
+ keep = cutoff;
+ // Re-normalize
+ float new_sum = 0.0f;
+ for (size_t i = 0; i < keep; i++) new_sum += softmax_vals[i];
+ for (size_t i = 0; i < keep; i++) softmax_vals[i] /= new_sum;
+ }
+
+ // Sample from the distribution
+ std::discrete_distribution dist(softmax_vals.begin(), softmax_vals.begin() + keep);
+ int sampled = dist(rng);
+ *out_idx = static_cast(indices[sampled]);
+}
+
+namespace llaisys::ops::cpu {
+void sample(std::byte *out_idx, const std::byte *logits, llaisysDataType_t type, size_t numel,
+ float temperature, int top_k, float top_p) {
+ auto *idx_ptr = reinterpret_cast(out_idx);
+ switch (type) {
+ case LLAISYS_DTYPE_F32:
+ return sample_(idx_ptr, reinterpret_cast(logits), numel, temperature, top_k, top_p);
+ case LLAISYS_DTYPE_BF16:
+ return sample_(idx_ptr, reinterpret_cast(logits), numel, temperature, top_k, top_p);
+ case LLAISYS_DTYPE_F16:
+ return sample_(idx_ptr, reinterpret_cast(logits), numel, temperature, top_k, top_p);
+ default:
+ EXCEPTION_UNSUPPORTED_DATATYPE(type);
+ }
+}
+} // namespace llaisys::ops::cpu
diff --git a/src/ops/sample/cpu/sample_cpu.hpp b/src/ops/sample/cpu/sample_cpu.hpp
new file mode 100644
index 000000000..611a18300
--- /dev/null
+++ b/src/ops/sample/cpu/sample_cpu.hpp
@@ -0,0 +1,9 @@
+#pragma once
+#include "llaisys.h"
+
+#include
+
+namespace llaisys::ops::cpu {
+void sample(std::byte *out_idx, const std::byte *logits, llaisysDataType_t type, size_t numel,
+ float temperature, int top_k, float top_p);
+}
diff --git a/src/ops/sample/cuda/sample_cuda.cu b/src/ops/sample/cuda/sample_cuda.cu
new file mode 100644
index 000000000..4a37bdc8b
--- /dev/null
+++ b/src/ops/sample/cuda/sample_cuda.cu
@@ -0,0 +1,103 @@
+#include "sample_cuda.cuh"
+#include "../../cuda_utils.cuh"
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+static thread_local std::mt19937 rng{std::random_device{}()};
+
+namespace llaisys::ops::cuda {
+void sample(std::byte *out_idx, const std::byte *logits, llaisysDataType_t type, size_t numel,
+ float temperature, int top_k, float top_p) {
+ // Copy logits from GPU to CPU, do sampling on CPU, copy result back
+ size_t esize = cuda_dsize(type);
+ std::vector host_logits(numel * esize);
+ cudaMemcpy(host_logits.data(), logits, numel * esize, cudaMemcpyDeviceToHost);
+
+ // Convert to float
+ std::vector probs(numel);
+ for (size_t i = 0; i < numel; i++) {
+ if (type == LLAISYS_DTYPE_F32) {
+ probs[i] = reinterpret_cast(host_logits.data())[i];
+ } else if (type == LLAISYS_DTYPE_BF16) {
+ uint16_t v = reinterpret_cast(host_logits.data())[i];
+ uint32_t bits = static_cast(v) << 16;
+ float f;
+ std::memcpy(&f, &bits, sizeof(float));
+ probs[i] = f;
+ } else if (type == LLAISYS_DTYPE_F16) {
+ uint16_t v = reinterpret_cast(host_logits.data())[i];
+ // Simple F16 -> F32 conversion
+ uint32_t sign = (v >> 15) & 0x1;
+ uint32_t exp = (v >> 10) & 0x1F;
+ uint32_t mant = v & 0x3FF;
+ uint32_t f32_bits;
+ if (exp == 0) {
+ f32_bits = sign << 31;
+ } else if (exp == 0x1F) {
+ f32_bits = (sign << 31) | 0x7F800000 | (mant << 13);
+ } else {
+ f32_bits = (sign << 31) | ((exp + 112) << 23) | (mant << 13);
+ }
+ float f;
+ std::memcpy(&f, &f32_bits, sizeof(float));
+ probs[i] = f;
+ }
+ }
+
+ if (temperature <= 0.0f) temperature = 1.0f;
+ if (temperature != 1.0f) {
+ for (size_t i = 0; i < numel; i++) {
+ probs[i] /= temperature;
+ }
+ }
+
+ std::vector indices(numel);
+ std::iota(indices.begin(), indices.end(), 0);
+ std::sort(indices.begin(), indices.end(),
+ [&](int a, int b) { return probs[a] > probs[b]; });
+
+ size_t keep = numel;
+ if (top_k > 0 && static_cast(top_k) < numel) {
+ keep = static_cast(top_k);
+ }
+
+ float max_val = probs[indices[0]];
+ std::vector softmax_vals(keep);
+ float sum_exp = 0.0f;
+ for (size_t i = 0; i < keep; i++) {
+ softmax_vals[i] = std::exp(probs[indices[i]] - max_val);
+ sum_exp += softmax_vals[i];
+ }
+ for (size_t i = 0; i < keep; i++) {
+ softmax_vals[i] /= sum_exp;
+ }
+
+ if (top_p > 0.0f && top_p < 1.0f) {
+ float cumsum = 0.0f;
+ size_t cutoff = keep;
+ for (size_t i = 0; i < keep; i++) {
+ cumsum += softmax_vals[i];
+ if (cumsum >= top_p) {
+ cutoff = i + 1;
+ break;
+ }
+ }
+ keep = cutoff;
+ float new_sum = 0.0f;
+ for (size_t i = 0; i < keep; i++) new_sum += softmax_vals[i];
+ for (size_t i = 0; i < keep; i++) softmax_vals[i] /= new_sum;
+ }
+
+ std::discrete_distribution dist(softmax_vals.begin(), softmax_vals.begin() + keep);
+ int sampled = dist(rng);
+ int64_t result = static_cast(indices[sampled]);
+
+ cudaMemcpy(out_idx, &result, sizeof(int64_t), cudaMemcpyHostToDevice);
+}
+} // namespace llaisys::ops::cuda
diff --git a/src/ops/sample/cuda/sample_cuda.cuh b/src/ops/sample/cuda/sample_cuda.cuh
new file mode 100644
index 000000000..70ee69d6e
--- /dev/null
+++ b/src/ops/sample/cuda/sample_cuda.cuh
@@ -0,0 +1,8 @@
+#pragma once
+#include "llaisys.h"
+#include
+
+namespace llaisys::ops::cuda {
+void sample(std::byte *out_idx, const std::byte *logits, llaisysDataType_t type, size_t numel,
+ float temperature, int top_k, float top_p);
+}
diff --git a/src/ops/sample/op.cpp b/src/ops/sample/op.cpp
new file mode 100644
index 000000000..c7a242b41
--- /dev/null
+++ b/src/ops/sample/op.cpp
@@ -0,0 +1,35 @@
+#include "op.hpp"
+
+#include "../../core/llaisys_core.hpp"
+#include "../../utils.hpp"
+
+#include "cpu/sample_cpu.hpp"
+#ifdef ENABLE_NVIDIA_API
+#include "cuda/sample_cuda.cuh"
+#endif
+
+namespace llaisys::ops {
+void sample(tensor_t out_idx, tensor_t logits, float temperature, int top_k, float top_p) {
+ ASSERT(logits->isContiguous(), "Sample: logits must be contiguous.");
+
+ if (logits->deviceType() == LLAISYS_DEVICE_CPU) {
+ return cpu::sample(out_idx->data(), logits->data(), logits->dtype(), logits->numel(),
+ temperature, top_k, top_p);
+ }
+
+ llaisys::core::context().setDevice(logits->deviceType(), logits->deviceId());
+
+ switch (logits->deviceType()) {
+ case LLAISYS_DEVICE_CPU:
+ return cpu::sample(out_idx->data(), logits->data(), logits->dtype(), logits->numel(),
+ temperature, top_k, top_p);
+#ifdef ENABLE_NVIDIA_API
+ case LLAISYS_DEVICE_NVIDIA:
+ return cuda::sample(out_idx->data(), logits->data(), logits->dtype(), logits->numel(),
+ temperature, top_k, top_p);
+#endif
+ default:
+ EXCEPTION_UNSUPPORTED_DEVICE;
+ }
+}
+} // namespace llaisys::ops
diff --git a/src/ops/sample/op.hpp b/src/ops/sample/op.hpp
new file mode 100644
index 000000000..e815ff784
--- /dev/null
+++ b/src/ops/sample/op.hpp
@@ -0,0 +1,7 @@
+#pragma once
+
+#include "../../tensor/tensor.hpp"
+
+namespace llaisys::ops {
+void sample(tensor_t out_idx, tensor_t logits, float temperature, int top_k, float top_p);
+}
diff --git a/src/ops/self_attention/cpu/self_attention_cpu.cpp b/src/ops/self_attention/cpu/self_attention_cpu.cpp
new file mode 100644
index 000000000..692e41058
--- /dev/null
+++ b/src/ops/self_attention/cpu/self_attention_cpu.cpp
@@ -0,0 +1,170 @@
+#ifdef __AVX2__
+#include
+#endif
+
+#include "self_attention_cpu.hpp"
+
+#include "../../../utils.hpp"
+
+#include
+#include
+#include
+
+#ifdef _OPENMP
+#include
+#endif
+
+#ifdef __AVX2__
+static inline float avx2_dot(const float *a, const float *b, size_t n) {
+ __m256 vsum = _mm256_setzero_ps();
+ size_t i = 0;
+ for (; i + 8 <= n; i += 8) {
+ __m256 va = _mm256_loadu_ps(a + i);
+ __m256 vb = _mm256_loadu_ps(b + i);
+ vsum = _mm256_fmadd_ps(va, vb, vsum);
+ }
+ float tmp[8];
+ _mm256_storeu_ps(tmp, vsum);
+ float sum = tmp[0] + tmp[1] + tmp[2] + tmp[3] +
+ tmp[4] + tmp[5] + tmp[6] + tmp[7];
+ for (; i < n; i++)
+ sum += a[i] * b[i];
+ return sum;
+}
+#endif
+
+template
+void self_attention_(T *attn_val, const T *q, const T *k, const T *v,
+ float scale, size_t qlen, size_t kvlen,
+ size_t nh, size_t nkvh, size_t d) {
+ size_t group_size = nh / nkvh;
+
+ bool need_cast = !std::is_same::value;
+
+ std::vector fq, fk, fv;
+ if (need_cast) {
+ fq.resize(qlen * nh * d);
+ fk.resize(kvlen * nkvh * d);
+ fv.resize(kvlen * nkvh * d);
+
+ #pragma omp parallel for schedule(static)
+ for (size_t i = 0; i < qlen * nh * d; i++)
+ fq[i] = llaisys::utils::cast(q[i]);
+ #pragma omp parallel for schedule(static)
+ for (size_t i = 0; i < kvlen * nkvh * d; i++)
+ fk[i] = llaisys::utils::cast(k[i]);
+ #pragma omp parallel for schedule(static)
+ for (size_t i = 0; i < kvlen * nkvh * d; i++)
+ fv[i] = llaisys::utils::cast(v[i]);
+ }
+
+ const float *qf = need_cast ? fq.data() : reinterpret_cast(q);
+ const float *kf = need_cast ? fk.data() : reinterpret_cast(k);
+ const float *vf = need_cast ? fv.data() : reinterpret_cast(v);
+
+ #pragma omp parallel for schedule(dynamic)
+ for (size_t h = 0; h < nh; h++) {
+ size_t kvh = h / group_size;
+
+ std::vector