A small, GPU-only JAX inference engine. Currently, only single GPU inference with Qwen3 models is supported.
from aphrodite_jax import LLM, SamplingParams
llm = LLM("Qwen/Qwen3-0.6B", max_model_len=4096)
outputs = llm.generate(
["Hello from Aphrodite-JAX"],
SamplingParams(temperature=0.6, max_tokens=32),
)
print(outputs[0]["text"])python -m aphrodite_jax.bench_perf -m Qwen/Qwen3-0.6BThere is currently no compile cache or AOT compilation, so each shape triggers a compile run.