Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
a791b9e
opencl: support ne3 in get_rows (llama/15866)
lhez Sep 30, 2025
7879b37
ggml webgpu: support for rope,div,sub,glu,scale,cont operators (llama…
reeselevine Sep 30, 2025
0b2e1e4
opencl: support pad_ext (llama/15888)
lhez Sep 30, 2025
32e75a6
vulkan: make ggml_vk_default_dispatcher support older vulkan headers …
netrunnereve Oct 1, 2025
ebae117
HIP: Disable ROCWMMA fattn on CDNA when compiled against ROCWMMA 2.0.…
IMbackK Oct 1, 2025
efb1344
musa: update compile flags (llama/16265)
yeahdongcn Oct 2, 2025
519f8c4
model : Apertus model implementation (llama/15852)
pwilkin Oct 2, 2025
3e930a4
ggml webgpu: add support for soft_max, optimize rms_norm (llama/16357)
reeselevine Oct 2, 2025
ba94efe
vulkan: in flash attention, bounds check against nem1 (don't rely on …
jeffbolznv Oct 3, 2025
e0c619e
vulkan: Fix FA coopmat1 invalid array indexing (llama/16365)
jeffbolznv Oct 3, 2025
f3fe64f
vulkan: Replace uses of maxMemoryAllocationSize and VK_WHOLE_SIZE (ll…
jeffbolznv Oct 3, 2025
187a56f
ggml : fix graph reallocation with multiple chunks (llama/16396)
Acly Oct 3, 2025
5bd10f5
metal : fix loop bound in ggml_mem_ranges (llama/16412)
ggerganov Oct 3, 2025
4a41781
sync : llama.cpp
ggerganov Oct 11, 2025
2767954
vulkan : incremental shader builds (llama/16341)
Acly Oct 11, 2025
e02c134
sync : llama.cpp
ggerganov Oct 11, 2025
4f0730d
rpc : add support for multiple devices (llama/16276)
rgerganov Oct 4, 2025
0df2131
rpc : check src buffer when copying tensor (llama/16421)
rgerganov Oct 4, 2025
082230a
vulkan: use a more appropriate amount of threads when generating shad…
netrunnereve Oct 4, 2025
9bb0a5f
ggml webgpu: actually add softmax, fix rms_norm offset (llama/16400)
reeselevine Oct 5, 2025
a33af4d
ggml-cpu : fix leftover handling in ggml_vec_scale_f32 for SVE (llama…
danbev Oct 6, 2025
53017b7
ggml : fix unaligned access in AMX code (llama/16315)
ggerganov Oct 6, 2025
3d6002f
metal : various optimizations + refactoring (llama/16446)
ggerganov Oct 7, 2025
cf7fbdc
tests : add -INF blocks to the KQ mask in the FA tests (llama/16380)
ggerganov Oct 7, 2025
d777d4a
metal : add support for non-padded FA KV (llama/16148)
ggerganov Oct 7, 2025
7255dc5
ggml webgpu: profiling, CI updates, reworking of command submission (…
reeselevine Oct 7, 2025
5d3854a
metal : mark FA blocks (llama/16372)
ggerganov Oct 8, 2025
b15f0b2
Disable CUDA host buffers on integrated GPUs (llama/16308)
ai-fonsi Oct 8, 2025
77dd6c9
refactor soft_max, add soft_max_back (llama/16472)
NeoZhangJianyu Oct 9, 2025
91dbc93
kleidiai: kernel interface refactoring (llama/16460)
chaxu01 Oct 9, 2025
626d187
CANN: Improve ACL graph matching (llama/16166)
noemotiovon Oct 9, 2025
c11d293
cpu : optimize the ggml NORM operation (llama/15953)
duduta Oct 9, 2025
d13a5fe
cmake : Dont define XOPENSOURCE on AIX (llama/16481)
mehendarkarprajwal Oct 10, 2025
5c85a3c
cuda : avoid initializing unused devices (llama/16510)
slaren Oct 11, 2025
a50be4e
metal : fix mul-mm condition + fix mul-mv permuted kernels (llama/16494)
ggerganov Oct 11, 2025
9a238bf
sync : llama.cpp
ggerganov Oct 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,6 @@ option(GGML_HIP "ggml: use HIP"
option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, slow" OFF)
option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON)
option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF)
option(GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12 "ggml: enable rocWMMA FlashAttention on GFX12" OFF)
option(GGML_HIP_MMQ_MFMA "ggml: enable MFMA MMA for CDNA in MMQ" ON)
option(GGML_HIP_EXPORT_METRICS "ggml: enable kernel perf metrics output" OFF)
option(GGML_MUSA_GRAPHS "ggml: use MUSA graph, experimental, unstable" OFF)
Expand All @@ -223,6 +222,9 @@ option(GGML_VULKAN_VALIDATE "ggml: enable Vulkan validation"
option(GGML_VULKAN_RUN_TESTS "ggml: run Vulkan tests" OFF)
option(GGML_WEBGPU "ggml: use WebGPU" OFF)
option(GGML_WEBGPU_DEBUG "ggml: enable WebGPU debug output" OFF)
option(GGML_WEBGPU_CPU_PROFILE "ggml: enable WebGPU profiling (CPU)" OFF)
option(GGML_WEBGPU_GPU_PROFILE "ggml: enable WebGPU profiling (GPU)" OFF)

option(GGML_ZDNN "ggml: use zDNN" OFF)
option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT})
option(GGML_METAL_NDEBUG "ggml: disable Metal debugging" OFF)
Expand Down
2 changes: 2 additions & 0 deletions include/ggml-backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,8 @@ extern "C" {
// Backend registry
//

GGML_API void ggml_backend_register(ggml_backend_reg_t reg);

GGML_API void ggml_backend_device_register(ggml_backend_dev_t device);

// Backend (reg) enumeration
Expand Down
17 changes: 8 additions & 9 deletions include/ggml-rpc.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,25 @@
extern "C" {
#endif

#define RPC_PROTO_MAJOR_VERSION 2
#define RPC_PROTO_MAJOR_VERSION 3
#define RPC_PROTO_MINOR_VERSION 0
#define RPC_PROTO_PATCH_VERSION 0
#define GGML_RPC_MAX_SERVERS 16

// backend API
GGML_BACKEND_API ggml_backend_t ggml_backend_rpc_init(const char * endpoint);
GGML_BACKEND_API ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device);
GGML_BACKEND_API bool ggml_backend_is_rpc(ggml_backend_t backend);

GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint);
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint, uint32_t device);

GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total);
GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total);

