Skip to content

Commit 38d7996

Browse files
committed
fix(mtmd): prevent batch splitting by capping n_batch to n_ubatch ikawrakow#988
2 parents 37bdd92 + 2d8a1db commit 38d7996

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

examples/mtmd/mtmd-helper.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,13 @@ int32_t mtmd_helper_decode_image_chunk(
185185
int n_mmproj_embd = llama_model_n_embd_inp(model);
186186
int n_pos_per_embd = mtmd_decode_use_mrope(ctx) ? 4 : 1;
187187

188+
// ensure we don't exceed n_ubatch, otherwise llama_decode will try to split the batch
189+
// which will break M-RoPE positional embeddings
190+
int32_t n_ubatch = llama_n_ubatch(lctx);
191+
if (n_batch > n_ubatch) {
192+
n_batch = n_ubatch;
193+
}
194+
188195
int32_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk);
189196
int32_t i_batch = 0;
190197
int32_t n_img_batches = GGML_PAD(n_tokens, n_batch) / n_batch;

0 commit comments

Comments
 (0)