Batched KV caching for fast parallel inference on Apple Silicon devices, via MLX.
This repo heavily borrows from mlx_lm. Will explore how to add batched generation there as a non-breaking PR.
Requires mlx and mlx_lm to be installed.
from mlx_parallm.utils import load, batch_generate
model, tokenizer = load("google/gemma-1.1-2b-it")
prompts = ["prompt_0", ..., "prompt_k"]
responses = batch_generate(model, tokenizer, prompts=prompts_raw[:10], max_tokens=100, verbose=True, format_prompts=True, temp=0.0)Models tested:
meta-llama/Meta-Llama-3-8B-Instructmicrosoft/Phi-3-mini-4k-instructgoogle/gemma-1.1-2b-itmlx-community/Meta-Llama-3-8B-Instruct-4bitmlx-community/Phi-3-mini-4k-instruct-4bitmlx-community/gemma-1.1-2b-it-4bit
Both quantized and float16 models are supported. float16 models seem to generally perform faster if sufficient RAM is available (up to 1300+ tok/s throughput for gemma-2b on M3 Max 128GB).
Additional models can be added by copying architecture files from mlx_lm/models and replacing any references to KVCache with BatchedKVCache.
Supported:
batch_generatemethod (tested withlen(prompts) > 500)- Auto-padding
- Auto-formatting with prompt templates (
format_prompts=True) temp = 0,temp > 0,top_psampling- single-stream
generatemethod
Not (yet) supported:
- Repetition penalties
- Streaming outputs for
batch_generate - Dynamic batching for async requests