GGML_BACKEND_API void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint,
const char * cache_dir,
size_t free_mem, size_t total_mem);
GGML_BACKEND_API void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir,
size_t n_threads, size_t n_devices,
ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem);

GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_reg(void);

GGML_BACKEND_API ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint);
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint);

#ifdef __cplusplus
}
Expand Down
22 changes: 22 additions & 0 deletions include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,8 @@
#define GGML_EXIT_SUCCESS 0
#define GGML_EXIT_ABORTED 1

// TODO: convert to enum https://github.com/ggml-org/llama.cpp/pull/16187#discussion_r2388538726
#define GGML_ROPE_TYPE_NORMAL 0
#define GGML_ROPE_TYPE_NEOX 2
#define GGML_ROPE_TYPE_MROPE 8
#define GGML_ROPE_TYPE_VISION 24
Expand Down Expand Up @@ -574,6 +576,7 @@ extern "C" {
GGML_UNARY_OP_HARDSIGMOID,
GGML_UNARY_OP_EXP,
GGML_UNARY_OP_GELU_ERF,
GGML_UNARY_OP_XIELU,

GGML_UNARY_OP_COUNT,
};
Expand Down Expand Up @@ -1148,6 +1151,18 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a);

// xIELU activation function
// x = x * (c_a(alpha_n) + c_b(alpha_p, beta) * sigmoid(beta * x)) + eps * (x > 0)
// where c_a = softplus and c_b(a, b) = softplus(a) + b are constraining functions
// that constrain the positive and negative source alpha values respectively
GGML_API struct ggml_tensor * ggml_xielu(
struct ggml_context * ctx,
struct ggml_tensor * a,
float alpha_n,
float alpha_p,
float beta,
float eps);

// gated linear unit ops
// A: n columns, r rows,
// result is n / 2 columns, r rows,
Expand Down Expand Up @@ -1615,6 +1630,13 @@ extern "C" {
float scale,
float max_bias);

