diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index d407033..21a2888 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -30,6 +30,7 @@ jobs: python -m venv venv source venv/bin/activate pip install -e .[dev,examples] + pip install --force-reinstall 'triton==3.2.0' # Add the project root to the PYTHONPATH for examples PYTHONPATH=$PYTHONPATH:$(pwd) pytest tests --cov=llamppl --cov-report=json @@ -41,3 +42,35 @@ jobs: token: ${{ secrets.CODECOV_TOKEN }} files: ./coverage.json slug: genlm/llamppl + + + test_mlx: + runs-on: macos-14 + + steps: + - name: Check out repository + uses: actions/checkout@v4 + with: + fetch-depth: 1 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.11.5' + cache: 'pip' + + - name: Run tests with MLX extras + run: | + python -m venv venv + source venv/bin/activate + pip install -e .[mlx,dev,examples] + PYTHONPATH=$PYTHONPATH:$(pwd) pytest tests --cov=llamppl --cov-report=json + + - name: Upload MLX coverage to Codecov + uses: codecov/codecov-action@v5 + with: + fail_ci_if_error: false + disable_search: true + token: ${{ secrets.CODECOV_TOKEN }} + files: ./coverage.json + slug: genlm/llamppl diff --git a/README.md b/README.md index 94637a3..4ec4c82 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,12 @@ To get started on your own machine, you can install this library from PyPI: pip install llamppl ``` +For faster inference on Apple Silicon devices, you can install with MLX backend: + +```bash +pip install llamppl[mlx] +``` + ### Local installation For local development, clone this repository and run `pip install -e ".[dev,examples]"` to install `llamppl` and its development dependencies. diff --git a/benchmark/benchmark_backend.py b/benchmark/benchmark_backend.py index c841c7e..567668c 100644 --- a/benchmark/benchmark_backend.py +++ b/benchmark/benchmark_backend.py @@ -11,7 +11,7 @@ from examples.haiku import run_example as run_haiku from examples.hard_constraints import run_example as run_hard_constraints -from hfppl.llms import CachedCausalLM +from llamppl.llms import CachedCausalLM, MLX_AVAILABLE backends = [ "hf", @@ -21,6 +21,12 @@ not torch.cuda.is_available(), reason="vLLM backend requires CUDA" ), ), + pytest.param( + "mlx", + marks=pytest.mark.skipif( + not MLX_AVAILABLE, reason="MLX backend requires MLX-LM" + ), + ), ] diff --git a/llamppl/llms.py b/llamppl/llms.py index 1ad14b3..1de65d3 100644 --- a/llamppl/llms.py +++ b/llamppl/llms.py @@ -8,6 +8,7 @@ from genlm.backend.llm import AsyncTransformer from genlm.backend.llm import AsyncVirtualLM from genlm.backend.llm import MockAsyncLM +from genlm.backend.llm import AsyncMlxLM VLLM_AVAILABLE = True try: @@ -15,6 +16,12 @@ except ImportError: VLLM_AVAILABLE = False +MLX_AVAILABLE = True +try: + import mlx_lm +except ImportError: + MLX_AVAILABLE = False + warnings.filterwarnings("once", category=DeprecationWarning) warnings.filterwarnings("once", category=RuntimeWarning) @@ -201,6 +208,7 @@ def from_pretrained(cls, model_id, backend=None, **kwargs): - 'vllm' to instantiate an `AsyncVirtualLM`; ideal for GPU usage - 'hf' for an `AsyncTransformer`; ideal for CPU usage - 'mock' for a `MockAsyncLM`; ideal for testing. + - 'mlx' for an `AsyncMlxLM`; ideal for usage on devices with Apple silicon. Defaults to 'vllm' if CUDA is available, otherwise 'hf'. **kwargs: Additional keyword arguments passed to the `AsyncLM` constructor. See [`AsyncLM` documentation](https://probcomp.github.io/genlm-backend/reference/genlm_backend/llm/__init__/). @@ -223,9 +231,11 @@ def from_pretrained(cls, model_id, backend=None, **kwargs): model_cls = AsyncTransformer elif backend == "mock": model_cls = MockAsyncLM + elif backend == "mlx": + model_cls = AsyncMlxLM else: raise ValueError( - f"Unknown backend: {backend}. Must be one of ['vllm', 'hf', 'mock']" + f"Unknown backend: {backend}. Must be one of ['vllm', 'hf', 'mock', 'mlx']" ) # Handle legacy auth_token parameter. The ability to pass in the auth_token should @@ -280,9 +290,11 @@ def __init__(self, model): self.backend = "hf" elif isinstance(model, MockAsyncLM): self.backend = "mock" + elif isinstance(model, AsyncMlxLM): + self.backend = "mlx" else: raise ValueError( - f"Unknown model type: {type(model)}. Must be one of [AsyncVirtualLM, AsyncTransformer, MockAsyncLM]" + f"Unknown model type: {type(model)}. Must be one of [AsyncVirtualLM, AsyncTransformer, MockAsyncLM, AsyncMlxLM]" ) self.model = model @@ -340,6 +352,8 @@ def clear_kv_cache(self): """Clear any key and value vectors from the cache.""" if self.backend == "hf": self.model.clear_kv_cache() + elif self.backend == "mlx": + self.model.clear_cache() elif self.backend == "vllm": warnings.warn( "clear_kv_cache() is only supported for the HuggingFace backend. The KV cache for the vLLM backend is handled internally by vLLM. No operation performed.", @@ -355,7 +369,7 @@ def clear_kv_cache(self): def reset_async_queries(self): """Clear any pending language model queries from the queue.""" - if self.backend == "hf": + if self.backend in ["hf", "mlx"]: self.model.reset_async_queries() elif self.backend == "vllm": warnings.warn( @@ -376,7 +390,7 @@ def cache_kv(self, prompt_tokens): Args: prompt_tokens (list[int]): token ids for the prompt to cache. """ - if self.backend == "hf": + if self.backend in ["hf", "mlx"]: self.model.cache_kv(prompt_tokens) elif self.backend == "vllm": warnings.warn( diff --git a/pyproject.toml b/pyproject.toml index 4c33e1e..3f9d8db 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ dependencies = [ [project.optional-dependencies] vllm = ["vllm>=0.6.6"] +mlx = ["genlm-backend[mlx]>=0.1.7"] dev = [ "pytest", "pytest-benchmark", diff --git a/tests/test_examples.py b/tests/test_examples.py index e4bdff5..e341ad5 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -5,25 +5,32 @@ from examples.haiku import run_example as run_haiku from examples.hard_constraints import run_example as run_hard_constraints -from llamppl.llms import CachedCausalLM - -backends = [ - "mock", - "hf", - pytest.param( - "vllm", - marks=pytest.mark.skipif( - not torch.cuda.is_available(), reason="vLLM backend requires CUDA" +from llamppl.llms import CachedCausalLM, MLX_AVAILABLE + +if MLX_AVAILABLE: + backends = ["mock", "mlx"] +else: + backends = [ + "mock", + "hf", + pytest.param( + "vllm", + marks=pytest.mark.skipif( + not torch.cuda.is_available(), reason="vLLM backend requires CUDA" + ), ), - ), -] + ] @pytest.fixture def LLM(backend): # Set lower gpu_memory_utilization in vllm so that we can fit both models on the GPU kwargs = ( - {"engine_opts": {"gpu_memory_utilization": 0.45}} if backend == "vllm" else {} + {"engine_opts": {"gpu_memory_utilization": 0.45}} + if backend == "vllm" + else {"cache_size": 10} + if backend == "mlx" + else {} ) return CachedCausalLM.from_pretrained("gpt2", backend=backend, **kwargs) diff --git a/tests/test_lmcontext.py b/tests/test_lmcontext.py index e310c37..92d9187 100644 --- a/tests/test_lmcontext.py +++ b/tests/test_lmcontext.py @@ -5,23 +5,27 @@ import torch from llamppl.distributions.lmcontext import LMContext -from llamppl.llms import CachedCausalLM - -backends = [ - "mock", - "hf", - pytest.param( - "vllm", - marks=pytest.mark.skipif( - not torch.cuda.is_available(), reason="vLLM backend requires CUDA" +from llamppl.llms import CachedCausalLM, MLX_AVAILABLE + +if MLX_AVAILABLE: + backends = ["mock", "mlx"] +else: + backends = [ + "mock", + "hf", + pytest.param( + "vllm", + marks=pytest.mark.skipif( + not torch.cuda.is_available(), reason="vLLM backend requires CUDA" + ), ), - ), -] + ] @pytest.fixture def lm(backend): - return CachedCausalLM.from_pretrained("gpt2", backend=backend) + kwargs = {"cache_size": 10} if backend == "mlx" else {} + return CachedCausalLM.from_pretrained("gpt2", backend=backend, **kwargs) @pytest.mark.parametrize("backend", backends) @@ -33,7 +37,7 @@ def test_init(lm): np.testing.assert_allclose( lmcontext.next_token_logprobs, logprobs, - rtol=1e-5, + rtol=5e-4, err_msg="Sync context __init__", ) @@ -44,7 +48,7 @@ async def async_context(): np.testing.assert_allclose( lmcontext.next_token_logprobs, logprobs, - rtol=1e-5, + rtol=5e-4, err_msg="Async context __init__", ) @@ -55,6 +59,6 @@ async def async_context_create(): np.testing.assert_allclose( lmcontext.next_token_logprobs, logprobs, - rtol=1e-5, + rtol=5e-4, err_msg="Async context create", )