Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Jamba support #4115

Merged
merged 117 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from 86 commits
Commits
Show all changes
117 commits
Select commit Hold shift + click to select a range
337f67a
Merged in jurassic-2.5 (pull request #1)
ErezSC42 Mar 18, 2024
0330e14
Merged in jamba-3 (pull request #4)
Apr 16, 2024
07cc899
Jamba mamba (#3)
mzusman Apr 2, 2024
6d336f6
Cuda graph (#5)
mzusman Apr 8, 2024
00bce1f
dtype (#6)
mzusman Apr 8, 2024
39c27b7
N support (#8)
mzusman Apr 14, 2024
7c75868
Tensor parallelism (#7)
mzusman Apr 14, 2024
30e6dcd
After merge fixes
Apr 14, 2024
5c0efdc
Clean up
Apr 14, 2024
19f11f3
Add release mamba cache to executor_base
Apr 15, 2024
1fb817a
Add jamba modifications
Apr 15, 2024
30ae4a1
Add minimun 1 attention layer
Apr 15, 2024
7bd9c0a
More fixes
Apr 15, 2024
d5ac8e8
Delete mamba cache
Apr 15, 2024
60b49b5
Jamba padding to the left
Apr 16, 2024
c583fe8
Clean up
Apr 16, 2024
c951b7d
Add import
Apr 16, 2024
da6d0f2
Another clean up
Apr 16, 2024
eb79923
Align to main
Apr 16, 2024
919edba
Fix reduce
Apr 16, 2024
4668566
Another fix
Apr 16, 2024
11a0737
Black format for jamba
Apr 16, 2024
7e3415e
Formatting
Apr 16, 2024
adbd2ae
Formatting with format.sh
Apr 16, 2024
6daf2a2
Adding to docs and more
Apr 16, 2024
7ee927b
Add to readme
Apr 16, 2024
87fa299
Adding comments for prefill mamba
Apr 16, 2024
8bca3b6
Formating
Apr 16, 2024
b421877
Remove mamba-ssm and conv1d from the build system requirements
Apr 17, 2024
d9c3319
Remove autoconfig for jamba
Apr 28, 2024
1c0fad8
Move get_attention_num_layers to model_config
Apr 28, 2024
e831cfc
Merge branch 'main' of https://github.com/vllm-project/vllm into jamb…
Apr 28, 2024
1033904
Fix in model config
Apr 28, 2024
2b93182
Formatting
Apr 28, 2024
b2f86f8
Add layers_block_type_support to model config
May 5, 2024
7061df7
Update Jamba to support changes from main
May 5, 2024
054faf1
Take Jamba config off since its now in transformers
May 5, 2024
fb3fc83
Take jamba config off
May 5, 2024
6d8765d
Format
May 5, 2024
10896ae
Refactor the model runner a little , make it more readable and chage
May 5, 2024
d1dc26f
rename release mamba to release seqlen agnostic
May 5, 2024
07c8cd2
Move requirements of mamba to its own requirements
May 5, 2024
5c11285
Remove mamba metadata since its mamba specific
May 5, 2024
2bb3360
Align with master
May 5, 2024
a235c44
Change comment
May 5, 2024
af7a4ac
(1) implement contains_seqlen_agnostic_layers with use of self.get_nu…
tomeras91 May 6, 2024
988718e
Jamba official hf (#14)
tomeras91 May 6, 2024
49ce3df
fixes missed in merge with official Jamba HF format
tomeras91 May 7, 2024
4fa065f
Merge with main
zhuohan123 May 16, 2024
7add09a
fix merge error
zhuohan123 May 17, 2024
14fbab5
Fix bug where seqlen agnostic cache wasn't copied for non driver workers
May 20, 2024
e3dec15
WIP - encapsulate Jamba cache managemnt inside the modeling file
Jun 25, 2024
92778c4
Cleanup
Jun 25, 2024
db36427
Typos and cleanup
Jun 26, 2024
7f6edfc
Another typo
Jun 26, 2024
ee5f058
Keep the finished requests ids after in the scheduler after each step…
Jun 26, 2024
2d42367
Cleanup
Jun 26, 2024
6a6378c
clean up requests after profile
Jun 26, 2024
1a8e2f9
Update mamba requirements
Jun 26, 2024
eb89987
Renaming
Jun 26, 2024
1cb8c1c
Renaming
Jun 26, 2024
feca5d5
Rename and docs
Jun 26, 2024
72c31cc
Format
Jun 26, 2024
ddeb689
Add mamba to Dockerfile
Jun 26, 2024
5d5a3be
Mamba disable prompt batching
Jun 26, 2024
85715fe
Format
Jun 26, 2024
84aa88f
Merge branch 'gh-main' into jamba-support-pr
Jun 26, 2024
8c6d82d
Fix jamba bug
Jun 26, 2024
628eec7
Renaming
Jun 26, 2024
30030ce
WIP - Merge with main adaptations
Jun 26, 2024
3fba9bc
Fix
Jun 26, 2024
45f3d96
Formating
Jun 26, 2024
794f1c3
deploy the finihsed request ids inside the modelinputs instead of worker
Jun 27, 2024
33eb405
fix
Jun 27, 2024
25c03e7
Renaming
Jun 27, 2024
94d40a8
Format
Jun 27, 2024
976166f
Typing and format
Jun 27, 2024
8181821
Cleanup
Jun 27, 2024
4fdc35b
Remove requirements-common and cuda from requirements-mamba
Jun 27, 2024
aadeca2
Fix
Jun 27, 2024
fee775e
set finished requests ids as none on default
Jun 27, 2024
668f3d9
get attr to get num hidden layers
Jun 27, 2024
10a44dc
Add jamba test
Jun 27, 2024
cd9ba35
Ignore jamba test in cpu
Jun 27, 2024
6df4f69
Cleanup
Jun 27, 2024
75dd84e
Format and rename
Jun 27, 2024
577f678
Format
Jun 27, 2024
7bb332e
change num_layers to num_attention_layers and add comment
Jun 29, 2024
c051758
Extended the finished reqeusts ids comment
Jun 29, 2024
b6dc237
Format and make the jamba code more readable, adding comments and
Jun 29, 2024
24b4bf2
Merge branch 'gh-main' into jamba-support-pr
Jun 29, 2024
b0b0836
Format
Jun 30, 2024
e52e4d7
Resolve conflicts and format
Jun 30, 2024
b4d49e0
Add finished requests ids to the prepare model spec decoding
Jun 30, 2024
68e27de
Format
Jun 30, 2024
670ff3a
Test cleanup
Jun 30, 2024
b7e31e3
Add message to test
Jun 30, 2024
571f63d
Add docstring in vllm/config.py
Jul 1, 2024
49da326
rename flush to get_and_reset
Jul 1, 2024
688732e
Add comments
Jul 1, 2024
4a6b170
Change to private and check finished through all of the queue
Jul 1, 2024
2047a91
CI
Jul 1, 2024
f2c407f
Merge branch 'gh-main' into jamba-support-pr
Jul 1, 2024
5d932a4
Pipeline Parallelism
andoorve Jul 1, 2024
3c15001
Make scheduler use pipeline parallel size directly
andoorve Jul 1, 2024
1ff2cdb
Change ABC defaults for prepare_model_input
andoorve Jul 1, 2024
548f4e8
Add basic comm ops tests with TP and PP.
andoorve Jul 1, 2024
5a4b323
Fix phi3v for test
andoorve Jul 1, 2024
c92257c
Address Nick nits and fix CUDAGraph correctness
andoorve Jul 2, 2024
60bb1a7
Merge branch 'pipeline-parallel' into jamba-support-pr
Jul 2, 2024
2ea2b80
Merge branch 'gh-main' into jamba-support-pr
Jul 2, 2024
10d8f3c
Formating and fixing llm engine
Jul 2, 2024
1331a8f
Align with main and format
Jul 2, 2024
21c92b4
Fix bug
Jul 2, 2024
726ccad
Format
Jul 2, 2024
4b6a491
Add intermediate tensors
Jul 2, 2024
da5d94a
Format
Jul 2, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .buildkite/run-cpu-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py"
docker exec cpu-test bash -c "cd tests;
pip install pytest Pillow protobuf
cd ../
pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py"
pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py --ignore=tests/models/test_jamba.py" # Mamba on CPU is not supported
23 changes: 23 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ COPY requirements-cuda.txt requirements-cuda.txt
RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install -r requirements-cuda.txt

COPY requirements-mamba.txt requirements-mamba.txt
RUN python3 -m pip install packaging
RUN python3 -m pip install -r requirements-mamba.txt

# cuda arch list used by torch
# can be useful for both `dev` and `test`
# explicitly set the list to avoid issues with torch 2.2
Expand Down Expand Up @@ -123,6 +127,21 @@ RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install -r requirements-dev.txt

#################### DEV IMAGE ####################
#################### MAMBA Build IMAGE ####################
FROM dev as mamba-builder
# max jobs used for build
ARG max_jobs=2
ENV MAX_JOBS=${max_jobs}

WORKDIR /usr/src/mamba

COPY requirements-mamba.txt requirements-mamba.txt

# Download the wheel or build it if a pre-compiled release doesn't exist
RUN pip --verbose wheel -r requirements-mamba.txt \
--no-build-isolation --no-deps --no-cache-dir

#################### MAMBA Build IMAGE ####################

#################### vLLM installation IMAGE ####################
# image with vLLM installed
Expand All @@ -143,6 +162,10 @@ RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/
RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \
--mount=type=cache,target=/root/.cache/pip \
python3 -m pip install dist/*.whl --verbose

RUN --mount=type=bind,from=mamba-builder,src=/usr/src/mamba,target=/usr/src/mamba \
--mount=type=cache,target=/root/.cache/pip \
python3 -m pip install /usr/src/mamba/*.whl --no-cache-dir
#################### vLLM installation IMAGE ####################


Expand Down
4 changes: 4 additions & 0 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ Alongside each architecture, we include some popular models that use it.
- Jais
- :code:`core42/jais-13b`, :code:`core42/jais-13b-chat`, :code:`core42/jais-30b-v3`, :code:`core42/jais-30b-chat-v3`, etc.
-
* - :code:`JambaForCausalLM`
- Jamba
- :code:`ai21labs/Jamba-v0.1`, etc.
- ✅︎
* - :code:`LlamaForCausalLM`
- LLaMA, Llama 2, Meta Llama 3, Vicuna, Alpaca, Yi
- :code:`meta-llama/Meta-Llama-3-8B-Instruct`, :code:`meta-llama/Meta-Llama-3-70B-Instruct`, :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`01-ai/Yi-6B`, :code:`01-ai/Yi-34B`, etc.
Expand Down
3 changes: 3 additions & 0 deletions requirements-mamba.txt
mzusman marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Mamba dependencies
mamba-ssm>=1.2.2
causal-conv1d>=1.2.0
46 changes: 46 additions & 0 deletions tests/models/test_jamba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import pytest

MODELS = ["ai21labs/Jamba-tiny-random"]


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [20])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
# To pass the small model tests, we need full precision.
assert dtype == "float"
mzusman marked this conversation as resolved.
Show resolved Hide resolved

with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)

with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)

for i in range(len(example_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i]
vllm_output_ids, vllm_output_str = vllm_outputs[i]
assert hf_output_str == vllm_output_str, (
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_model_print(
vllm_runner,
model: str,
dtype: str,
) -> None:
with vllm_runner(model, dtype=dtype) as vllm_model:
# This test is for verifying whether the model's extra_repr
# can be printed correctly.
print(vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model)
28 changes: 27 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,9 +349,35 @@ def get_num_attention_heads(self,
return num_heads // parallel_config.tensor_parallel_size

def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
total_num_hidden_layers = self.hf_text_config.num_hidden_layers
total_num_hidden_layers = getattr(self.hf_text_config,
mzusman marked this conversation as resolved.
Show resolved Hide resolved
"num_hidden_layers", 0)
return total_num_hidden_layers // parallel_config.pipeline_parallel_size

def contains_seqlen_agnostic_layers(
mzusman marked this conversation as resolved.
Show resolved Hide resolved
self, parallel_config: "ParallelConfig") -> bool:
return self.get_num_seqlen_agnostic_layers(parallel_config) > 0

def get_layers_block_type(self,
parallel_config: "ParallelConfig") -> List[str]:
num_layers = self.get_num_layers(parallel_config)
# Transformers supports layers_block_type @property
return getattr(self.hf_config, "layers_block_type",
["attention"] * num_layers)

def get_num_attention_layers(self,
parallel_config: "ParallelConfig") -> int:
return len([
t for t in self.get_layers_block_type(parallel_config)
if t == "attention"
])

def get_num_seqlen_agnostic_layers(
self, parallel_config: "ParallelConfig") -> int:
return len([
t for t in self.get_layers_block_type(parallel_config)
if t != "attention"
])


class CacheConfig:
"""Configuration for the KV cache.
Expand Down
13 changes: 12 additions & 1 deletion vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,8 @@ def __init__(
# Sequence groups in the SWAPPED state.
# Contain decode requests that are swapped out.
self.swapped: Deque[SequenceGroup] = deque()

# Sequence groups finished since last step iter.
self.finished_requests_ids: List[str] = list()
mzusman marked this conversation as resolved.
Show resolved Hide resolved
mzusman marked this conversation as resolved.
Show resolved Hide resolved
# Time at previous scheduling step
self.prev_time = 0.0
# Did we schedule a prompt at previous step?
Expand Down Expand Up @@ -364,6 +365,12 @@ def has_unfinished_seqs(self) -> bool:
def get_num_unfinished_seq_groups(self) -> int:
return len(self.waiting) + len(self.running) + len(self.swapped)

def flush_finished_requests_ids(self) -> List[str]:
mzusman marked this conversation as resolved.
Show resolved Hide resolved
"""Flushes the list of request ids of previously finished seq_groups."""
finished_requests_ids = self.finished_requests_ids
self.finished_requests_ids = []
return finished_requests_ids

def _schedule_running(
self,
running_queue: deque,
Expand Down Expand Up @@ -1027,6 +1034,10 @@ def free_seq(self, seq: Sequence) -> None:
self.block_manager.free(seq)

def free_finished_seq_groups(self) -> None:
self.finished_requests_ids += [
seq_group.request_id for seq_group in self.running
if seq_group.is_finished()
]
mzusman marked this conversation as resolved.
Show resolved Hide resolved
self.running = deque(seq_group for seq_group in self.running
if not seq_group.is_finished())

Expand Down
3 changes: 2 additions & 1 deletion vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,8 @@ async def step_async(
blocks_to_copy=scheduler_outputs.blocks_to_copy,
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
running_queue_size=scheduler_outputs.running_queue_size,
)
finished_requests_ids=self.scheduler.
flush_finished_requests_ids())
output = await self.model_executor.execute_model_async(
execute_model_req)
else:
Expand Down
3 changes: 2 additions & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,7 +802,8 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
blocks_to_copy=scheduler_outputs.blocks_to_copy,
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
running_queue_size=scheduler_outputs.running_queue_size,
)
finished_requests_ids=self.scheduler.
flush_finished_requests_ids())
output = self.model_executor.execute_model(
execute_model_req=execute_model_req)
else:
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
"JambaForCausalLM": ("jamba", "JambaForCausalLM")
mzusman marked this conversation as resolved.
Show resolved Hide resolved
}

_EMBEDDING_MODELS = {
Expand Down
Loading
Loading