diff --git a/.gitignore b/.gitignore
index e38cf574..e9960a4d 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,7 +1,10 @@
+#models
+/models/
# Xmake cache
.xmake/
build/
-
+# md docs
+*.md
# Binaries
bin/
lib/
@@ -87,4 +90,4 @@ htmlcov/
# Windows
Thumbs.db
ehthumbs.db
-desktop.ini
\ No newline at end of file
+desktop.ini
diff --git a/=3.1.0 b/=3.1.0
new file mode 100644
index 00000000..8151c7c6
--- /dev/null
+++ b/=3.1.0
@@ -0,0 +1,2 @@
+Defaulting to user installation because normal site-packages is not writeable
+Requirement already satisfied: jinja2 in /usr/lib/python3/dist-packages (3.0.3)
diff --git a/include/llaisys/models/qwen2.h b/include/llaisys/models/qwen2.h
index 7054626d..8840660f 100644
--- a/include/llaisys/models/qwen2.h
+++ b/include/llaisys/models/qwen2.h
@@ -37,6 +37,7 @@ __C {
__export struct LlaisysQwen2Weights *llaisysQwen2ModelWeights(struct LlaisysQwen2Model * model);
- __export int64_t llaisysQwen2ModelInfer(struct LlaisysQwen2Model * model, int64_t * token_ids, size_t ntoken);
+ __export int64_t llaisysQwen2ModelInfer(struct LlaisysQwen2Model * model, int64_t * token_ids, size_t ntoken,
+ float temperature, size_t topK, float topP, int64_t seed);
}
#endif // LLAISYS_MODELS_QWEN2_H
diff --git a/include/llaisys/ops.h b/include/llaisys/ops.h
index ddb3be24..15b70067 100644
--- a/include/llaisys/ops.h
+++ b/include/llaisys/ops.h
@@ -6,6 +6,8 @@
__C {
__export void llaisysAdd(llaisysTensor_t c, llaisysTensor_t a, llaisysTensor_t b);
__export void llaisysArgmax(llaisysTensor_t max_idx, llaisysTensor_t max_val, llaisysTensor_t vals);
+ __export void llaisysRandSample(llaisysTensor_t sample_idx, llaisysTensor_t sample_val, llaisysTensor_t vals,
+ float temperature, size_t topK, float topP, int64_t seed);
__export void llaisysEmbedding(llaisysTensor_t out, llaisysTensor_t index, llaisysTensor_t weight);
__export void llaisysLinear(llaisysTensor_t out, llaisysTensor_t in, llaisysTensor_t weight, llaisysTensor_t bias);
__export void llaisysRearrange(llaisysTensor_t out, llaisysTensor_t in);
diff --git a/python/llaisys/libllaisys/__init__.py b/python/llaisys/libllaisys/__init__.py
index f536fb52..fa872903 100644
--- a/python/llaisys/libllaisys/__init__.py
+++ b/python/llaisys/libllaisys/__init__.py
@@ -12,6 +12,10 @@
from .tensor import llaisysTensor_t
from .tensor import load_tensor
from .ops import load_ops
+# Llaisys infer
+from .models import load_models
+from .models import LlaisysQwen2Meta
+from .models import LlaisysQwen2Weights
def load_shared_library():
@@ -38,6 +42,8 @@ def load_shared_library():
load_runtime(LIB_LLAISYS)
load_tensor(LIB_LLAISYS)
load_ops(LIB_LLAISYS)
+# Llaisys load_models
+load_models(LIB_LLAISYS)
__all__ = [
@@ -52,4 +58,7 @@ def load_shared_library():
"llaisysMemcpyKind_t",
"MemcpyKind",
"llaisysStream_t",
+ # Llaisys c side
+ "LlaisysQwen2Meta",
+ "LlaisysQwen2Weights",
]
diff --git a/python/llaisys/libllaisys/models.py b/python/llaisys/libllaisys/models.py
new file mode 100644
index 00000000..a3807fae
--- /dev/null
+++ b/python/llaisys/libllaisys/models.py
@@ -0,0 +1,69 @@
+import ctypes
+from ctypes import c_size_t, c_int64, c_int, c_float
+from .llaisys_types import llaisysDeviceType_t, llaisysDataType_t
+from .tensor import llaisysTensor_t
+
+# c side wrap
+
+class LlaisysQwen2Meta(ctypes.Structure):
+ _fields_ = [
+ ("dtype", llaisysDataType_t),
+ ("nlayer", c_size_t),
+ ("hs", c_size_t),
+ ("nh", c_size_t),
+ ("nkvh", c_size_t),
+ ("dh", c_size_t),
+ ("di", c_size_t),
+ ("maxseq", c_size_t),
+ ("voc", c_size_t),
+ ("epsilon", c_float),
+ ("theta", c_float),
+ ("end_token", c_int64),
+ ]
+
+
+class LlaisysQwen2Weights(ctypes.Structure):
+ _fields_ = [
+ ("in_embed", llaisysTensor_t),
+ ("out_embed", llaisysTensor_t),
+ ("out_norm_w", llaisysTensor_t),
+ ("attn_norm_w", ctypes.POINTER(llaisysTensor_t)),
+ ("attn_q_w", ctypes.POINTER(llaisysTensor_t)),
+ ("attn_q_b", ctypes.POINTER(llaisysTensor_t)),
+ ("attn_k_w", ctypes.POINTER(llaisysTensor_t)),
+ ("attn_k_b", ctypes.POINTER(llaisysTensor_t)),
+ ("attn_v_w", ctypes.POINTER(llaisysTensor_t)),
+ ("attn_v_b", ctypes.POINTER(llaisysTensor_t)),
+ ("attn_o_w", ctypes.POINTER(llaisysTensor_t)),
+ ("mlp_norm_w", ctypes.POINTER(llaisysTensor_t)),
+ ("mlp_gate_w", ctypes.POINTER(llaisysTensor_t)),
+ ("mlp_up_w", ctypes.POINTER(llaisysTensor_t)),
+ ("mlp_down_w", ctypes.POINTER(llaisysTensor_t)),
+ ]
+
+
+def load_models(lib):
+ lib.llaisysQwen2ModelCreate.argtypes = [
+ ctypes.POINTER(LlaisysQwen2Meta),
+ llaisysDeviceType_t,
+ ctypes.POINTER(c_int),
+ c_int,
+ ]
+ lib.llaisysQwen2ModelCreate.restype = ctypes.c_void_p
+
+ lib.llaisysQwen2ModelDestroy.argtypes = [ctypes.c_void_p]
+ lib.llaisysQwen2ModelDestroy.restype = None
+
+ lib.llaisysQwen2ModelWeights.argtypes = [ctypes.c_void_p]
+ lib.llaisysQwen2ModelWeights.restype = ctypes.POINTER(LlaisysQwen2Weights)
+
+ lib.llaisysQwen2ModelInfer.argtypes = [
+ ctypes.c_void_p,
+ ctypes.POINTER(c_int64),
+ c_size_t,
+ c_float,
+ c_size_t,
+ c_float,
+ c_int64,
+ ]
+ lib.llaisysQwen2ModelInfer.restype = c_int64
diff --git a/python/llaisys/libllaisys/ops.py b/python/llaisys/libllaisys/ops.py
index 5be095ef..e6ecfaa6 100644
--- a/python/llaisys/libllaisys/ops.py
+++ b/python/llaisys/libllaisys/ops.py
@@ -1,5 +1,5 @@
from .tensor import llaisysTensor_t
-from ctypes import c_float
+from ctypes import c_float, c_int64, c_size_t
def load_ops(lib):
lib.llaisysAdd.argtypes = [llaisysTensor_t, llaisysTensor_t, llaisysTensor_t]
@@ -8,6 +8,17 @@ def load_ops(lib):
lib.llaisysArgmax.argtypes = [llaisysTensor_t, llaisysTensor_t, llaisysTensor_t]
lib.llaisysArgmax.restype = None
+ lib.llaisysRandSample.argtypes = [
+ llaisysTensor_t,
+ llaisysTensor_t,
+ llaisysTensor_t,
+ c_float,
+ c_size_t,
+ c_float,
+ c_int64,
+ ]
+ lib.llaisysRandSample.restype = None
+
lib.llaisysEmbedding.argtypes = [llaisysTensor_t, llaisysTensor_t, llaisysTensor_t]
lib.llaisysEmbedding.restype = None
diff --git a/python/llaisys/models/qwen2.py b/python/llaisys/models/qwen2.py
index 0d07b0b2..c89cbb04 100644
--- a/python/llaisys/models/qwen2.py
+++ b/python/llaisys/models/qwen2.py
@@ -1,23 +1,177 @@
from typing import Sequence
from ..libllaisys import LIB_LLAISYS
from ..libllaisys import DeviceType
+from ..libllaisys import DataType
+from ..libllaisys import llaisysDeviceType_t
+from ..libllaisys.models import LlaisysQwen2Meta
from pathlib import Path
+import os
import safetensors
+import json
+import ctypes
+import torch
class Qwen2:
-
def __init__(self, model_path, device: DeviceType = DeviceType.CPU):
- # TODO: Implement model constructor
-
+ self._device = device
model_path = Path(model_path)
+ config_path = model_path / "config.json"
+ # read model config from config.json
+ with open(config_path, "r", encoding="utf-8") as f:
+ config = json.load(f)
+
+ # parse dim
+ hs = int(config["hidden_size"])
+ nlayer = int(config["num_hidden_layers"])
+ nh = int(config["num_attention_heads"])
+ nkvh = int(config.get("num_key_value_heads", nh))
+ di = int(config["intermediate_size"])
+ dh = int(hs // nh)
+
+ # parse key params
+ maxseq = int(config["max_position_embeddings"])
+ voc = int(config["vocab_size"])
+ epsilon = float(config["rms_norm_eps"])
+ theta = float(config["rope_theta"])
+ end_token = int(config["eos_token_id"])
+
+
+ dtype = self._select_dtype(device)
+ self._dtype = dtype
+ # construct C struct LlaisysQwen2Meta
+ meta = LlaisysQwen2Meta(
+ dtype=dtype,
+ nlayer=nlayer,
+ hs=hs,
+ nh=nh,
+ nkvh=nkvh,
+ dh=dh,
+ di=di,
+ maxseq=maxseq,
+ voc=voc,
+ epsilon=epsilon,
+ theta=theta,
+ end_token=end_token,
+ )
+ device_ids = (ctypes.c_int * 1)(0)
+ # create model instance
+ self._model = LIB_LLAISYS.llaisysQwen2ModelCreate(
+ ctypes.byref(meta), llaisysDeviceType_t(device), device_ids, 1
+ )
+
+ # get model weights
+ self._weights = LIB_LLAISYS.llaisysQwen2ModelWeights(self._model).contents
+ self._end_token = end_token
+
+ # traverse all safetensors files, in fact only one file in qwen2
for file in sorted(model_path.glob("*.safetensors")):
- data_ = safetensors.safe_open(file, framework="numpy", device="cpu")
+ # load on cpu, I use pt framework to load bfloat16 weights here
+ data_ = safetensors.safe_open(file, framework="pt", device="cpu")
for name_ in data_.keys():
- ## TODO: load the model weights
- pass
+ weight = self._match_weight(name_)
+ if weight is None:
+ continue
+ # load weight to c side
+ arr = data_.get_tensor(name_)
+ torch_dtype = self._torch_dtype(self._dtype)
+ if arr.dtype != torch_dtype:
+ arr = arr.to(torch_dtype)
+ arr = arr.contiguous()
+
+ LIB_LLAISYS.tensorLoad(weight, ctypes.c_void_p(arr.data_ptr()))
+
+
+ def _match_weight(self, name: str):
+ # match weight name to c struct field
+
+ w = self._weights
+ # input embedding
+ if name == "model.embed_tokens.weight":
+ return w.in_embed
+ # output embedding
+ if name in ("lm_head.weight", "model.lm_head.weight"):
+ return w.out_embed
+ # final LayerNorm
+ if name == "model.norm.weight":
+ return w.out_norm_w
+ # only processtransformer layer weights
+ if not name.startswith("model.layers."):
+ return None
+ parts = name.split(".")
+ if len(parts) < 5:
+ return None
+ layer = int(parts[2]) # 提取层索引
+ tail = ".".join(parts[3:]) # 剩余后缀
+ # Attention Layer
+ if tail == "input_layernorm.weight":
+ return w.attn_norm_w[layer]
+ if tail == "self_attn.q_proj.weight":
+ return w.attn_q_w[layer]
+ if tail == "self_attn.q_proj.bias":
+ return w.attn_q_b[layer]
+ if tail == "self_attn.k_proj.weight":
+ return w.attn_k_w[layer]
+ if tail == "self_attn.k_proj.bias":
+ return w.attn_k_b[layer]
+ if tail == "self_attn.v_proj.weight":
+ return w.attn_v_w[layer]
+ if tail == "self_attn.v_proj.bias":
+ return w.attn_v_b[layer]
+ if tail == "self_attn.o_proj.weight":
+ return w.attn_o_w[layer]
+ # FFN layer
+ if tail == "post_attention_layernorm.weight":
+ return w.mlp_norm_w[layer]
+ if tail == "mlp.gate_proj.weight":
+ return w.mlp_gate_w[layer]
+ if tail == "mlp.up_proj.weight":
+ return w.mlp_up_w[layer]
+ if tail == "mlp.down_proj.weight":
+ return w.mlp_down_w[layer]
+ return None
+
+ def _select_dtype(self, device: DeviceType) -> DataType:
+ dtype_env = os.environ.get("LLAISYS_DTYPE", "").strip().lower()
+ if dtype_env in ("f16", "float16"):
+ return DataType.F16
+ if dtype_env in ("f32", "float32"):
+ return DataType.F32
+ if dtype_env in ("bf16", "bfloat16"):
+ return DataType.BF16
+ if device == DeviceType.NVIDIA:
+ return DataType.F32
+ return DataType.BF16
+
+ def _torch_dtype(self, dtype: DataType):
+ if dtype == DataType.F16:
+ return torch.float16
+ if dtype == DataType.F32:
+ return torch.float32
+ if dtype == DataType.BF16:
+ return torch.bfloat16
+ return torch.float32
+
+ def _infer(self, tokens: Sequence[int], temperature: float, top_k: int, top_p: float, seed: int) -> int:
+ # step forward infer
+
+ if len(tokens) == 0:
+ return self._end_token
+ # convert python list to c int64 array
+ arr = (ctypes.c_int64 * len(tokens))(*tokens)
+ return int(
+ LIB_LLAISYS.llaisysQwen2ModelInfer(
+ self._model,
+ arr,
+ ctypes.c_size_t(len(tokens)),
+ ctypes.c_float(temperature),
+ ctypes.c_size_t(top_k),
+ ctypes.c_float(top_p),
+ ctypes.c_int64(seed),
+ )
+ )
def generate(
self,
@@ -26,8 +180,22 @@ def generate(
top_k: int = 1,
top_p: float = 0.8,
temperature: float = 0.8,
+ seed: int = 0,
):
-
- # TODO: Implement generate function
-
- return []
+ # max new tokens default value:32
+ if max_new_tokens is None:
+ max_new_tokens = 32
+ tokens = list(inputs)
+ if max_new_tokens == 0:
+ return tokens
+ # prefill
+ next_token = self._infer(tokens, temperature, top_k, top_p, seed)
+ tokens.append(next_token)
+ # decode
+ for _ in range(max_new_tokens - 1):
+ if tokens[-1] == self._end_token:
+ break
+ seed += 1
+ next_token = self._infer([tokens[-1]], temperature, top_k, top_p, seed)
+ tokens.append(next_token)
+ return tokens
diff --git a/python/llaisys/ops.py b/python/llaisys/ops.py
index ed0180bc..8c835eeb 100644
--- a/python/llaisys/ops.py
+++ b/python/llaisys/ops.py
@@ -1,6 +1,6 @@
from .libllaisys import LIB_LLAISYS
from .tensor import Tensor
-from ctypes import c_float, c_int
+from ctypes import c_float, c_int, c_int64, c_size_t
class Ops:
@@ -12,6 +12,26 @@ def add(c: Tensor, a: Tensor, b: Tensor):
def argmax(max_idx: Tensor, max_val: Tensor, vals: Tensor):
LIB_LLAISYS.llaisysArgmax(max_idx.lib_tensor(), max_val.lib_tensor(), vals.lib_tensor())
+ @staticmethod
+ def rand_sample(
+ sample_idx: Tensor,
+ sample_val: Tensor,
+ vals: Tensor,
+ temperature: float,
+ top_k: int,
+ top_p: float,
+ seed: int,
+ ):
+ LIB_LLAISYS.llaisysRandSample(
+ sample_idx.lib_tensor(),
+ sample_val.lib_tensor(),
+ vals.lib_tensor(),
+ c_float(temperature),
+ c_size_t(top_k),
+ c_float(top_p),
+ c_int64(seed),
+ )
+
@staticmethod
def embedding(out: Tensor, index: Tensor, weight: Tensor):
LIB_LLAISYS.llaisysEmbedding(
diff --git a/python/llaisys/server/chat_server.py b/python/llaisys/server/chat_server.py
new file mode 100644
index 00000000..72af35ff
--- /dev/null
+++ b/python/llaisys/server/chat_server.py
@@ -0,0 +1,264 @@
+import os
+import time
+import uuid
+import json
+import threading
+from pathlib import Path
+from typing import List, Optional, Iterable
+
+from fastapi import FastAPI, HTTPException
+from pydantic import BaseModel
+from starlette.responses import StreamingResponse, JSONResponse, FileResponse
+from starlette.staticfiles import StaticFiles
+from transformers import AutoTokenizer
+
+import llaisys
+
+# define Message format
+class ChatMessage(BaseModel):
+ role: str
+ content: str
+
+# define Request format
+class ChatRequest(BaseModel):
+ model: Optional[str] = None
+ messages: List[ChatMessage]
+ max_tokens: Optional[int] = 128
+ temperature: float = 1.0
+ top_p: float = 0.8
+ top_k: int = 50
+ stream: bool = False
+ seed: int = 0
+
+
+class AppState:
+ tokenizer = None
+ model = None
+ model_path = None
+ model_name = None
+ device = None
+
+
+state = AppState()
+app = FastAPI()
+model_lock = threading.Lock()
+
+WEBUI_INDEX = Path(__file__).resolve().parent.parent / "webUI" / "index.html"
+WEBUI_DIR = WEBUI_INDEX.parent.parent
+WEBUI_ENABLED = WEBUI_INDEX.is_file()
+if WEBUI_ENABLED:
+ app.mount("/webUI", StaticFiles(directory=WEBUI_DIR), name="webUI")
+
+# map device name to llaisys device type
+def llaisys_device(device_name: str):
+ if device_name == "cpu":
+ return llaisys.DeviceType.CPU
+ if device_name == "nvidia":
+ return llaisys.DeviceType.NVIDIA
+ raise ValueError(f"Unsupported device name: {device_name}")
+
+# set default values for model path, device, and model name
+# if not provided
+def ensure_state_loaded():
+ if state.model is not None and state.tokenizer is not None:
+ return
+ model_path = os.environ.get("LLAISYS_MODEL_PATH", "./models")
+ if not model_path:
+ raise RuntimeError("LLAISYS_MODEL_PATH is required")
+ device = os.environ.get("LLAISYS_DEVICE", "cpu")
+ model_name = os.environ.get("LLAISYS_MODEL_NAME", "llaisys-qwen2")
+ state.model_path = model_path
+ state.device = device
+ state.model_name = model_name
+ state.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
+ state.model = llaisys.models.Qwen2(model_path, llaisys_device(device))
+
+# build prompt from messages
+def build_prompt(messages: List[ChatMessage]) -> str:
+ conversation = [{"role": m.role, "content": m.content} for m in messages]
+ return state.tokenizer.apply_chat_template(
+ conversation=conversation,
+ add_generation_prompt=True,
+ tokenize=False,
+ )
+
+# decode output token id and mask prompt tokens
+def decode_completion(input_ids: List[int], output_ids: List[int]) -> str:
+ prompt_text = state.tokenizer.decode(input_ids, skip_special_tokens=True)
+ full_text = state.tokenizer.decode(output_ids, skip_special_tokens=True)
+ return full_text[len(prompt_text):]
+
+# generate full completion
+def generate_full(
+ input_ids: List[int],
+ max_tokens: int,
+ temperature: float,
+ top_k: int,
+ top_p: float,
+ seed: int,
+):
+ output_ids = state.model.generate(
+ input_ids,
+ max_new_tokens=max_tokens,
+ top_k=top_k,
+ top_p=top_p,
+ temperature=temperature,
+ seed=seed,
+ )
+ completion_text = decode_completion(input_ids, output_ids)
+ completion_tokens = max(0, len(output_ids) - len(input_ids))
+ return completion_text, completion_tokens
+
+# stream mode
+def generate_stream(
+ input_ids: List[int],
+ max_tokens: int,
+ temperature: float,
+ top_k: int,
+ top_p: float,
+ seed: int,
+) -> Iterable[str]:
+ tokens = list(input_ids)
+ prompt_text = state.tokenizer.decode(tokens, skip_special_tokens=True)
+ prev_text = prompt_text
+ if max_tokens <= 0:
+ return
+ # generate first output token, then append it to prompt
+ next_token = state.model._infer(tokens, temperature, top_k, top_p, seed)
+ tokens.append(next_token)
+ text = state.tokenizer.decode(tokens, skip_special_tokens=True)
+ delta = text[len(prev_text):]
+ prev_text = text
+ # if we decode a non-empty delta, yield it, that's the first token
+ if delta:
+ yield delta
+ # continue generate subsequent tokens, until max_tokens or end_token
+ for _ in range(max_tokens - 1):
+ if tokens[-1] == state.model._end_token:
+ break
+ seed += 1
+ # because of kv-cache, we only need to infer the last token
+ next_token = state.model._infer([tokens[-1]], temperature, top_k, top_p, seed)
+ tokens.append(next_token)
+ text = state.tokenizer.decode(tokens, skip_special_tokens=True)
+ delta = text[len(prev_text):]
+ prev_text = text
+ if delta:
+ yield delta
+
+
+def sse_chunk(payload: dict) -> str:
+ return "data: " + json.dumps(payload, ensure_ascii=False) + "\n\n"
+
+
+@app.post("/v1/chat/completions")
+def chat_completions(req: ChatRequest):
+ ensure_state_loaded()
+ if not req.messages:
+ raise HTTPException(status_code=400, detail="messages is required")
+ prompt = build_prompt(req.messages)
+ input_ids = state.tokenizer.encode(prompt)
+ max_tokens = req.max_tokens if req.max_tokens is not None else 128
+ created = int(time.time())
+ request_id = "chatcmpl-" + uuid.uuid4().hex
+
+ if not req.stream:
+ with model_lock:
+ completion_text, completion_tokens = generate_full(
+ input_ids=input_ids,
+ max_tokens=max_tokens,
+ temperature=req.temperature,
+ top_k=req.top_k,
+ top_p=req.top_p,
+ seed=req.seed,
+ )
+ response = {
+ "id": request_id,
+ "object": "chat.completion",
+ "created": created,
+ "model": req.model or state.model_name,
+ "choices": [
+ {
+ "index": 0,
+ "message": {"role": "assistant", "content": completion_text},
+ "finish_reason": "stop",
+ }
+ ],
+ "usage": {
+ "prompt_tokens": len(input_ids),
+ "completion_tokens": completion_tokens,
+ "total_tokens": len(input_ids) + completion_tokens,
+ },
+ }
+ return JSONResponse(response)
+
+ def event_stream():
+ model_lock.acquire()
+ try:
+ yield sse_chunk(
+ {
+ "id": request_id,
+ "object": "chat.completion.chunk",
+ "created": created,
+ "model": req.model or state.model_name,
+ "choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}],
+ }
+ )
+ for delta in generate_stream(
+ input_ids=input_ids,
+ max_tokens=max_tokens,
+ temperature=req.temperature,
+ top_k=req.top_k,
+ top_p=req.top_p,
+ seed=req.seed,
+ ):
+ yield sse_chunk(
+ {
+ "id": request_id,
+ "object": "chat.completion.chunk",
+ "created": created,
+ "model": req.model or state.model_name,
+ "choices": [{"index": 0, "delta": {"content": delta}, "finish_reason": None}],
+ }
+ )
+ yield sse_chunk(
+ {
+ "id": request_id,
+ "object": "chat.completion.chunk",
+ "created": created,
+ "model": req.model or state.model_name,
+ "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
+ }
+ )
+ yield "data: [DONE]\n\n"
+ finally:
+ model_lock.release()
+
+ return StreamingResponse(event_stream(), media_type="text/event-stream")
+
+
+@app.get("/")
+def chat_ui():
+ if WEBUI_ENABLED:
+ return FileResponse(WEBUI_INDEX)
+ raise HTTPException(status_code=404, detail="Web UI not found")
+
+
+if __name__ == "__main__":
+ import argparse
+ import uvicorn
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--host", default="0.0.0.0")
+ parser.add_argument("--port", default=8000, type=int)
+ parser.add_argument("--model_path", default=None)
+ parser.add_argument("--device", default="nvidia")
+ parser.add_argument("--model_name", default="llaisys-qwen2")
+ args = parser.parse_args()
+
+ if args.model_path:
+ os.environ["LLAISYS_MODEL_PATH"] = args.model_path
+ os.environ["LLAISYS_DEVICE"] = args.device
+ os.environ["LLAISYS_MODEL_NAME"] = args.model_name
+
+ uvicorn.run(app, host=args.host, port=args.port)
diff --git a/python/llaisys/webUI/index.html b/python/llaisys/webUI/index.html
new file mode 100644
index 00000000..6da35694
--- /dev/null
+++ b/python/llaisys/webUI/index.html
@@ -0,0 +1,363 @@
+
+
+
+
+
+ LLAISYS Chat UI
+
+
+
+
+
+
+
diff --git a/python/setup.cfg b/python/setup.cfg
index b35fc65f..88ff8f6e 100644
--- a/python/setup.cfg
+++ b/python/setup.cfg
@@ -13,6 +13,9 @@ install_requires =
torch>=2.4.0
transformers
accelerate
+ fastapi
+ uvicorn
+ sse-starlette
[options.package_data]
llaisys =
diff --git a/src/device/nvidia/nvidia_runtime_api.cu b/src/device/nvidia/nvidia_runtime_api.cu
index cab92826..adc39208 100644
--- a/src/device/nvidia/nvidia_runtime_api.cu
+++ b/src/device/nvidia/nvidia_runtime_api.cu
@@ -1,56 +1,97 @@
#include "../runtime_api.hpp"
+#include
#include
#include
+#include
+#include
namespace llaisys::device::nvidia {
namespace runtime_api {
+static void checkCuda(cudaError_t err) {
+ if (err != cudaSuccess) {
+ throw std::runtime_error(std::string("CUDA error: ") + cudaGetErrorString(err));
+ }
+}
+
+static cudaMemcpyKind toCudaMemcpyKind(llaisysMemcpyKind_t kind) {
+ switch (kind) {
+ case LLAISYS_MEMCPY_H2H:
+ return cudaMemcpyHostToHost;
+ case LLAISYS_MEMCPY_H2D:
+ return cudaMemcpyHostToDevice;
+ case LLAISYS_MEMCPY_D2H:
+ return cudaMemcpyDeviceToHost;
+ case LLAISYS_MEMCPY_D2D:
+ return cudaMemcpyDeviceToDevice;
+ default:
+ throw std::runtime_error("Unsupported memcpy kind");
+ }
+}
+
int getDeviceCount() {
- TO_BE_IMPLEMENTED();
+ int count = 0;
+ checkCuda(cudaGetDeviceCount(&count));
+ return count;
}
-void setDevice(int) {
- TO_BE_IMPLEMENTED();
+void setDevice(int device_id) {
+ checkCuda(cudaSetDevice(device_id));
}
void deviceSynchronize() {
- TO_BE_IMPLEMENTED();
+ checkCuda(cudaDeviceSynchronize());
}
llaisysStream_t createStream() {
- TO_BE_IMPLEMENTED();
+ cudaStream_t stream = nullptr;
+ checkCuda(cudaStreamCreate(&stream));
+ return reinterpret_cast(stream);
}
void destroyStream(llaisysStream_t stream) {
- TO_BE_IMPLEMENTED();
+ if (stream == nullptr) {
+ return;
+ }
+ checkCuda(cudaStreamDestroy(reinterpret_cast(stream)));
}
void streamSynchronize(llaisysStream_t stream) {
- TO_BE_IMPLEMENTED();
+ checkCuda(cudaStreamSynchronize(reinterpret_cast(stream)));
}
void *mallocDevice(size_t size) {
- TO_BE_IMPLEMENTED();
+ void *ptr = nullptr;
+ checkCuda(cudaMalloc(&ptr, size));
+ return ptr;
}
void freeDevice(void *ptr) {
- TO_BE_IMPLEMENTED();
+ if (ptr == nullptr) {
+ return;
+ }
+ checkCuda(cudaFree(ptr));
}
void *mallocHost(size_t size) {
- TO_BE_IMPLEMENTED();
+ void *ptr = nullptr;
+ checkCuda(cudaMallocHost(&ptr, size));
+ return ptr;
}
void freeHost(void *ptr) {
- TO_BE_IMPLEMENTED();
+ if (ptr == nullptr) {
+ return;
+ }
+ checkCuda(cudaFreeHost(ptr));
}
void memcpySync(void *dst, const void *src, size_t size, llaisysMemcpyKind_t kind) {
- TO_BE_IMPLEMENTED();
+ checkCuda(cudaMemcpy(dst, src, size, toCudaMemcpyKind(kind)));
}
-void memcpyAsync(void *dst, const void *src, size_t size, llaisysMemcpyKind_t kind) {
- TO_BE_IMPLEMENTED();
+void memcpyAsync(void *dst, const void *src, size_t size, llaisysMemcpyKind_t kind, llaisysStream_t stream) {
+ checkCuda(cudaMemcpyAsync(dst, src, size, toCudaMemcpyKind(kind), reinterpret_cast(stream)));
}
static const LlaisysRuntimeAPI RUNTIME_API = {
diff --git a/src/llaisys/ops.cc b/src/llaisys/ops.cc
index c99fbc32..50186759 100644
--- a/src/llaisys/ops.cc
+++ b/src/llaisys/ops.cc
@@ -6,6 +6,7 @@
#include "../ops/argmax/op.hpp"
#include "../ops/embedding/op.hpp"
#include "../ops/linear/op.hpp"
+#include "../ops/rand_sample/op.hpp"
#include "../ops/rearrange/op.hpp"
#include "../ops/rms_norm/op.hpp"
#include "../ops/rope/op.hpp"
@@ -19,6 +20,10 @@ __C {
void llaisysArgmax(llaisysTensor_t max_idx, llaisysTensor_t max_val, llaisysTensor_t vals) {
llaisys::ops::argmax(max_idx->tensor, max_val->tensor, vals->tensor);
}
+ void llaisysRandSample(llaisysTensor_t sample_idx, llaisysTensor_t sample_val, llaisysTensor_t vals,
+ float temperature, size_t topK, float topP, int64_t seed) {
+ llaisys::ops::rand_sample(sample_idx->tensor, sample_val->tensor, vals->tensor, temperature, topK, topP, seed);
+ }
void llaisysEmbedding(llaisysTensor_t out, llaisysTensor_t index, llaisysTensor_t weight) {
llaisys::ops::embedding(out->tensor, index->tensor, weight->tensor);
}
diff --git a/src/llaisys/qwen2.cc b/src/llaisys/qwen2.cc
new file mode 100644
index 00000000..286ec4ea
--- /dev/null
+++ b/src/llaisys/qwen2.cc
@@ -0,0 +1,832 @@
+#include "llaisys/models/qwen2.h"
+#include "llaisys_tensor.hpp"
+
+#include "../tensor/tensor.hpp"
+#include "../core/llaisys_core.hpp"
+#include "../ops/add/op.hpp"
+#include "../ops/rand_sample/op.hpp"
+#include "../ops/embedding/op.hpp"
+#include "../ops/linear/op.hpp"
+#include "../ops/rearrange/op.hpp"
+#include "../ops/rms_norm/op.hpp"
+#include "../ops/rope/op.hpp"
+#include "../ops/self_attention/op.hpp"
+#include "../ops/swiglu/op.hpp"
+#include "../utils.hpp"
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+// wrap c++ tensor to external handle
+namespace {
+
+// based on modelMetaInfo create tensor
+llaisys::tensor_t make_tensor(
+ const LlaisysQwen2Meta &meta,
+ llaisysDeviceType_t device,
+ int device_id,
+ const std::vector &shape) {
+ return llaisys::Tensor::create(shape, meta.dtype, device, device_id);
+}
+
+// based on dtype create tensor
+llaisys::tensor_t make_tensor_dtype(
+ llaisysDataType_t dtype,
+ llaisysDeviceType_t device,
+ int device_id,
+ const std::vector &shape) {
+ return llaisys::Tensor::create(shape, dtype, device, device_id);
+}
+
+// set tensor data to zero
+void zero_tensor(const llaisys::tensor_t &t) {
+ size_t size = t->numel() * t->elementSize();
+ if (t->deviceType() == LLAISYS_DEVICE_CPU) {
+ std::memset(t->data(), 0, size);
+ return;
+ }
+ std::vector zeros(size);
+ t->load(zeros.data());
+}
+
+// wrap c++ tensor to external handle and record to handles list
+llaisysTensor_t wrap_tensor(
+ const llaisys::tensor_t &t,
+ std::vector &handles) {
+ auto *h = new LlaisysTensor{t};
+ handles.push_back(h);
+ return h;
+}
+
+size_t ceil_div(size_t a, size_t b) {
+ return (a + b - 1) / b;
+}
+
+size_t read_env_size_t(const char *name, size_t default_value) {
+#ifdef _WIN32
+ char *value = nullptr;
+ size_t len = 0;
+ if (_dupenv_s(&value, &len, name) != 0 || value == nullptr) {
+ return default_value;
+ }
+ char *end = nullptr;
+ unsigned long long parsed = std::strtoull(value, &end, 10);
+ std::free(value);
+ if (end == value || *end != '\0') {
+ return default_value;
+ }
+ return static_cast(parsed);
+#else
+ const char *value = std::getenv(name);
+ if (!value) {
+ return default_value;
+ }
+ char *end = nullptr;
+ unsigned long long parsed = std::strtoull(value, &end, 10);
+ if (end == value || *end != '\0') {
+ return default_value;
+ }
+ return static_cast(parsed);
+#endif
+}
+
+uint64_t hash_mix(uint64_t h, int64_t v, uint64_t base) {
+ return h * base + (static_cast(v) + 0x9e3779b97f4a7c15ull);
+}
+
+struct KVPagePoolLayer {
+ std::vector k_pages;
+ std::vector v_pages;
+ std::vector refcnt;
+ std::vector last_access;
+ std::vector free_ids;
+ std::vector is_free;
+};
+
+class KVCachePool {
+public:
+ void init(const LlaisysQwen2Meta &meta, llaisysDeviceType_t device, int device_id, size_t page_len, size_t max_pages) {
+ meta_ = meta;
+ device_ = device;
+ device_id_ = device_id;
+ page_len_ = page_len;
+ max_pages_ = max_pages;
+ access_clock_ = 1;
+ layers_.assign(meta.nlayer, KVPagePoolLayer{});
+ }
+
+ size_t page_len() const {
+ return page_len_;
+ }
+
+ size_t acquire_page(size_t layer) {
+ auto &pool = layers_[layer];
+ while (!pool.free_ids.empty()) {
+ size_t id = pool.free_ids.back();
+ pool.free_ids.pop_back();
+ if (pool.is_free[id] && pool.refcnt[id] == 0) {
+ pool.is_free[id] = 0;
+ pool.last_access[id] = access_clock_++;
+ return id;
+ }
+ }
+
+ if (pool.k_pages.size() < max_pages_) {
+ size_t id = pool.k_pages.size();
+ pool.k_pages.push_back(make_tensor(meta_, device_, device_id_, {page_len_, meta_.nkvh, meta_.dh}));
+ pool.v_pages.push_back(make_tensor(meta_, device_, device_id_, {page_len_, meta_.nkvh, meta_.dh}));
+ pool.refcnt.push_back(0);
+ pool.last_access.push_back(access_clock_++);
+ pool.is_free.push_back(0);
+ return id;
+ }
+
+ size_t selected = std::numeric_limits::max();
+ uint64_t best_access = std::numeric_limits::max();
+ for (size_t i = 0; i < pool.k_pages.size(); ++i) {
+ if (pool.refcnt[i] == 0 && pool.last_access[i] <= best_access) {
+ best_access = pool.last_access[i];
+ selected = i;
+ }
+ }
+ CHECK_ARGUMENT(selected != std::numeric_limits::max(), "KV cache pool has no free page.");
+ pool.is_free[selected] = 0;
+ pool.last_access[selected] = access_clock_++;
+ return selected;
+ }
+
+ void incref(size_t layer, size_t page_id) {
+ auto &pool = layers_[layer];
+ if (pool.is_free[page_id]) {
+ pool.is_free[page_id] = 0;
+ }
+ pool.refcnt[page_id] += 1;
+ pool.last_access[page_id] = access_clock_++;
+ }
+
+ void decref(size_t layer, size_t page_id) {
+ auto &pool = layers_[layer];
+ if (pool.refcnt[page_id] > 0) {
+ pool.refcnt[page_id] -= 1;
+ if (pool.refcnt[page_id] == 0 && !pool.is_free[page_id]) {
+ pool.is_free[page_id] = 1;
+ pool.free_ids.push_back(page_id);
+ }
+ }
+ pool.last_access[page_id] = access_clock_++;
+ }
+
+ llaisys::tensor_t k_page(size_t layer, size_t page_id) {
+ auto &pool = layers_[layer];
+ pool.last_access[page_id] = access_clock_++;
+ return pool.k_pages[page_id];
+ }
+
+ llaisys::tensor_t v_page(size_t layer, size_t page_id) {
+ auto &pool = layers_[layer];
+ pool.last_access[page_id] = access_clock_++;
+ return pool.v_pages[page_id];
+ }
+
+private:
+ LlaisysQwen2Meta meta_;
+ llaisysDeviceType_t device_;
+ int device_id_;
+ size_t page_len_;
+ size_t max_pages_;
+ uint64_t access_clock_ = 1;
+ std::vector layers_;
+};
+
+struct KVHandle {
+ std::vector> layer_pages;
+ size_t token_count = 0;
+ uint64_t last_access = 0;
+ std::vector tokens;
+ std::vector hash_keys;
+};
+
+class PrefixCacheIndex {
+public:
+ void init(size_t max_handles) {
+ max_handles_ = max_handles;
+ access_clock_ = 1;
+ handles_.clear();
+ alive_.clear();
+ key_to_handles_.clear();
+ free_handle_ids_.clear();
+ }
+
+ size_t find_longest_prefix(const int64_t *tokens, size_t ntoken, const KVCachePool &pool, KVHandle &out_handle) {
+ if (ntoken == 0) {
+ return 0;
+ }
+ auto hashes = prefix_hashes(tokens, ntoken);
+ for (size_t len = ntoken; len > 0; --len) {
+ uint64_t key = make_key(hashes[len], len);
+ auto it = key_to_handles_.find(key);
+ if (it == key_to_handles_.end()) {
+ continue;
+ }
+ for (size_t handle_id : it->second) {
+ if (handle_id >= handles_.size() || !alive_[handle_id]) {
+ continue;
+ }
+ const auto &h = handles_[handle_id];
+ if (h.tokens.size() < len) {
+ continue;
+ }
+ if (!std::equal(tokens, tokens + len, h.tokens.begin())) {
+ continue;
+ }
+ size_t pages_needed = ceil_div(len, pool.page_len());
+ out_handle.layer_pages.assign(h.layer_pages.size(), {});
+ for (size_t i = 0; i < h.layer_pages.size(); ++i) {
+ out_handle.layer_pages[i].assign(h.layer_pages[i].begin(), h.layer_pages[i].begin() + pages_needed);
+ }
+ out_handle.token_count = len;
+ out_handle.tokens.assign(tokens, tokens + len);
+ out_handle.last_access = access_clock_++;
+ handles_[handle_id].last_access = out_handle.last_access;
+ return len;
+ }
+ }
+ return 0;
+ }
+
+ void insert_handle(const KVHandle &handle, const int64_t *tokens, size_t ntoken, KVCachePool &pool) {
+ if (ntoken == 0) {
+ return;
+ }
+ size_t handle_id = acquire_handle_id();
+ if (handle_id >= handles_.size()) {
+ handles_.resize(handle_id + 1);
+ alive_.resize(handle_id + 1, 0);
+ }
+ if (alive_[handle_id]) {
+ release_handle(handle_id, pool);
+ }
+ KVHandle stored;
+ stored.token_count = ntoken;
+ stored.tokens.assign(tokens, tokens + ntoken);
+ stored.layer_pages = handle.layer_pages;
+ size_t pages_needed = ceil_div(ntoken, pool.page_len());
+ for (size_t i = 0; i < stored.layer_pages.size(); ++i) {
+ if (stored.layer_pages[i].size() > pages_needed) {
+ stored.layer_pages[i].resize(pages_needed);
+ }
+ for (size_t page_id : stored.layer_pages[i]) {
+ pool.incref(i, page_id);
+ }
+ }
+ auto hashes = prefix_hashes(tokens, ntoken);
+ stored.hash_keys.reserve(ntoken);
+ for (size_t len = 1; len <= ntoken; ++len) {
+ uint64_t key = make_key(hashes[len], len);
+ key_to_handles_[key].push_back(handle_id);
+ stored.hash_keys.push_back(key);
+ }
+ stored.last_access = access_clock_++;
+ handles_[handle_id] = std::move(stored);
+ alive_[handle_id] = 1;
+ enforce_capacity(pool);
+ }
+
+ void release_handle(size_t handle_id, KVCachePool &pool) {
+ if (handle_id >= handles_.size() || !alive_[handle_id]) {
+ return;
+ }
+ auto &h = handles_[handle_id];
+ for (size_t i = 0; i < h.layer_pages.size(); ++i) {
+ for (size_t page_id : h.layer_pages[i]) {
+ pool.decref(i, page_id);
+ }
+ }
+ for (uint64_t key : h.hash_keys) {
+ auto it = key_to_handles_.find(key);
+ if (it == key_to_handles_.end()) {
+ continue;
+ }
+ auto &vec = it->second;
+ vec.erase(std::remove(vec.begin(), vec.end(), handle_id), vec.end());
+ if (vec.empty()) {
+ key_to_handles_.erase(it);
+ }
+ }
+ h.layer_pages.clear();
+ h.tokens.clear();
+ h.hash_keys.clear();
+ alive_[handle_id] = 0;
+ free_handle_ids_.push_back(handle_id);
+ }
+
+private:
+ uint64_t make_key(uint64_t hash, size_t len) const {
+ return hash ^ (salt_ * static_cast(len));
+ }
+
+ std::vector prefix_hashes(const int64_t *tokens, size_t ntoken) const {
+ std::vector hashes(ntoken + 1);
+ uint64_t h = 0;
+ for (size_t i = 0; i < ntoken; ++i) {
+ h = hash_mix(h, tokens[i], base_);
+ hashes[i + 1] = h;
+ }
+ return hashes;
+ }
+
+ size_t acquire_handle_id() {
+ if (!free_handle_ids_.empty()) {
+ size_t id = free_handle_ids_.back();
+ free_handle_ids_.pop_back();
+ return id;
+ }
+ return handles_.size();
+ }
+
+ void enforce_capacity(KVCachePool &pool) {
+ if (max_handles_ == 0) {
+ return;
+ }
+ size_t alive_count = 0;
+ for (uint8_t v : alive_) {
+ alive_count += v;
+ }
+ while (alive_count > max_handles_) {
+ size_t oldest = std::numeric_limits::max();
+ uint64_t oldest_access = std::numeric_limits::max();
+ for (size_t i = 0; i < handles_.size(); ++i) {
+ if (!alive_[i]) {
+ continue;
+ }
+ if (handles_[i].last_access <= oldest_access) {
+ oldest_access = handles_[i].last_access;
+ oldest = i;
+ }
+ }
+ if (oldest == std::numeric_limits::max()) {
+ break;
+ }
+ release_handle(oldest, pool);
+ alive_count -= 1;
+ }
+ }
+
+ uint64_t base_ = 1469598103934665603ull;
+ uint64_t salt_ = 1099511628211ull;
+ size_t max_handles_ = 64;
+ uint64_t access_clock_ = 1;
+ std::unordered_map> key_to_handles_;
+ std::vector handles_;
+ std::vector alive_;
+ std::vector free_handle_ids_;
+};
+
+}
+
+// Qwen2 Model Instance Structure
+struct LlaisysQwen2Model {
+ LlaisysQwen2Meta meta; // model meta info
+ llaisysDeviceType_t device; // device type
+ int device_id; // device id
+
+ LlaisysQwen2Weights weights; // all weight tensors handle
+ std::vector handles; // handles list for unified release
+
+ // attn out bias
+ std::vector attn_o_bias;
+ // MLP gate bias
+ std::vector mlp_gate_bias;
+ // MLP up bias
+ std::vector mlp_up_bias;
+ // MLP down bias
+ std::vector mlp_down_bias;
+ // out bias
+ llaisys::tensor_t out_bias;
+
+ // KV cache
+ KVCachePool kv_pool;
+ PrefixCacheIndex prefix_index;
+ KVHandle active_handle;
+ bool active_valid;
+ size_t cache_len;
+};
+
+static void release_active_handle(LlaisysQwen2Model *model) {
+ if (!model || !model->active_valid) {
+ return;
+ }
+ for (size_t i = 0; i < model->active_handle.layer_pages.size(); ++i) {
+ for (size_t page_id : model->active_handle.layer_pages[i]) {
+ model->kv_pool.decref(i, page_id);
+ }
+ }
+ model->active_handle.layer_pages.assign(model->meta.nlayer, {});
+ model->active_handle.tokens.clear();
+ model->active_handle.token_count = 0;
+ model->active_valid = false;
+}
+
+// init model weights, allocate memory and zero bias
+static void init_weights(LlaisysQwen2Model *model) {
+ const auto &m = model->meta;
+ // input embedding
+ model->weights.in_embed = wrap_tensor(
+ make_tensor(m, model->device, model->device_id, {m.voc, m.hs}),
+ model->handles);
+ // output embedding
+ model->weights.out_embed = wrap_tensor(
+ make_tensor(m, model->device, model->device_id, {m.voc, m.hs}),
+ model->handles);
+ // output layer norm
+ model->weights.out_norm_w = wrap_tensor(
+ make_tensor(m, model->device, model->device_id, {m.hs}),
+ model->handles);
+
+ // allocate ptr array for each layer
+ model->weights.attn_norm_w = new llaisysTensor_t[m.nlayer];
+ model->weights.attn_q_w = new llaisysTensor_t[m.nlayer];
+ model->weights.attn_q_b = new llaisysTensor_t[m.nlayer];
+ model->weights.attn_k_w = new llaisysTensor_t[m.nlayer];
+ model->weights.attn_k_b = new llaisysTensor_t[m.nlayer];
+ model->weights.attn_v_w = new llaisysTensor_t[m.nlayer];
+ model->weights.attn_v_b = new llaisysTensor_t[m.nlayer];
+ model->weights.attn_o_w = new llaisysTensor_t[m.nlayer];
+ model->weights.mlp_norm_w = new llaisysTensor_t[m.nlayer];
+ model->weights.mlp_gate_w = new llaisysTensor_t[m.nlayer];
+ model->weights.mlp_up_w = new llaisysTensor_t[m.nlayer];
+ model->weights.mlp_down_w = new llaisysTensor_t[m.nlayer];
+
+ // bias
+ model->attn_o_bias.resize(m.nlayer);
+ model->mlp_gate_bias.resize(m.nlayer);
+ model->mlp_up_bias.resize(m.nlayer);
+ model->mlp_down_bias.resize(m.nlayer);
+
+ // for each layer, create its weight and bias tensor
+ // first weight tensor
+ for (size_t i = 0; i < m.nlayer; ++i) {
+ // attn layer norm
+ model->weights.attn_norm_w[i] = wrap_tensor(
+ make_tensor(m, model->device, model->device_id, {m.hs}),
+ model->handles);
+ // Q/K/V/O proj weight and bias
+ model->weights.attn_q_w[i] = wrap_tensor(
+ make_tensor(m, model->device, model->device_id, {m.nh * m.dh, m.hs}),
+ model->handles);
+ model->weights.attn_q_b[i] = wrap_tensor(
+ make_tensor(m, model->device, model->device_id, {m.nh * m.dh}),
+ model->handles);
+ model->weights.attn_k_w[i] = wrap_tensor(
+ make_tensor(m, model->device, model->device_id, {m.nkvh * m.dh, m.hs}),
+ model->handles);
+ model->weights.attn_k_b[i] = wrap_tensor(
+ make_tensor(m, model->device, model->device_id, {m.nkvh * m.dh}),
+ model->handles);
+ model->weights.attn_v_w[i] = wrap_tensor(
+ make_tensor(m, model->device, model->device_id, {m.nkvh * m.dh, m.hs}),
+ model->handles);
+ model->weights.attn_v_b[i] = wrap_tensor(
+ make_tensor(m, model->device, model->device_id, {m.nkvh * m.dh}),
+ model->handles);
+ model->weights.attn_o_w[i] = wrap_tensor(
+ make_tensor(m, model->device, model->device_id, {m.hs, m.nh * m.dh}),
+ model->handles);
+
+ // MLP layer norm and gate/up/down weight
+ model->weights.mlp_norm_w[i] = wrap_tensor(
+ make_tensor(m, model->device, model->device_id, {m.hs}),
+ model->handles);
+ model->weights.mlp_gate_w[i] = wrap_tensor(
+ make_tensor(m, model->device, model->device_id, {m.di, m.hs}),
+ model->handles);
+ model->weights.mlp_up_w[i] = wrap_tensor(
+ make_tensor(m, model->device, model->device_id, {m.di, m.hs}),
+ model->handles);
+ model->weights.mlp_down_w[i] = wrap_tensor(
+ make_tensor(m, model->device, model->device_id, {m.hs, m.di}),
+ model->handles);
+
+ // bias tensor
+ model->attn_o_bias[i] = make_tensor(m, model->device, model->device_id, {m.hs});
+ model->mlp_gate_bias[i] = make_tensor(m, model->device, model->device_id, {m.di});
+ model->mlp_up_bias[i] = make_tensor(m, model->device, model->device_id, {m.di});
+ model->mlp_down_bias[i] = make_tensor(m, model->device, model->device_id, {m.hs});
+
+ // initialize bias tensor to 0
+ zero_tensor(model->weights.attn_q_b[i]->tensor);
+ zero_tensor(model->weights.attn_k_b[i]->tensor);
+ zero_tensor(model->weights.attn_v_b[i]->tensor);
+ zero_tensor(model->attn_o_bias[i]);
+ zero_tensor(model->mlp_gate_bias[i]);
+ zero_tensor(model->mlp_up_bias[i]);
+ zero_tensor(model->mlp_down_bias[i]);
+ }
+
+ // output layerbias
+ model->out_bias = make_tensor(m, model->device, model->device_id, {m.voc});
+ zero_tensor(model->out_bias);
+}
+
+// initialize kv cache tensor
+static void init_cache(LlaisysQwen2Model *model) {
+ const auto &m = model->meta;
+ size_t page_len = read_env_size_t("LLAISYS_KV_PAGE_LEN", 128);
+ size_t max_pages = read_env_size_t("LLAISYS_KV_MAX_PAGES", ceil_div(m.maxseq, page_len));
+ size_t max_handles = read_env_size_t("LLAISYS_KV_MAX_HANDLES", 64);
+ model->kv_pool.init(m, model->device, model->device_id, page_len, max_pages);
+ model->prefix_index.init(max_handles);
+ model->active_handle.layer_pages.assign(m.nlayer, {});
+ model->active_handle.tokens.clear();
+ model->active_handle.token_count = 0;
+ model->active_valid = false;
+ model->cache_len = 0;
+}
+
+// inference implementation
+static int64_t infer_impl(LlaisysQwen2Model *model, const int64_t *token_ids, size_t ntoken, float temperature,
+ size_t topK, float topP, int64_t seed) {
+ CHECK_ARGUMENT(model != nullptr, "model is null");
+ CHECK_ARGUMENT(token_ids != nullptr || ntoken == 0, "token_ids is null");
+ CHECK_ARGUMENT(model->device == LLAISYS_DEVICE_CPU || model->device == LLAISYS_DEVICE_NVIDIA,
+ "Unsupported device type.");
+
+ if (ntoken == 0) {
+ return model->meta.end_token;
+ }
+
+ size_t reuse_len = 0;
+ if (ntoken > 1) {
+ release_active_handle(model);
+ reuse_len = model->prefix_index.find_longest_prefix(token_ids, ntoken, model->kv_pool, model->active_handle);
+ if (reuse_len > 0) {
+ model->active_valid = true;
+ for (size_t i = 0; i < model->active_handle.layer_pages.size(); ++i) {
+ for (size_t page_id : model->active_handle.layer_pages[i]) {
+ model->kv_pool.incref(i, page_id);
+ }
+ }
+ }
+ if (reuse_len >= ntoken) {
+ reuse_len = ntoken - 1;
+ }
+ if (model->active_valid) {
+ size_t max_pages = ceil_div(reuse_len, model->kv_pool.page_len());
+ for (size_t i = 0; i < model->active_handle.layer_pages.size(); ++i) {
+ while (model->active_handle.layer_pages[i].size() > max_pages) {
+ size_t page_id = model->active_handle.layer_pages[i].back();
+ model->active_handle.layer_pages[i].pop_back();
+ model->kv_pool.decref(i, page_id);
+ }
+ }
+ model->active_handle.token_count = reuse_len;
+ model->active_handle.tokens.assign(token_ids, token_ids + reuse_len);
+ }
+ model->cache_len = reuse_len;
+ }
+
+ size_t seqlen = ntoken - reuse_len;
+ size_t pos_offset = model->cache_len;
+
+ // position ID [pos_offset, pos_offset + seqlen)
+ std::vector pos_ids(seqlen);
+ for (size_t i = 0; i < seqlen; ++i) {
+ pos_ids[i] = static_cast(pos_offset + i);
+ }
+
+ // input token and position ID tensors
+ auto input_ids = make_tensor_dtype(LLAISYS_DTYPE_I64, model->device, model->device_id, {seqlen});
+ input_ids->load(token_ids + reuse_len);
+
+ auto pos_tensor = make_tensor_dtype(LLAISYS_DTYPE_I64, model->device, model->device_id, {seqlen});
+ pos_tensor->load(pos_ids.data());
+
+ // embedding lookup
+ auto x = make_tensor(model->meta, model->device, model->device_id, {seqlen, model->meta.hs});
+ llaisys::ops::embedding(x, input_ids, model->weights.in_embed->tensor);
+
+ // attn scale factor
+ float scale = 1.0f / std::sqrt(static_cast(model->meta.dh));
+
+ size_t total_len = model->cache_len + seqlen;
+ size_t page_len = model->kv_pool.page_len();
+ size_t pages_needed = ceil_div(total_len, page_len);
+ if (model->active_handle.layer_pages.size() != model->meta.nlayer) {
+ model->active_handle.layer_pages.assign(model->meta.nlayer, {});
+ }
+ for (size_t i = 0; i < model->meta.nlayer; ++i) {
+ while (model->active_handle.layer_pages[i].size() < pages_needed) {
+ size_t page_id = model->kv_pool.acquire_page(i);
+ model->kv_pool.incref(i, page_id);
+ model->active_handle.layer_pages[i].push_back(page_id);
+ model->active_valid = true;
+ }
+ }
+
+ // layer forward
+ for (size_t i = 0; i < model->meta.nlayer; ++i) {
+ // attn input norm
+ auto x_norm = make_tensor(model->meta, model->device, model->device_id, {seqlen, model->meta.hs});
+ llaisys::ops::rms_norm(x_norm, x, model->weights.attn_norm_w[i]->tensor, model->meta.epsilon);
+
+ // Q/K/V linear proj
+ auto q = make_tensor(model->meta, model->device, model->device_id, {seqlen, model->meta.nh * model->meta.dh});
+ auto k = make_tensor(model->meta, model->device, model->device_id, {seqlen, model->meta.nkvh * model->meta.dh});
+ auto v = make_tensor(model->meta, model->device, model->device_id, {seqlen, model->meta.nkvh * model->meta.dh});
+
+ llaisys::ops::linear(q, x_norm, model->weights.attn_q_w[i]->tensor, model->weights.attn_q_b[i]->tensor);
+ llaisys::ops::linear(k, x_norm, model->weights.attn_k_w[i]->tensor, model->weights.attn_k_b[i]->tensor);
+ llaisys::ops::linear(v, x_norm, model->weights.attn_v_w[i]->tensor, model->weights.attn_v_b[i]->tensor);
+
+ // transform to multi-head dim
+ auto q_view = q->view({seqlen, model->meta.nh, model->meta.dh});
+ auto k_view = k->view({seqlen, model->meta.nkvh, model->meta.dh});
+ auto v_view = v->view({seqlen, model->meta.nkvh, model->meta.dh});
+
+ // RoPE
+ auto q_rope = make_tensor(model->meta, model->device, model->device_id, {seqlen, model->meta.nh, model->meta.dh});
+ auto k_rope = make_tensor(model->meta, model->device, model->device_id, {seqlen, model->meta.nkvh, model->meta.dh});
+
+ llaisys::ops::rope(q_rope, q_view, pos_tensor, model->meta.theta);
+ llaisys::ops::rope(k_rope, k_view, pos_tensor, model->meta.theta);
+
+ size_t write_offset = model->cache_len;
+ size_t remaining = seqlen;
+ size_t src_offset = 0;
+ while (remaining > 0) {
+ size_t page_index = write_offset / page_len;
+ size_t page_offset = write_offset % page_len;
+ size_t chunk = std::min(remaining, page_len - page_offset);
+ size_t page_id = model->active_handle.layer_pages[i][page_index];
+ auto k_page = model->kv_pool.k_page(i, page_id);
+ auto v_page = model->kv_pool.v_page(i, page_id);
+ auto k_page_slice = k_page->slice(0, page_offset, page_offset + chunk);
+ auto v_page_slice = v_page->slice(0, page_offset, page_offset + chunk);
+ auto k_chunk = k_rope->slice(0, src_offset, src_offset + chunk);
+ auto v_chunk = v_view->slice(0, src_offset, src_offset + chunk);
+ llaisys::ops::rearrange(k_page_slice, k_chunk);
+ llaisys::ops::rearrange(v_page_slice, v_chunk);
+ write_offset += chunk;
+ src_offset += chunk;
+ remaining -= chunk;
+ }
+
+ auto k_total = make_tensor(model->meta, model->device, model->device_id, {total_len, model->meta.nkvh, model->meta.dh});
+ auto v_total = make_tensor(model->meta, model->device, model->device_id, {total_len, model->meta.nkvh, model->meta.dh});
+ size_t read_offset = 0;
+ for (size_t page_index = 0; page_index < pages_needed; ++page_index) {
+ size_t chunk = std::min(page_len, total_len - read_offset);
+ size_t page_id = model->active_handle.layer_pages[i][page_index];
+ auto k_page = model->kv_pool.k_page(i, page_id);
+ auto v_page = model->kv_pool.v_page(i, page_id);
+ auto k_page_slice = k_page->slice(0, 0, chunk);
+ auto v_page_slice = v_page->slice(0, 0, chunk);
+ auto k_total_slice = k_total->slice(0, read_offset, read_offset + chunk);
+ auto v_total_slice = v_total->slice(0, read_offset, read_offset + chunk);
+ llaisys::ops::rearrange(k_total_slice, k_page_slice);
+ llaisys::ops::rearrange(v_total_slice, v_page_slice);
+ read_offset += chunk;
+ }
+
+ // self attn
+ auto attn = make_tensor(model->meta, model->device, model->device_id, {seqlen, model->meta.nh, model->meta.dh});
+ llaisys::ops::self_attention(attn, q_rope, k_total, v_total, scale);
+
+ // attn out proj
+ auto attn_flat = attn->view({seqlen, model->meta.nh * model->meta.dh});
+ auto attn_proj = make_tensor(model->meta, model->device, model->device_id, {seqlen, model->meta.hs});
+ llaisys::ops::linear(attn_proj, attn_flat, model->weights.attn_o_w[i]->tensor, model->attn_o_bias[i]);
+
+ // first residual conn
+ auto res1 = make_tensor(model->meta, model->device, model->device_id, {seqlen, model->meta.hs});
+ llaisys::ops::add(res1, x, attn_proj);
+
+ // MLP input norm
+ auto x_norm2 = make_tensor(model->meta, model->device, model->device_id, {seqlen, model->meta.hs});
+ llaisys::ops::rms_norm(x_norm2, res1, model->weights.mlp_norm_w[i]->tensor, model->meta.epsilon);
+
+ // MLP gate and up proj
+ auto gate = make_tensor(model->meta, model->device, model->device_id, {seqlen, model->meta.di});
+ auto up = make_tensor(model->meta, model->device, model->device_id, {seqlen, model->meta.di});
+
+ llaisys::ops::linear(gate, x_norm2, model->weights.mlp_gate_w[i]->tensor, model->mlp_gate_bias[i]);
+ llaisys::ops::linear(up, x_norm2, model->weights.mlp_up_w[i]->tensor, model->mlp_up_bias[i]);
+
+ // SwiGLU activate
+ auto swiglu_out = make_tensor(model->meta, model->device, model->device_id, {seqlen, model->meta.di});
+ llaisys::ops::swiglu(swiglu_out, gate, up);
+
+ // MLP down proj
+ auto down = make_tensor(model->meta, model->device, model->device_id, {seqlen, model->meta.hs});
+ llaisys::ops::linear(down, swiglu_out, model->weights.mlp_down_w[i]->tensor, model->mlp_down_bias[i]);
+
+ // second residual conn
+ auto res2 = make_tensor(model->meta, model->device, model->device_id, {seqlen, model->meta.hs});
+ llaisys::ops::add(res2, res1, down);
+
+ x = res2;
+ }
+
+ // update cache len
+ model->cache_len += seqlen;
+ model->active_handle.token_count = model->cache_len;
+ if (model->active_valid) {
+ if (model->active_handle.tokens.size() < model->cache_len) {
+ model->active_handle.tokens.insert(
+ model->active_handle.tokens.end(),
+ token_ids + reuse_len,
+ token_ids + reuse_len + seqlen);
+ }
+ }
+ if (ntoken > 1 && model->active_valid) {
+ model->prefix_index.insert_handle(model->active_handle, token_ids, ntoken, model->kv_pool);
+ }
+
+ // final norm
+ auto x_norm = make_tensor(model->meta, model->device, model->device_id, {seqlen, model->meta.hs});
+ llaisys::ops::rms_norm(x_norm, x, model->weights.out_norm_w->tensor, model->meta.epsilon);
+
+ // output proj to vocab
+ auto logits = make_tensor(model->meta, model->device, model->device_id, {seqlen, model->meta.voc});
+ llaisys::ops::linear(logits, x_norm, model->weights.out_embed->tensor, model->out_bias);
+
+ // get last token logits and sample
+ auto last = logits->slice(0, seqlen - 1, seqlen)->view({model->meta.voc});
+ auto sample_idx = make_tensor_dtype(LLAISYS_DTYPE_I64, model->device, model->device_id, {1});
+ auto sample_val = make_tensor(model->meta, model->device, model->device_id, {1});
+ llaisys::ops::rand_sample(sample_idx, sample_val, last, temperature, topK, topP, seed);
+
+ if (model->device == LLAISYS_DEVICE_CPU) {
+ return reinterpret_cast(sample_idx->data())[0];
+ }
+ int64_t host_value = 0;
+ llaisys::core::context().setDevice(model->device, model->device_id);
+ llaisys::core::context().runtime().api()->memcpy_sync(
+ &host_value,
+ sample_idx->data(),
+ sizeof(int64_t),
+ LLAISYS_MEMCPY_D2H);
+ return host_value;
+}
+
+// C API wrapper
+__C {
+
+// ModelCreate: Check args, initialize weights and cache
+struct LlaisysQwen2Model *llaisysQwen2ModelCreate(
+ const LlaisysQwen2Meta *meta,
+ llaisysDeviceType_t device,
+ int *device_ids,
+ int ndevice) {
+
+ auto *model = new LlaisysQwen2Model();
+ model->meta = *meta;
+ model->device = device;
+ model->device_id = device_ids ? device_ids[0] : 0;
+ init_weights(model);
+ init_cache(model);
+ return model;
+}
+
+// ModelDestroy: Free weights and cache, all handles
+void llaisysQwen2ModelDestroy(struct LlaisysQwen2Model *model) {
+ if (!model) {
+ return;
+ }
+ delete[] model->weights.attn_norm_w;
+ delete[] model->weights.attn_q_w;
+ delete[] model->weights.attn_q_b;
+ delete[] model->weights.attn_k_w;
+ delete[] model->weights.attn_k_b;
+ delete[] model->weights.attn_v_w;
+ delete[] model->weights.attn_v_b;
+ delete[] model->weights.attn_o_w;
+ delete[] model->weights.mlp_norm_w;
+ delete[] model->weights.mlp_gate_w;
+ delete[] model->weights.mlp_up_w;
+ delete[] model->weights.mlp_down_w;
+
+ for (auto *h : model->handles) {
+ delete h;
+ }
+ delete model;
+}
+
+// ModelWeights: Get weights pointer for loading pretrained params
+struct LlaisysQwen2Weights *llaisysQwen2ModelWeights(struct LlaisysQwen2Model *model) {
+ return &model->weights;
+}
+
+// ModelInfer: Single interface for single token prediction
+int64_t llaisysQwen2ModelInfer(struct LlaisysQwen2Model *model, int64_t *token_ids, size_t ntoken, float temperature,
+ size_t topK, float topP, int64_t seed) {
+ return infer_impl(model, token_ids, ntoken, temperature, topK, topP, seed);
+}
+
+}
diff --git a/src/ops/add/nvidia/add_nvidia.cu b/src/ops/add/nvidia/add_nvidia.cu
new file mode 100644
index 00000000..e7307d64
--- /dev/null
+++ b/src/ops/add/nvidia/add_nvidia.cu
@@ -0,0 +1,54 @@
+#include "add_nvidia.hpp"
+
+#include "../../nvidia_utils.cuh"
+#include "../../../core/llaisys_core.hpp"
+
+namespace llaisys::ops::nvidia {
+template
+__global__ void add_kernel(T *c, const T *a, const T *b, size_t numel) {
+ size_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x;
+ if (idx >= numel) {
+ return;
+ }
+ if constexpr (std::is_same_v || std::is_same_v) {
+ float av = detail::to_float(a[idx]);
+ float bv = detail::to_float(b[idx]);
+ c[idx] = detail::from_float(av + bv);
+ } else {
+ c[idx] = a[idx] + b[idx];
+ }
+}
+
+void add(std::byte *c, const std::byte *a, const std::byte *b, llaisysDataType_t type, size_t numel) {
+ int threads = 256;
+ int blocks = static_cast((numel + threads - 1) / threads);
+ auto stream = reinterpret_cast(llaisys::core::context().runtime().stream());
+ switch (type) {
+ case LLAISYS_DTYPE_F32:
+ add_kernel<<>>(
+ reinterpret_cast(c),
+ reinterpret_cast(a),
+ reinterpret_cast(b),
+ numel);
+ break;
+ case LLAISYS_DTYPE_BF16:
+ add_kernel<<>>(
+ reinterpret_cast(c),
+ reinterpret_cast(a),
+ reinterpret_cast(b),
+ numel);
+ break;
+ case LLAISYS_DTYPE_F16:
+ add_kernel<<>>(
+ reinterpret_cast(c),
+ reinterpret_cast(a),
+ reinterpret_cast(b),
+ numel);
+ break;
+ default:
+ EXCEPTION_UNSUPPORTED_DATATYPE(type);
+ }
+ detail::checkCuda(cudaGetLastError());
+ detail::checkCuda(cudaStreamSynchronize(stream));
+}
+}
diff --git a/src/ops/add/nvidia/add_nvidia.hpp b/src/ops/add/nvidia/add_nvidia.hpp
new file mode 100644
index 00000000..2e67b497
--- /dev/null
+++ b/src/ops/add/nvidia/add_nvidia.hpp
@@ -0,0 +1,9 @@
+#pragma once
+
+#include "llaisys.h"
+
+#include
+
+namespace llaisys::ops::nvidia {
+void add(std::byte *c, const std::byte *a, const std::byte *b, llaisysDataType_t type, size_t numel);
+}
diff --git a/src/ops/add/op.cpp b/src/ops/add/op.cpp
index a057330d..9b36ba8e 100644
--- a/src/ops/add/op.cpp
+++ b/src/ops/add/op.cpp
@@ -4,6 +4,7 @@
#include "../../utils.hpp"
#include "cpu/add_cpu.hpp"
+#include "nvidia/add_nvidia.hpp"
namespace llaisys::ops {
void add(tensor_t c, tensor_t a, tensor_t b) {
@@ -25,8 +26,7 @@ void add(tensor_t c, tensor_t a, tensor_t b) {
return cpu::add(c->data(), a->data(), b->data(), c->dtype(), c->numel());
#ifdef ENABLE_NVIDIA_API
case LLAISYS_DEVICE_NVIDIA:
- TO_BE_IMPLEMENTED();
- return;
+ return nvidia::add(c->data(), a->data(), b->data(), c->dtype(), c->numel());
#endif
default:
EXCEPTION_UNSUPPORTED_DEVICE;
diff --git a/src/ops/argmax/cpu/argmax_cpu.cpp b/src/ops/argmax/cpu/argmax_cpu.cpp
new file mode 100644
index 00000000..ca0eb53d
--- /dev/null
+++ b/src/ops/argmax/cpu/argmax_cpu.cpp
@@ -0,0 +1,47 @@
+#include "argmax_cpu.hpp"
+
+#include "../../../utils.hpp"
+
+#include
+
+template
+static void argmax_(std::byte *max_idx, std::byte *max_val, const std::byte *vals, size_t numel) {
+ size_t max_i = 0;
+ const T *v = reinterpret_cast(vals);
+ if constexpr (std::is_same_v || std::is_same_v) {
+ float max_f = llaisys::utils::cast(v[0]);
+ for (size_t i = 1; i < numel; ++i) {
+ float cur = llaisys::utils::cast(v[i]);
+ if (cur > max_f) {
+ max_i = i;
+ max_f = cur;
+ }
+ }
+ *reinterpret_cast(max_val) = llaisys::utils::cast(max_f);
+ } else {
+ T max_v = v[0];
+ for (size_t i = 1; i < numel; ++i) {
+ if (v[i] > max_v) {
+ max_i = i;
+ max_v = v[i];
+ }
+ }
+ *reinterpret_cast(max_val) = max_v;
+ }
+ *reinterpret_cast(max_idx) = static_cast(max_i);
+}
+
+namespace llaisys::ops::cpu {
+void argmax(std::byte *max_idx, std::byte *max_val, const std::byte *vals, llaisysDataType_t type, size_t numel) {
+ switch (type) {
+ case LLAISYS_DTYPE_F32:
+ return argmax_(max_idx, max_val, vals, numel);
+ case LLAISYS_DTYPE_BF16:
+ return argmax_(max_idx, max_val, vals, numel);
+ case LLAISYS_DTYPE_F16:
+ return argmax_(max_idx, max_val, vals, numel);
+ default:
+ EXCEPTION_UNSUPPORTED_DATATYPE(type);
+ }
+}
+}
\ No newline at end of file
diff --git a/src/ops/argmax/cpu/argmax_cpu.hpp b/src/ops/argmax/cpu/argmax_cpu.hpp
new file mode 100644
index 00000000..1f3224cb
--- /dev/null
+++ b/src/ops/argmax/cpu/argmax_cpu.hpp
@@ -0,0 +1,8 @@
+#pragma once
+#include "llaisys.h"
+
+#include
+
+namespace llaisys::ops::cpu {
+void argmax(std::byte *max_idx, std::byte *max_val, const std::byte *vals, llaisysDataType_t type, size_t numel);
+}
\ No newline at end of file
diff --git a/src/ops/argmax/nvidia/argmax_nvidia.cu b/src/ops/argmax/nvidia/argmax_nvidia.cu
new file mode 100644
index 00000000..79a51059
--- /dev/null
+++ b/src/ops/argmax/nvidia/argmax_nvidia.cu
@@ -0,0 +1,66 @@
+#include "argmax_nvidia.hpp"
+
+#include "../../nvidia_utils.cuh"
+#include "../../../core/llaisys_core.hpp"
+
+namespace llaisys::ops::nvidia {
+template
+__global__ void argmax_kernel(int64_t *max_idx, T *max_val, const T *vals, size_t numel) {
+ if (blockIdx.x != 0 || threadIdx.x != 0) {
+ return;
+ }
+ size_t max_i = 0;
+ if constexpr (std::is_same_v || std::is_same_v) {
+ float max_f = detail::to_float(vals[0]);
+ for (size_t i = 1; i < numel; ++i) {
+ float cur = detail::to_float(vals[i]);
+ if (cur > max_f) {
+ max_f = cur;
+ max_i = i;
+ }
+ }
+ *max_val = detail::from_float(max_f);
+ } else {
+ T max_v = vals[0];
+ for (size_t i = 1; i < numel; ++i) {
+ if (vals[i] > max_v) {
+ max_v = vals[i];
+ max_i = i;
+ }
+ }
+ *max_val = max_v;
+ }
+ *max_idx = static_cast(max_i);
+}
+
+void argmax(std::byte *max_idx, std::byte *max_val, const std::byte *vals, llaisysDataType_t type, size_t numel) {
+ auto stream = reinterpret_cast(llaisys::core::context().runtime().stream());
+ switch (type) {
+ case LLAISYS_DTYPE_F32:
+ argmax_kernel<<<1, 1, 0, stream>>>(
+ reinterpret_cast(max_idx),
+ reinterpret_cast(max_val),
+ reinterpret_cast(vals),
+ numel);
+ break;
+ case LLAISYS_DTYPE_BF16:
+ argmax_kernel<<<1, 1, 0, stream>>>(
+ reinterpret_cast(max_idx),
+ reinterpret_cast(max_val),
+ reinterpret_cast(vals),
+ numel);
+ break;
+ case LLAISYS_DTYPE_F16:
+ argmax_kernel<<<1, 1, 0, stream>>>(
+ reinterpret_cast(max_idx),
+ reinterpret_cast(max_val),
+ reinterpret_cast(vals),
+ numel);
+ break;
+ default:
+ EXCEPTION_UNSUPPORTED_DATATYPE(type);
+ }
+ detail::checkCuda(cudaGetLastError());
+ detail::checkCuda(cudaStreamSynchronize(stream));
+}
+}
diff --git a/src/ops/argmax/nvidia/argmax_nvidia.hpp b/src/ops/argmax/nvidia/argmax_nvidia.hpp
new file mode 100644
index 00000000..054fa353
--- /dev/null
+++ b/src/ops/argmax/nvidia/argmax_nvidia.hpp
@@ -0,0 +1,9 @@
+#pragma once
+
+#include "llaisys.h"
+
+#include
+
+namespace llaisys::ops::nvidia {
+void argmax(std::byte *max_idx, std::byte *max_val, const std::byte *vals, llaisysDataType_t type, size_t numel);
+}
diff --git a/src/ops/argmax/op.cpp b/src/ops/argmax/op.cpp
index 6dc37d42..3482db29 100644
--- a/src/ops/argmax/op.cpp
+++ b/src/ops/argmax/op.cpp
@@ -1,7 +1,36 @@
#include "op.hpp"
+#include "../../core/llaisys_core.hpp"
+#include "../../utils.hpp"
+
+#include "cpu/argmax_cpu.hpp"
+#include "nvidia/argmax_nvidia.hpp"
+
namespace llaisys::ops {
void argmax(tensor_t max_idx, tensor_t max_val, tensor_t vals) {
- TO_BE_IMPLEMENTED();
+ CHECK_SAME_DEVICE(max_idx, max_val, vals);
+ ASSERT(vals->ndim() == 1, "Argmax: vals must be 1D.");
+ ASSERT(max_idx->ndim() == 1 && max_idx->shape()[0] == 1, "Argmax: max_idx shape must be (1, ).");
+ ASSERT(max_val->ndim() == 1 && max_val->shape()[0] == 1, "Argmax: max_val shape must be (1, ).");
+ CHECK_SAME_DTYPE(max_val->dtype(), vals->dtype());
+ ASSERT(max_idx->dtype() == LLAISYS_DTYPE_I64, "Argmax: max_idx must be I64.");
+ ASSERT(vals->isContiguous() && max_idx->isContiguous() && max_val->isContiguous(), "Argmax: all tensors must be contiguous.");
+
+ if (vals->deviceType() == LLAISYS_DEVICE_CPU) {
+ return cpu::argmax(max_idx->data(), max_val->data(), vals->data(), vals->dtype(), vals->numel());
+ }
+
+ llaisys::core::context().setDevice(vals->deviceType(), vals->deviceId());
+
+ switch (vals->deviceType()) {
+ case LLAISYS_DEVICE_CPU:
+ return cpu::argmax(max_idx->data(), max_val->data(), vals->data(), vals->dtype(), vals->numel());
+#ifdef ENABLE_NVIDIA_API
+ case LLAISYS_DEVICE_NVIDIA:
+ return nvidia::argmax(max_idx->data(), max_val->data(), vals->data(), vals->dtype(), vals->numel());
+#endif
+ default:
+ EXCEPTION_UNSUPPORTED_DEVICE;
+ }
}
} // namespace llaisys::ops
diff --git a/src/ops/embedding/cpu/embedding_cpu.cpp b/src/ops/embedding/cpu/embedding_cpu.cpp
new file mode 100644
index 00000000..e87fb5e2
--- /dev/null
+++ b/src/ops/embedding/cpu/embedding_cpu.cpp
@@ -0,0 +1,33 @@
+#include "embedding_cpu.hpp"
+
+#include "../../../utils.hpp"
+
+template
+static void embedding_(std::byte *out, const std::byte *index, const std::byte *weight, size_t index_len, size_t row_len) {
+ const auto idx = reinterpret_cast(index);
+ const auto w = reinterpret_cast(weight);
+ auto o = reinterpret_cast(out);
+ for (size_t i = 0; i < index_len; ++i) {
+ const int64_t row = idx[i];
+ const T *src = w + row * row_len;
+ T *dst = o + i * row_len;
+ for (size_t j = 0; j < row_len; ++j) {
+ dst[j] = src[j];
+ }
+ }
+}
+
+namespace llaisys::ops::cpu {
+void embedding(std::byte *out, const std::byte *index, const std::byte *weight, llaisysDataType_t type, size_t index_len, size_t row_len) {
+ switch (type) {
+ case LLAISYS_DTYPE_F32:
+ return embedding_(out, index, weight, index_len, row_len);
+ case LLAISYS_DTYPE_BF16:
+ return embedding_(out, index, weight, index_len, row_len);
+ case LLAISYS_DTYPE_F16:
+ return embedding_(out, index, weight, index_len, row_len);
+ default:
+ EXCEPTION_UNSUPPORTED_DATATYPE(type);
+ }
+}
+}
\ No newline at end of file
diff --git a/src/ops/embedding/cpu/embedding_cpu.hpp b/src/ops/embedding/cpu/embedding_cpu.hpp
new file mode 100644
index 00000000..9aec5020
--- /dev/null
+++ b/src/ops/embedding/cpu/embedding_cpu.hpp
@@ -0,0 +1,8 @@
+#pragma once
+#include "llaisys.h"
+
+#include
+
+namespace llaisys::ops::cpu {
+void embedding(std::byte *out, const std::byte *index, const std::byte *weight, llaisysDataType_t type, size_t index_len, size_t row_len);
+}
\ No newline at end of file
diff --git a/src/ops/embedding/nvidia/embedding_nvidia.cu b/src/ops/embedding/nvidia/embedding_nvidia.cu
new file mode 100644
index 00000000..c8832472
--- /dev/null
+++ b/src/ops/embedding/nvidia/embedding_nvidia.cu
@@ -0,0 +1,56 @@
+#include "embedding_nvidia.hpp"
+
+#include "../../nvidia_utils.cuh"
+#include "../../../core/llaisys_core.hpp"
+
+namespace llaisys::ops::nvidia {
+template
+__global__ void embedding_kernel(T *out, const int64_t *index, const T *weight, size_t index_len, size_t row_len) {
+ size_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x;
+ size_t total = index_len * row_len;
+ if (idx >= total) {
+ return;
+ }
+ size_t i = idx / row_len;
+ size_t j = idx - i * row_len;
+ int64_t row = index[i];
+ out[idx] = weight[static_cast(row) * row_len + j];
+}
+
+void embedding(std::byte *out, const std::byte *index, const std::byte *weight, llaisysDataType_t type, size_t index_len, size_t row_len) {
+ size_t total = index_len * row_len;
+ int threads = 256;
+ int blocks = static_cast((total + threads - 1) / threads);
+ auto stream = reinterpret_cast(llaisys::core::context().runtime().stream());
+ switch (type) {
+ case LLAISYS_DTYPE_F32:
+ embedding_kernel<<>>(
+ reinterpret_cast(out),
+ reinterpret_cast(index),
+ reinterpret_cast(weight),
+ index_len,
+ row_len);
+ break;
+ case LLAISYS_DTYPE_BF16:
+ embedding_kernel<<>>(
+ reinterpret_cast(out),
+ reinterpret_cast(index),
+ reinterpret_cast(weight),
+ index_len,
+ row_len);
+ break;
+ case LLAISYS_DTYPE_F16:
+ embedding_kernel<<>>(
+ reinterpret_cast(out),
+ reinterpret_cast(index),
+ reinterpret_cast(weight),
+ index_len,
+ row_len);
+ break;
+ default:
+ EXCEPTION_UNSUPPORTED_DATATYPE(type);
+ }
+ detail::checkCuda(cudaGetLastError());
+ detail::checkCuda(cudaStreamSynchronize(stream));
+}
+}
diff --git a/src/ops/embedding/nvidia/embedding_nvidia.hpp b/src/ops/embedding/nvidia/embedding_nvidia.hpp
new file mode 100644
index 00000000..29ce3a28
--- /dev/null
+++ b/src/ops/embedding/nvidia/embedding_nvidia.hpp
@@ -0,0 +1,9 @@
+#pragma once
+
+#include "llaisys.h"
+
+#include
+
+namespace llaisys::ops::nvidia {
+void embedding(std::byte *out, const std::byte *index, const std::byte *weight, llaisysDataType_t type, size_t index_len, size_t row_len);
+}
diff --git a/src/ops/embedding/op.cpp b/src/ops/embedding/op.cpp
index 84b9a5d0..01abaac7 100644
--- a/src/ops/embedding/op.cpp
+++ b/src/ops/embedding/op.cpp
@@ -1,7 +1,39 @@
#include "op.hpp"
+#include "../../core/llaisys_core.hpp"
+#include "../../utils.hpp"
+#include "cpu/embedding_cpu.hpp"
+#include "nvidia/embedding_nvidia.hpp"
namespace llaisys::ops {
void embedding(tensor_t out, tensor_t index, tensor_t weight) {
- TO_BE_IMPLEMENTED();
+ CHECK_SAME_DEVICE(out, index, weight);
+ ASSERT(index->ndim() == 1, "Embedding: index must be 1D.");
+ ASSERT(weight->ndim() == 2, "Embedding: weight must be 2D.");
+ ASSERT(out->ndim() == 2, "Embedding: out must be 2D.");
+ ASSERT(index->dtype() == LLAISYS_DTYPE_I64, "Embedding: index must be I64.");
+ CHECK_SAME_DTYPE(out->dtype(), weight->dtype());
+ ASSERT(out->shape()[0] == index->shape()[0], "Embedding: out.shape[0] must equal index.shape[0].");
+ ASSERT(out->shape()[1] == weight->shape()[1], "Embedding: out.shape[1] must equal weight.shape[1].");
+ ASSERT(out->isContiguous() && index->isContiguous() && weight->isContiguous(), "Embedding: all tensors must be contiguous.");
+
+ size_t index_len = index->shape()[0];
+ size_t row_len = weight->shape()[1];
+
+ if (weight->deviceType() == LLAISYS_DEVICE_CPU) {
+ return cpu::embedding(out->data(), index->data(), weight->data(), out->dtype(), index_len, row_len);
+ }
+
+ llaisys::core::context().setDevice(weight->deviceType(), weight->deviceId());
+
+ switch (weight->deviceType()) {
+ case LLAISYS_DEVICE_CPU:
+ return cpu::embedding(out->data(), index->data(), weight->data(), out->dtype(), index_len, row_len);
+#ifdef ENABLE_NVIDIA_API
+ case LLAISYS_DEVICE_NVIDIA:
+ return nvidia::embedding(out->data(), index->data(), weight->data(), out->dtype(), index_len, row_len);
+#endif
+ default:
+ EXCEPTION_UNSUPPORTED_DEVICE;
+ }
}
} // namespace llaisys::ops
diff --git a/src/ops/linear/cpu/linear_cpu.cpp b/src/ops/linear/cpu/linear_cpu.cpp
new file mode 100644
index 00000000..520d8b0e
--- /dev/null
+++ b/src/ops/linear/cpu/linear_cpu.cpp
@@ -0,0 +1,71 @@
+#include "linear_cpu.hpp"
+
+#include "../../../utils.hpp"
+
+template
+static void linear_(T *out, const T *in, const T *W, const T *bias,
+ size_t batch, size_t in_dim, size_t out_dim)
+{
+ for(size_t i = 0; i < batch; ++i)
+ {
+ const T *in_ = in + i*in_dim;
+ T *out_ = out + i*out_dim;
+ for(size_t j = 0; j < out_dim; ++j)
+ {
+ const T *weight_ = W + j*in_dim;
+ if constexpr (std::is_same_v
+ || std::is_same_v)
+ {
+ float acc = 0.0f;
+ for(size_t k = 0; k < in_dim; ++k)
+ {
+ acc += llaisys::utils::cast(in_[k]) *
+ llaisys::utils::cast(weight_[k]);
+ }
+ if(bias)
+ acc += llaisys::utils::cast(bias[j]);
+ out_[j] = llaisys::utils::cast(acc);
+ }
+ else
+ {
+ T acc = static_cast(0);
+ for(size_t k = 0; k < in_dim; ++k)
+ {
+ acc += in_[k] * weight_[k];
+ }
+ if(bias)
+ acc += bias[j];
+ out_[j] = acc;
+ }
+ }
+ }
+
+}
+namespace llaisys::ops::cpu {
+void linear(std::byte *out, const std::byte *in, const std::byte *W,
+ const std::byte *bias, llaisysDataType_t type, size_t batch,
+ size_t in_dim, size_t out_dim) {
+ switch (type) {
+ case LLAISYS_DTYPE_F32:
+ return linear_(reinterpret_cast(out),
+ reinterpret_cast(in),
+ reinterpret_cast(W),
+ reinterpret_cast(bias),
+ batch, in_dim, out_dim);
+ case LLAISYS_DTYPE_BF16:
+ return linear_(reinterpret_cast(out),
+ reinterpret_cast(in),
+ reinterpret_cast(W),
+ reinterpret_cast(bias),
+ batch, in_dim, out_dim);
+ case LLAISYS_DTYPE_F16:
+ return linear_(reinterpret_cast(out),
+ reinterpret_cast(in),
+ reinterpret_cast(W),
+ reinterpret_cast(bias),
+ batch, in_dim, out_dim);
+ default:
+ EXCEPTION_UNSUPPORTED_DATATYPE(type);
+ }
+}
+}
diff --git a/src/ops/linear/cpu/linear_cpu.hpp b/src/ops/linear/cpu/linear_cpu.hpp
new file mode 100644
index 00000000..a6f4e279
--- /dev/null
+++ b/src/ops/linear/cpu/linear_cpu.hpp
@@ -0,0 +1,10 @@
+#pragma once
+#include "llaisys.h"
+
+#include
+
+namespace llaisys::ops::cpu {
+void linear(std::byte *out, const std::byte *in, const std::byte *W,
+ const std::byte *bias, llaisysDataType_t type, size_t batch,
+ size_t in_dim, size_t out_dim);
+}
diff --git a/src/ops/linear/nvidia/linear_nvidia.cu b/src/ops/linear/nvidia/linear_nvidia.cu
new file mode 100644
index 00000000..fc7066f3
--- /dev/null
+++ b/src/ops/linear/nvidia/linear_nvidia.cu
@@ -0,0 +1,75 @@
+#include "linear_nvidia.hpp"
+
+#include "../../nvidia_utils.cuh"
+#include "../../../core/llaisys_core.hpp"
+
+namespace llaisys::ops::nvidia {
+template
+__global__ void linear_kernel(T *out, const T *in, const T *W, const T *bias, size_t batch, size_t in_dim, size_t out_dim) {
+ size_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x;
+ size_t total = batch * out_dim;
+ if (idx >= total) {
+ return;
+ }
+ size_t i = idx / out_dim;
+ size_t j = idx - i * out_dim;
+ const T *in_row = in + i * in_dim;
+ const T *w_row = W + j * in_dim;
+ if constexpr (std::is_same_v || std::is_same_v) {
+ float acc = 0.0f;
+ for (size_t k = 0; k < in_dim; ++k) {
+ acc += detail::to_float(in_row[k]) * detail::to_float(w_row[k]);
+ }
+ if (bias) {
+ acc += detail::to_float(bias[j]);
+ }
+ out[idx] = detail::from_float(acc);
+ } else {
+ T acc = static_cast(0);
+ for (size_t k = 0; k < in_dim; ++k) {
+ acc += in_row[k] * w_row[k];
+ }
+ if (bias) {
+ acc += bias[j];
+ }
+ out[idx] = acc;
+ }
+}
+
+void linear(std::byte *out, const std::byte *in, const std::byte *W, const std::byte *bias, llaisysDataType_t type, size_t batch, size_t in_dim, size_t out_dim) {
+ size_t total = batch * out_dim;
+ int threads = 256;
+ int blocks = static_cast((total + threads - 1) / threads);
+ auto stream = reinterpret_cast(llaisys::core::context().runtime().stream());
+ switch (type) {
+ case LLAISYS_DTYPE_F32:
+ linear_kernel<<>>(
+ reinterpret_cast(out),
+ reinterpret_cast(in),
+ reinterpret_cast(W),
+ reinterpret_cast(bias),
+ batch, in_dim, out_dim);
+ break;
+ case LLAISYS_DTYPE_BF16:
+ linear_kernel<<>>(
+ reinterpret_cast(out),
+ reinterpret_cast(in),
+ reinterpret_cast(W),
+ reinterpret_cast(bias),
+ batch, in_dim, out_dim);
+ break;
+ case LLAISYS_DTYPE_F16:
+ linear_kernel<<>>(
+ reinterpret_cast(out),
+ reinterpret_cast(in),
+ reinterpret_cast(W),
+ reinterpret_cast(bias),
+ batch, in_dim, out_dim);
+ break;
+ default:
+ EXCEPTION_UNSUPPORTED_DATATYPE(type);
+ }
+ detail::checkCuda(cudaGetLastError());
+ detail::checkCuda(cudaStreamSynchronize(stream));
+}
+}
diff --git a/src/ops/linear/nvidia/linear_nvidia.hpp b/src/ops/linear/nvidia/linear_nvidia.hpp
new file mode 100644
index 00000000..f31ea3fb
--- /dev/null
+++ b/src/ops/linear/nvidia/linear_nvidia.hpp
@@ -0,0 +1,9 @@
+#pragma once
+
+#include "llaisys.h"
+
+#include
+
+namespace llaisys::ops::nvidia {
+void linear(std::byte *out, const std::byte *in, const std::byte *W, const std::byte *bias, llaisysDataType_t type, size_t batch, size_t in_dim, size_t out_dim);
+}
diff --git a/src/ops/linear/op.cpp b/src/ops/linear/op.cpp
index 97d1f865..0804ca50 100644
--- a/src/ops/linear/op.cpp
+++ b/src/ops/linear/op.cpp
@@ -1,7 +1,50 @@
#include "op.hpp"
-
+#include "../../core/llaisys_core.hpp"
+#include "../../utils.hpp"
+#include "cpu/linear_cpu.hpp"
+#include "nvidia/linear_nvidia.hpp"
namespace llaisys::ops {
void linear(tensor_t out, tensor_t in, tensor_t weight, tensor_t bias) {
- TO_BE_IMPLEMENTED();
+ ASSERT(out->isContiguous() && in->isContiguous()
+ && weight->isContiguous(), "out, in, weight must be contiguous tensors");
+
+ ASSERT(out->ndim() == 2 && in->ndim() == 2 &&
+ weight->ndim() == 2, "out, in, weight must be 2D tensors");
+ ASSERT(in->shape()[1] == weight->shape()[1],"in.shape[1] must be equal to weight.shape[1]");
+ ASSERT(out->shape()[1] == weight->shape()[0], "out.shape[1] must be equal to weight.shape[0]");
+ CHECK_SAME_DTYPE(out->dtype(), in->dtype(), weight->dtype());
+ if(bias)
+ {
+ CHECK_SAME_DEVICE(out, bias, weight);
+ ASSERT(bias->ndim() == 1, "bias must be 1D tensor");
+ ASSERT(bias->isContiguous(), "bias must be contiguous tensor");
+ ASSERT(bias->shape()[0] == out->shape()[1], "bias.shape[0] must be equal to out.shape[1]");
+ CHECK_SAME_DTYPE(out->dtype(), bias->dtype());
+ }
+
+ size_t batch = out->shape()[0];
+ size_t in_dim = in->shape()[1];
+ size_t out_dim = out->shape()[1];
+
+ const std::byte *bias_data = bias ? bias->data():nullptr;
+
+ if (weight->deviceType() == LLAISYS_DEVICE_CPU) {
+ return cpu::linear(out->data(), in->data(),
+ weight->data(), bias_data, bias->dtype(), batch, in_dim, out_dim);
+ }
+
+ llaisys::core::context().setDevice(weight->deviceType(), weight->deviceId());
+
+ switch (weight->deviceType()) {
+ case LLAISYS_DEVICE_CPU:
+ return cpu::linear(out->data(), in->data(),
+ weight->data(), bias_data, bias->dtype(), batch, in_dim, out_dim);
+#ifdef ENABLE_NVIDIA_API
+ case LLAISYS_DEVICE_NVIDIA:
+ return nvidia::linear(out->data(), in->data(), weight->data(), bias_data, out->dtype(), batch, in_dim, out_dim);
+#endif
+ default:
+ EXCEPTION_UNSUPPORTED_DEVICE;
+ }
}
} // namespace llaisys::ops
diff --git a/src/ops/nvidia_utils.cuh b/src/ops/nvidia_utils.cuh
new file mode 100644
index 00000000..04ed31f2
--- /dev/null
+++ b/src/ops/nvidia_utils.cuh
@@ -0,0 +1,123 @@
+#pragma once
+
+#include "llaisys.h"
+#include "../utils.hpp"
+#include
+#include
+#include
+#include
+#include
+
+namespace llaisys::ops::nvidia::detail {
+
+inline void checkCuda(cudaError_t err) {
+ if (err != cudaSuccess) {
+ throw std::runtime_error(std::string("CUDA error: ") + cudaGetErrorString(err));
+ }
+}
+
+__device__ inline float f16_to_f32(llaisys::fp16_t val) {
+ uint16_t h = val._v;
+ uint32_t sign = (h & 0x8000) << 16;
+ int32_t exponent = (h >> 10) & 0x1F;
+ uint32_t mantissa = h & 0x3FF;
+ uint32_t f32;
+ if (exponent == 31) {
+ f32 = mantissa != 0 ? (sign | 0x7F800000 | (mantissa << 13)) : (sign | 0x7F800000);
+ } else if (exponent == 0) {
+ if (mantissa == 0) {
+ f32 = sign;
+ } else {
+ exponent = -14;
+ while ((mantissa & 0x400) == 0) {
+ mantissa <<= 1;
+ exponent--;
+ }
+ mantissa &= 0x3FF;
+ f32 = sign | ((exponent + 127) << 23) | (mantissa << 13);
+ }
+ } else {
+ f32 = sign | ((exponent + 127 - 15) << 23) | (mantissa << 13);
+ }
+ union {
+ uint32_t u;
+ float f;
+ } tmp;
+ tmp.u = f32;
+ return tmp.f;
+}
+
+__device__ inline llaisys::fp16_t f32_to_f16(float val) {
+ union {
+ uint32_t u;
+ float f;
+ } tmp;
+ tmp.f = val;
+ uint32_t f32 = tmp.u;
+ uint16_t sign = (f32 >> 16) & 0x8000;
+ int32_t exponent = ((f32 >> 23) & 0xFF) - 127;
+ uint32_t mantissa = f32 & 0x7FFFFF;
+ if (exponent >= 16) {
+ if (exponent == 128 && mantissa != 0) {
+ return llaisys::fp16_t{static_cast(sign | 0x7E00)};
+ }
+ return llaisys::fp16_t{static_cast(sign | 0x7C00)};
+ } else if (exponent >= -14) {
+ return llaisys::fp16_t{static_cast(sign | ((exponent + 15) << 10) | (mantissa >> 13))};
+ } else if (exponent >= -24) {
+ mantissa |= 0x800000;
+ mantissa >>= (-14 - exponent);
+ return llaisys::fp16_t{static_cast(sign | (mantissa >> 13))};
+ }
+ return llaisys::fp16_t{static_cast(sign)};
+}
+
+__device__ inline float bf16_to_f32(llaisys::bf16_t val) {
+ uint32_t bits32 = static_cast(val._v) << 16;
+ union {
+ uint32_t u;
+ float f;
+ } tmp;
+ tmp.u = bits32;
+ return tmp.f;
+}
+
+__device__ inline llaisys::bf16_t f32_to_bf16(float val) {
+ union {
+ uint32_t u;
+ float f;
+ } tmp;
+ tmp.f = val;
+ uint32_t bits32 = tmp.u;
+ const uint32_t rounding_bias = 0x00007FFF + ((bits32 >> 16) & 1);
+ uint16_t bf16_bits = static_cast((bits32 + rounding_bias) >> 16);
+ return llaisys::bf16_t{bf16_bits};
+}
+
+template
+__device__ inline float to_float(T v) {
+ if constexpr (std::is_same_v) {
+ return v;
+ } else if constexpr (std::is_same_v) {
+ return f16_to_f32(v);
+ } else if constexpr (std::is_same_v) {
+ return bf16_to_f32(v);
+ } else {
+ return static_cast(v);
+ }
+}
+
+template
+__device__ inline T from_float(float v) {
+ if constexpr (std::is_same_v) {
+ return v;
+ } else if constexpr (std::is_same_v) {
+ return f32_to_f16(v);
+ } else if constexpr (std::is_same_v) {
+ return f32_to_bf16(v);
+ } else {
+ return static_cast(v);
+ }
+}
+
+}
diff --git a/src/ops/rand_sample/cpu/rand_sample_cpu.cpp b/src/ops/rand_sample/cpu/rand_sample_cpu.cpp
new file mode 100644
index 00000000..9ca5e0e5
--- /dev/null
+++ b/src/ops/rand_sample/cpu/rand_sample_cpu.cpp
@@ -0,0 +1,133 @@
+#include "rand_sample_cpu.hpp"
+
+#include "../../../utils.hpp"
+
+#include
+#include
+#include
+#include
+
+template
+static void rand_sample_(std::byte *sample_idx, std::byte *sample_val, const std::byte *vals, const float temperature,
+ const size_t topK, const float topP, size_t numel, const int64_t batch_size, const int64_t seed) {
+ const T *v_all = reinterpret_cast(vals);
+ auto *out_idx = reinterpret_cast(sample_idx);
+ auto *out_val = reinterpret_cast(sample_val);
+ float temp = temperature;
+ if (temp <= 1e-6f) {
+ temp = 1e-6f;
+ }
+
+ std::vector scores(numel);
+ std::vector> sorted_scores(numel);
+ std::vector> candidates;
+ candidates.reserve(numel);
+
+ std::mt19937_64 rng(static_cast(seed));
+ std::uniform_real_distribution dist(0.0f, 1.0f);
+
+ for (int64_t b = 0; b < batch_size; ++b) {
+ const T *v = v_all + b * numel;
+ size_t max_i = 0;
+ float max_v = llaisys::utils::cast(v[0]);
+ for (size_t i = 1; i < numel; ++i) {
+ float cur = llaisys::utils::cast(v[i]);
+ if (cur > max_v) {
+ max_v = cur;
+ max_i = i;
+ }
+ }
+
+ for (size_t i = 0; i < numel; ++i) {
+ float cur = llaisys::utils::cast(v[i]);
+ scores[i] = std::exp((cur - max_v) / temp);
+ }
+ float sum = 0.0f;
+ for (size_t i = 0; i < numel; ++i) {
+ sum += scores[i];
+ }
+ if (sum <= 0.0f) {
+ for (size_t i = 0; i < numel; ++i) {
+ scores[i] = 0.0f;
+ }
+ scores[max_i] = 1.0f;
+ sum = 1.0f;
+ } else {
+ for (size_t i = 0; i < numel; ++i) {
+ scores[i] /= sum;
+ }
+ }
+
+ for (size_t i = 0; i < numel; ++i) {
+ sorted_scores[i] = {scores[i], i};
+ }
+ std::sort(sorted_scores.begin(), sorted_scores.end(),
+ [](const auto &a, const auto &b) { return a.first > b.first; });
+
+ size_t k = numel;
+ if (topK > 0 && topK < numel) {
+ k = topK;
+ }
+
+ candidates.clear();
+ if (topP > 0.0f && topP < 1.0f) {
+ float cum_score = 0.0f;
+ for (size_t i = 0; i < k; ++i) {
+ candidates.push_back(sorted_scores[i]);
+ cum_score += sorted_scores[i].first;
+ if (cum_score >= topP) {
+ break;
+ }
+ }
+ } else {
+ for (size_t i = 0; i < k; ++i) {
+ candidates.push_back(sorted_scores[i]);
+ }
+ }
+ if (candidates.empty()) {
+ candidates.push_back(sorted_scores[0]);
+ }
+
+ float cand_sum = 0.0f;
+ for (const auto &item : candidates) {
+ cand_sum += item.first;
+ }
+
+ size_t chosen = candidates[0].second;
+ if (cand_sum > 0.0f) {
+ float r = dist(rng);
+ for (const auto &item : candidates) {
+ r -= item.first / cand_sum;
+ if (r <= 0.0f) {
+ chosen = item.second;
+ break;
+ }
+ }
+ if (r > 0.0f) {
+ chosen = candidates.back().second;
+ }
+ }
+
+ out_idx[b] = static_cast(chosen);
+ out_val[b] = v[chosen];
+ }
+}
+
+namespace llaisys::ops::cpu {
+void rand_sample(std::byte *sample_idx, std::byte *sample_val, const std::byte *vals, llaisysDataType_t type, size_t numel,
+ const int64_t batch_size, const float temperature, const size_t topK, const float topP, const int64_t seed) {
+ switch (type) {
+ case LLAISYS_DTYPE_F32:
+ return rand_sample_(sample_idx, sample_val, vals,
+ temperature, topK, topP, numel, batch_size, seed);
+ case LLAISYS_DTYPE_BF16:
+ return rand_sample_(sample_idx, sample_val, vals,
+ temperature, topK, topP, numel, batch_size, seed);
+ case LLAISYS_DTYPE_F16:
+ return rand_sample_(sample_idx, sample_val, vals,
+ temperature, topK, topP, numel, batch_size, seed);
+ default:
+ EXCEPTION_UNSUPPORTED_DATATYPE(type);
+ }
+}
+}
diff --git a/src/ops/rand_sample/cpu/rand_sample_cpu.hpp b/src/ops/rand_sample/cpu/rand_sample_cpu.hpp
new file mode 100644
index 00000000..59ad95e7
--- /dev/null
+++ b/src/ops/rand_sample/cpu/rand_sample_cpu.hpp
@@ -0,0 +1,9 @@
+#pragma once
+#include "llaisys.h"
+
+#include
+
+namespace llaisys::ops::cpu {
+void rand_sample(std::byte *sample_idx, std::byte *sample_val, const std::byte *vals, llaisysDataType_t type, size_t numel,
+ const int64_t batch_size, const float temperature, const size_t topK, const float topP, const int64_t seed);
+}
diff --git a/src/ops/rand_sample/nvidia/rand_sample_nvidia.cu b/src/ops/rand_sample/nvidia/rand_sample_nvidia.cu
new file mode 100644
index 00000000..cdb00f7b
--- /dev/null
+++ b/src/ops/rand_sample/nvidia/rand_sample_nvidia.cu
@@ -0,0 +1,172 @@
+#include "rand_sample_nvidia.hpp"
+
+#include "../../../core/llaisys_core.hpp"
+#include "../../nvidia_utils.cuh"
+
+#include
+#include
+#include
+#include
+
+namespace llaisys::ops::nvidia {
+__device__ inline uint64_t lcg_next(uint64_t &state) {
+ state = state * 6364136223846793005ULL + 1ULL;
+ return state;
+}
+
+__device__ inline float rng_uniform(uint64_t &state) {
+ uint64_t x = lcg_next(state);
+ uint32_t mant = static_cast(x >> 40);
+ return static_cast(mant) * (1.0f / 16777216.0f);
+}
+
+template
+__global__ void rand_sample_kernel(int64_t *out_idx, T *out_val, const T *vals, size_t numel, int64_t batch_size,
+ float temperature, size_t topK, float topP, int64_t seed, float *probs, int *idx) {
+ int64_t b = static_cast(blockIdx.x);
+ if (b >= batch_size) {
+ return;
+ }
+ if (threadIdx.x != 0) {
+ return;
+ }
+ const T *v = vals + static_cast(b) * numel;
+ float *p = probs + static_cast(b) * numel;
+ int *id = idx + static_cast(b) * numel;
+ float temp = temperature <= 1e-6f ? 1e-6f : temperature;
+ size_t max_i = 0;
+ float max_v = detail::to_float(v[0]);
+ for (size_t i = 1; i < numel; ++i) {
+ float cur = detail::to_float(v[i]);
+ if (cur > max_v) {
+ max_v = cur;
+ max_i = i;
+ }
+ }
+ float sum = 0.0f;
+ for (size_t i = 0; i < numel; ++i) {
+ float cur = detail::to_float(v[i]);
+ float val = expf((cur - max_v) / temp);
+ p[i] = val;
+ sum += val;
+ id[i] = static_cast(i);
+ }
+ if (sum <= 0.0f) {
+ for (size_t i = 0; i < numel; ++i) {
+ p[i] = 0.0f;
+ }
+ p[max_i] = 1.0f;
+ sum = 1.0f;
+ } else {
+ float inv = 1.0f / sum;
+ for (size_t i = 0; i < numel; ++i) {
+ p[i] *= inv;
+ }
+ }
+ size_t k = numel;
+ if (topK > 0 && topK < numel) {
+ k = topK;
+ }
+ if ((topK > 0 && topK < numel) || (topP > 0.0f && topP < 1.0f)) {
+ for (size_t i = 0; i < k; ++i) {
+ size_t max_pos = i;
+ float max_val = p[i];
+ for (size_t j = i + 1; j < numel; ++j) {
+ float vj = p[j];
+ if (vj > max_val) {
+ max_val = vj;
+ max_pos = j;
+ }
+ }
+ if (max_pos != i) {
+ float tmp = p[i];
+ p[i] = p[max_pos];
+ p[max_pos] = tmp;
+ int tmp_i = id[i];
+ id[i] = id[max_pos];
+ id[max_pos] = tmp_i;
+ }
+ }
+ }
+ size_t cand = k;
+ if (topP > 0.0f && topP < 1.0f) {
+ float cum = 0.0f;
+ cand = 0;
+ for (size_t i = 0; i < k; ++i) {
+ cum += p[i];
+ cand = i + 1;
+ if (cum >= topP) {
+ break;
+ }
+ }
+ if (cand == 0) {
+ cand = 1;
+ }
+ }
+ float cand_sum = 0.0f;
+ for (size_t i = 0; i < cand; ++i) {
+ cand_sum += p[i];
+ }
+ size_t chosen = static_cast(id[0]);
+ if (cand_sum > 0.0f) {
+ uint64_t state = static_cast(seed) ^ (static_cast(b + 1) * 0x9e3779b97f4a7c15ULL);
+ float r = rng_uniform(state) * cand_sum;
+ for (size_t i = 0; i < cand; ++i) {
+ r -= p[i];
+ if (r <= 0.0f) {
+ chosen = static_cast(id[i]);
+ break;
+ }
+ if (i == cand - 1) {
+ chosen = static_cast(id[i]);
+ }
+ }
+ }
+ out_idx[b] = static_cast(chosen);
+ out_val[b] = v[chosen];
+}
+
+void rand_sample(std::byte *sample_idx, std::byte *sample_val, const std::byte *vals, llaisysDataType_t type, size_t numel,
+ int64_t batch_size, float temperature, size_t topK, float topP, int64_t seed) {
+ auto &runtime = llaisys::core::context().runtime();
+ if (batch_size <= 0 || numel == 0) {
+ return;
+ }
+ auto d_probs = static_cast(runtime.api()->malloc_device(sizeof(float) * static_cast(batch_size) * numel));
+ auto d_idx = static_cast(runtime.api()->malloc_device(sizeof(int) * static_cast(batch_size) * numel));
+ dim3 grid(static_cast(batch_size), 1, 1);
+ int threads = 1;
+ auto stream = reinterpret_cast(runtime.stream());
+ switch (type) {
+ case LLAISYS_DTYPE_F32:
+ rand_sample_kernel<<>>(
+ reinterpret_cast(sample_idx),
+ reinterpret_cast(sample_val),
+ reinterpret_cast(vals),
+ numel, batch_size, temperature, topK, topP, seed, d_probs, d_idx);
+ break;
+ case LLAISYS_DTYPE_BF16:
+ rand_sample_kernel<<