Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
run: |
python -m venv venv
source venv/bin/activate
pip install -e .[dev,examples]
pip install -e .[vllm,dev,examples]
# Add the project root to the PYTHONPATH for examples
PYTHONPATH=$PYTHONPATH:$(pwd) pytest tests --cov=llamppl --cov-report=json

Expand All @@ -41,3 +41,35 @@ jobs:
token: ${{ secrets.CODECOV_TOKEN }}
files: ./coverage.json
slug: genlm/llamppl


test_mlx:
runs-on: macos-14

steps:
- name: Check out repository
uses: actions/checkout@v4
with:
fetch-depth: 1

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.11.5'
cache: 'pip'

- name: Run tests with MLX extras
run: |
python -m venv venv
source venv/bin/activate
pip install -e .[mlx,dev,examples]
PYTHONPATH=$PYTHONPATH:$(pwd) pytest tests --cov=llamppl --cov-report=json

- name: Upload MLX coverage to Codecov
uses: codecov/codecov-action@v5
with:
fail_ci_if_error: false
disable_search: true
token: ${{ secrets.CODECOV_TOKEN }}
files: ./coverage.json
slug: genlm/llamppl
8 changes: 7 additions & 1 deletion benchmark/benchmark_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from examples.haiku import run_example as run_haiku
from examples.hard_constraints import run_example as run_hard_constraints
from hfppl.llms import CachedCausalLM
from llamppl.llms import CachedCausalLM, MLX_AVAILABLE

backends = [
"hf",
Expand All @@ -21,6 +21,12 @@
not torch.cuda.is_available(), reason="vLLM backend requires CUDA"
),
),
pytest.param(
"mlx",
marks=pytest.mark.skipif(
not MLX_AVAILABLE, reason="MLX backend requires MLX-LM"
),
),
]


Expand Down
20 changes: 16 additions & 4 deletions llamppl/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,20 @@
from genlm.backend.llm import AsyncTransformer
from genlm.backend.llm import AsyncVirtualLM
from genlm.backend.llm import MockAsyncLM
from genlm.backend.llm import AsyncMlxLM

VLLM_AVAILABLE = True
try:
import vllm
except ImportError:
VLLM_AVAILABLE = False

MLX_AVAILABLE = True
try:
import mlx_lm
except ImportError:
MLX_AVAILABLE = False

warnings.filterwarnings("once", category=DeprecationWarning)
warnings.filterwarnings("once", category=RuntimeWarning)

