Skip to content

Commit

Permalink
[CI/Build] Simplify model loading for HfRunner (vllm-project#5251)
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 authored Jun 4, 2024
1 parent 27208be commit 9ba093b
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 13 deletions.
27 changes: 16 additions & 11 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import contextlib
import gc
import os
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, TypeVar

import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from transformers import (AutoModelForCausalLM, AutoProcessor, AutoTokenizer,
LlavaConfig, LlavaForConditionalGeneration)
from transformers import (AutoModelForCausalLM, AutoModelForVision2Seq,
AutoProcessor, AutoTokenizer, BatchEncoding)

from vllm import LLM, SamplingParams
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
Expand Down Expand Up @@ -144,16 +145,12 @@ def example_long_prompts() -> List[str]:
"float": torch.float,
}

AutoModelForCausalLM.register(LlavaConfig, LlavaForConditionalGeneration)

_EMBEDDING_MODELS = [
"intfloat/e5-mistral-7b-instruct",
]
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding)


class HfRunner:

def wrap_device(self, input: any):
def wrap_device(self, input: _T) -> _T:
if not is_cpu():
return input.to("cuda")
else:
Expand All @@ -163,13 +160,16 @@ def __init__(
self,
model_name: str,
dtype: str = "half",
*,
is_embedding_model: bool = False,
is_vision_model: bool = False,
) -> None:
assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]

self.model_name = model_name

if model_name in _EMBEDDING_MODELS:
if is_embedding_model:
# Lazy init required for AMD CI
from sentence_transformers import SentenceTransformer
self.model = self.wrap_device(
Expand All @@ -178,8 +178,13 @@ def __init__(
device="cpu",
).to(dtype=torch_dtype))
else:
if is_vision_model:
auto_cls = AutoModelForVision2Seq
else:
auto_cls = AutoModelForCausalLM

self.model = self.wrap_device(
AutoModelForCausalLM.from_pretrained(
auto_cls.from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=True,
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_models(
model: str,
dtype: str,
) -> None:
hf_model = hf_runner(model, dtype=dtype)
hf_model = hf_runner(model, dtype=dtype, is_embedding_model=True)
hf_outputs = hf_model.encode(example_prompts)
del hf_model

Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def test_models(hf_runner, vllm_runner, hf_image_prompts, hf_images,
"""
model_id, vision_language_config = model_and_config

hf_model = hf_runner(model_id, dtype=dtype)
hf_model = hf_runner(model_id, dtype=dtype, is_vision_model=True)
hf_outputs = hf_model.generate_greedy(hf_image_prompts,
max_tokens,
images=hf_images)
Expand Down

0 comments on commit 9ba093b

Please sign in to comment.