Skip to content

Commit

Permalink
tinyblas dynamic dispaching
Browse files Browse the repository at this point in the history
  • Loading branch information
Djip007 committed Dec 14, 2024
1 parent f48c35d commit 30ae0d2
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 57 deletions.
4 changes: 2 additions & 2 deletions examples/server/tests/unit/test_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_completion_stream_vs_non_stream():
@pytest.mark.parametrize("n_slots", [1, 2])
def test_consistent_result_same_seed(n_slots: int):
global server
server.n_slots = 1
server.n_slots = n_slots
server.start()
last_res = None
for _ in range(4):
Expand Down Expand Up @@ -115,7 +115,7 @@ def test_different_result_different_seed(n_slots: int):
@pytest.mark.parametrize("temperature", [0.0, 1.0])
def test_consistent_result_different_batch_size(n_batch: int, temperature: float):
global server
server.n_batch = 1
server.n_batch = n_batch
server.start()
last_res = None
for _ in range(4):
Expand Down
8 changes: 4 additions & 4 deletions ggml/src/ggml-cpu/ggml-cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -7420,14 +7420,14 @@ static void ggml_compute_forward_mul_mat(
if (src1_cont) {
for (int64_t i13 = 0; i13 < ne13; i13++)
for (int64_t i12 = 0; i12 < ne12; i12++)
if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
if (!llamafile_sgemm(params,
ne01, ne11, ne00/ggml_blck_size(src0->type),
(const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
nb01/ggml_type_size(src0->type),
(const char *)src1->data + i12*nb12 + i13*nb13,
nb11/ggml_type_size(src1->type),
(char *)dst->data + i12*nb2 + i13*nb3,
nb1/ggml_type_size(dst->type),
ith, nth,
src0->type,
src1->type,
dst->type))
Expand Down Expand Up @@ -7472,14 +7472,14 @@ UseGgmlGemm1:;

for (int64_t i13 = 0; i13 < ne13; i13++)
for (int64_t i12 = 0; i12 < ne12; i12++)
if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
if (!llamafile_sgemm(params,
ne01, ne11, ne00/ggml_blck_size(src0->type),
(const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
nb01/ggml_type_size(src0->type),
(const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size,
row_size/ggml_type_size(vec_dot_type),
(char *)dst->data + i12*nb2 + i13*nb3,
nb1/ggml_type_size(dst->type),
ith, nth,
src0->type,
vec_dot_type,
dst->type))
Expand Down
Loading

0 comments on commit 30ae0d2

Please sign in to comment.