Expand Down Expand Up @@ -201,6 +208,7 @@ def from_pretrained(cls, model_id, backend=None, **kwargs):
- 'vllm' to instantiate an `AsyncVirtualLM`; ideal for GPU usage
- 'hf' for an `AsyncTransformer`; ideal for CPU usage
- 'mock' for a `MockAsyncLM`; ideal for testing.
- 'mlx' for an `AsyncMlxLM`; ideal for usage on devices with Apple silicon.
Defaults to 'vllm' if CUDA is available, otherwise 'hf'.
**kwargs: Additional keyword arguments passed to the `AsyncLM` constructor.
See [`AsyncLM` documentation](https://probcomp.github.io/genlm-backend/reference/genlm_backend/llm/__init__/).
Expand All @@ -223,9 +231,11 @@ def from_pretrained(cls, model_id, backend=None, **kwargs):
model_cls = AsyncTransformer
elif backend == "mock":
model_cls = MockAsyncLM
elif backend == "mlx":
model_cls = AsyncMlxLM
else:
raise ValueError(
f"Unknown backend: {backend}. Must be one of ['vllm', 'hf', 'mock']"
f"Unknown backend: {backend}. Must be one of ['vllm', 'hf', 'mock', 'mlx']"
)

# Handle legacy auth_token parameter. The ability to pass in the auth_token should
Expand Down Expand Up @@ -280,9 +290,11 @@ def __init__(self, model):
self.backend = "hf"
elif isinstance(model, MockAsyncLM):
self.backend = "mock"
elif isinstance(model, AsyncMlxLM):
self.backend = "mlx"
else:
raise ValueError(
f"Unknown model type: {type(model)}. Must be one of [AsyncVirtualLM, AsyncTransformer, MockAsyncLM]"
f"Unknown model type: {type(model)}. Must be one of [AsyncVirtualLM, AsyncTransformer, MockAsyncLM, AsyncMlxLM]"
)

self.model = model
Expand Down Expand Up @@ -355,7 +367,7 @@ def clear_kv_cache(self):

def reset_async_queries(self):
"""Clear any pending language model queries from the queue."""
if self.backend == "hf":
if self.backend in ["hf", "mlx"]:
self.model.reset_async_queries()
elif self.backend == "vllm":
warnings.warn(
Expand All @@ -376,7 +388,7 @@ def cache_kv(self, prompt_tokens):
Args:
prompt_tokens (list[int]): token ids for the prompt to cache.
"""
if self.backend == "hf":
if self.backend in ["hf", "mlx"]:
self.model.cache_kv(prompt_tokens)
elif self.backend == "vllm":
warnings.warn(
Expand Down
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@ dependencies = [
]

[project.optional-dependencies]
vllm = ["vllm>=0.6.6"]
vllm = [
"vllm>=0.6.6",
"triton==3.2.0",
]
mlx = ["genlm-backend[mlx]>=0.1.7"]
dev = [
"pytest",
"pytest-benchmark",
Expand Down
31 changes: 19 additions & 12 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,32 @@

from examples.haiku import run_example as run_haiku
from examples.hard_constraints import run_example as run_hard_constraints
from llamppl.llms import CachedCausalLM

backends = [
"mock",
"hf",
pytest.param(
"vllm",
marks=pytest.mark.skipif(
not torch.cuda.is_available(), reason="vLLM backend requires CUDA"
from llamppl.llms import CachedCausalLM, MLX_AVAILABLE

if MLX_AVAILABLE:
backends = ["mock", "mlx"]
else:
backends = [
"mock",
"hf",
pytest.param(
"vllm",
marks=pytest.mark.skipif(
not torch.cuda.is_available(), reason="vLLM backend requires CUDA"
),
),
),
]
]


@pytest.fixture
def LLM(backend):
# Set lower gpu_memory_utilization in vllm so that we can fit both models on the GPU
kwargs = (
{"engine_opts": {"gpu_memory_utilization": 0.45}} if backend == "vllm" else {}
{"engine_opts": {"gpu_memory_utilization": 0.45}}
if backend == "vllm"
else {"cache_size": 10}
if backend == "mlx"
else {}
)
return CachedCausalLM.from_pretrained("gpt2", backend=backend, **kwargs)

Expand Down
34 changes: 19 additions & 15 deletions tests/test_lmcontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,27 @@
import torch

from llamppl.distributions.lmcontext import LMContext
from llamppl.llms import CachedCausalLM

backends = [
"mock",
"hf",
pytest.param(
"vllm",
marks=pytest.mark.skipif(
not torch.cuda.is_available(), reason="vLLM backend requires CUDA"
from llamppl.llms import CachedCausalLM, MLX_AVAILABLE

if MLX_AVAILABLE:
backends = ["mock", "mlx"]
else:
backends = [
"mock",
"hf",
pytest.param(
"vllm",
marks=pytest.mark.skipif(
not torch.cuda.is_available(), reason="vLLM backend requires CUDA"
),
),
),
]
]


@pytest.fixture
def lm(backend):
return CachedCausalLM.from_pretrained("gpt2", backend=backend)
kwargs = {"cache_size": 10} if backend == "mlx" else {}
return CachedCausalLM.from_pretrained("gpt2", backend=backend, **kwargs)


@pytest.mark.parametrize("backend", backends)
Expand All @@ -33,7 +37,7 @@ def test_init(lm):
np.testing.assert_allclose(
lmcontext.next_token_logprobs,
logprobs,
rtol=1e-5,
rtol=5e-4,
err_msg="Sync context __init__",
)

Expand All @@ -44,7 +48,7 @@ async def async_context():
np.testing.assert_allclose(
lmcontext.next_token_logprobs,
logprobs,
rtol=1e-5,
rtol=5e-4,
err_msg="Async context __init__",
)

Expand All @@ -55,6 +59,6 @@ async def async_context_create():
np.testing.assert_allclose(
lmcontext.next_token_logprobs,
logprobs,
rtol=1e-5,
rtol=5e-4,
err_msg="Async context create",
)