GGML_API struct ggml_tensor * ggml_soft_max_ext_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * mask,
float scale,
float max_bias);

GGML_API void ggml_soft_max_add_sinks(
struct ggml_tensor * a,
struct ggml_tensor * sinks);
Expand Down
2 changes: 1 addition & 1 deletion scripts/sync-llama.last
Original file line number Diff line number Diff line change
@@ -1 +1 @@
a014310374a16f9204f2bcc1b458fc1eda67e469
a3cb04744fb5c591985f53b749fef5407d07a145
3 changes: 3 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ endif()
# which was introduced in POSIX.1-2008, forcing us to go higher
if (CMAKE_SYSTEM_NAME MATCHES "OpenBSD")
add_compile_definitions(_XOPEN_SOURCE=700)
elseif (CMAKE_SYSTEM_NAME MATCHES "AIX")
# Don't define _XOPEN_SOURCE. We need _ALL_SOURCE, which is the default,
# in order to define _SC_PHYS_PAGES.
else()
add_compile_definitions(_XOPEN_SOURCE=600)
endif()
Expand Down
30 changes: 16 additions & 14 deletions src/ggml-alloc.c
Original file line number Diff line number Diff line change
Expand Up @@ -392,12 +392,8 @@ static void ggml_dyn_tallocr_free(struct ggml_dyn_tallocr * alloc) {
free(alloc);
}

static size_t ggml_dyn_tallocr_max_size(struct ggml_dyn_tallocr * alloc) {
size_t max_size = 0;
for (int i = 0; i < alloc->n_chunks; i++) {
max_size += alloc->chunks[i]->max_size;
}
return max_size;
static size_t ggml_dyn_tallocr_max_size(struct ggml_dyn_tallocr * alloc, int chunk) {
return chunk < alloc->n_chunks ? alloc->chunks[chunk]->max_size : 0;
}


Expand All @@ -417,10 +413,8 @@ static void ggml_vbuffer_free(struct vbuffer * buf) {
free(buf);
}

static int ggml_vbuffer_n_chunks(struct vbuffer * buf) {
int n = 0;
while (n < GGML_VBUFFER_MAX_CHUNKS && buf->chunks[n]) n++;
return n;
static size_t ggml_vbuffer_chunk_size(struct vbuffer * buf, int chunk) {
return buf->chunks[chunk] ? ggml_backend_buffer_get_size(buf->chunks[chunk]) : 0;
}

static size_t ggml_vbuffer_size(struct vbuffer * buf) {
Expand Down Expand Up @@ -885,12 +879,20 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c
}
}

size_t cur_size = galloc->buffers[i] ? ggml_vbuffer_size(galloc->buffers[i]) : 0;
size_t new_size = ggml_dyn_tallocr_max_size(galloc->buf_tallocs[i]);

// even if there are no tensors allocated in this buffer, we still need to allocate it to initialize views
if (new_size > cur_size || galloc->buffers[i] == NULL) {
bool realloc = galloc->buffers[i] == NULL;
size_t new_size = 0;
for (int c = 0; c < galloc->buf_tallocs[i]->n_chunks; c++) {
size_t cur_chunk_size = galloc->buffers[i] ? ggml_vbuffer_chunk_size(galloc->buffers[i], c) : 0;
size_t new_chunk_size = ggml_dyn_tallocr_max_size(galloc->buf_tallocs[i], c);
new_size += new_chunk_size;
if (new_chunk_size > cur_chunk_size) {
realloc = true;
}
}
if (realloc) {
#ifndef NDEBUG
size_t cur_size = galloc->buffers[i] ? ggml_vbuffer_size(galloc->buffers[i]) : 0;
GGML_LOG_DEBUG("%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
#endif

Expand Down
3 changes: 0 additions & 3 deletions src/ggml-backend-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,6 @@ extern "C" {
void * context;
};

// Internal backend registry API
GGML_API void ggml_backend_register(ggml_backend_reg_t reg);

// Add backend dynamic loading support to the backend

// Initialize the backend
Expand Down
9 changes: 8 additions & 1 deletion src/ggml-cann/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -341,11 +341,18 @@ class cann_task_queue {

#ifdef USE_ACL_GRAPH
struct ggml_graph_node_properties {
// dst tensor
void * node_address;
ggml_op node_op;
int64_t ne[GGML_MAX_DIMS];
size_t nb[GGML_MAX_DIMS];

// src tensor
void * src_address[GGML_MAX_SRC];
int64_t src_ne[GGML_MAX_SRC][GGML_MAX_DIMS];
size_t src_nb[GGML_MAX_SRC][GGML_MAX_DIMS];

// op
ggml_op node_op;
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
};

Expand Down
48 changes: 37 additions & 11 deletions src/ggml-cann/ggml-cann.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2186,7 +2186,15 @@ static void add_lru_matched_graph_node_properties(
std::copy_n(node->nb, GGML_MAX_DIMS, prop.nb);

for (int src = 0; src < GGML_MAX_SRC; ++src) {
prop.src_address[src] = node->src[src] ? node->src[src]->data : nullptr;
if (node->src[src]) {
prop.src_address[src] = node->src[src]->data;
std::copy_n(node->src[src]->ne, GGML_MAX_DIMS, prop.src_ne[src]);
std::copy_n(node->src[src]->nb, GGML_MAX_DIMS, prop.src_nb[src]);
} else {
prop.src_address[src] = nullptr;
std::fill_n(prop.src_ne[src], GGML_MAX_DIMS, 0);
std::fill_n(prop.src_nb[src], GGML_MAX_DIMS, 0);
}
}

memcpy(prop.op_params, node->op_params, GGML_MAX_OP_PARAMS);
Expand All @@ -2206,14 +2214,18 @@ static void add_lru_matched_graph_node_properties(
* @param graph_node_properties The stored properties of a CANN graph node.
* @return true if all fields match (excluding GGML_OP_VIEW); false otherwise.
*/
static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
static bool ggml_graph_node_has_matching_properties(
ggml_tensor * node,
ggml_graph_node_properties * graph_node_properties) {
if (node->data != graph_node_properties->node_address &&
node->op != GGML_OP_VIEW) {
node->op != GGML_OP_VIEW) {
return false;
}

if (node->op != graph_node_properties->node_op) {
return false;
}

for (int i = 0; i < GGML_MAX_DIMS; i++) {
if (node->ne[i] != graph_node_properties->ne[i]) {
return false;
Expand All @@ -2222,17 +2234,31 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
return false;
}
}

for (int i = 0; i < GGML_MAX_SRC; i++) {
if (node->src[i] &&
node->src[i]->data != graph_node_properties->src_address[i] &&
node->op != GGML_OP_VIEW
) {
return false;
if (node->src[i]) {
if (node->src[i]->data != graph_node_properties->src_address[i] &&
node->op != GGML_OP_VIEW) {
return false;
}

for (int d = 0; d < GGML_MAX_DIMS; d++) {
if (node->src[i]->ne[d] != graph_node_properties->src_ne[i][d]) {
return false;
}
if (node->src[i]->nb[d] != graph_node_properties->src_nb[i][d]) {
return false;
}
}
} else {
if (graph_node_properties->src_address[i] != nullptr) {
return false;
}
}
}
if (node->op == GGML_OP_SCALE &&
memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
return false;

if (node->op == GGML_OP_SCALE || node->op == GGML_OP_UNARY || node->op == GGML_OP_GLU) {
return memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) == 0;
}
return true;
}
Expand Down
1 change: 1 addition & 0 deletions src/ggml-cpu/amx/amx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
if (op->op == GGML_OP_MUL_MAT && is_contiguous_2d(op->src[0]) && // src0 must be contiguous
is_contiguous_2d(op->src[1]) && // src1 must be contiguous
op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_amx_buffer_type() &&
op->src[0]->ne[0] % (TILE_K * 2 * 32) == 0 && // TODO: not sure if correct (https://github.com/ggml-org/llama.cpp/pull/16315)
op->ne[0] % (TILE_N * 2) == 0 && // out_features is 32x
(qtype_has_amx_kernels(op->src[0]->type) || (op->src[0]->type == GGML_TYPE_F16))) {
// src1 must be host buffer
Expand Down
1 change: 1 addition & 0 deletions src/ggml-cpu/ggml-cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -2187,6 +2187,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_UNARY_OP_GELU_ERF:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_XIELU:
{
n_tasks = n_threads;
} break;
Expand Down
Loading
Loading