diff --git a/create_test_data.py b/create_test_data.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/xinference/core/supervisor.py b/xinference/core/supervisor.py
index a0e83229ab..7554ca66be 100644
--- a/xinference/core/supervisor.py
+++ b/xinference/core/supervisor.py
@@ -37,6 +37,7 @@
     from ..model.embedding import EmbeddingModelSpec
     from ..model.image import ImageModelFamilyV1
     from ..model.llm import LLMFamilyV1
+    from ..model.multimodal import LVLMFamilyV1
     from ..model.rerank import RerankModelSpec
     from .worker import WorkerActor
 
@@ -222,6 +223,25 @@ def _to_image_model_reg(
                 "is_builtin": is_builtin,
             }
 
+    def _to_multimodal_reg(
+        self, model_family: "LVLMFamilyV1", is_builtin: bool
+    ) -> Dict[str, Any]:
+        from ..model.llm import get_cache_status
+
+        if self.is_local_deployment():
+            specs = []
+            # TODO: does not work when the supervisor and worker are running on separate nodes.
+            for spec in model_family.model_specs:
+                cache_status = get_cache_status(model_family, spec)
+                specs.append({**spec.dict(), "cache_status": cache_status})
+            return {
+                **model_family.dict(),
+                "is_builtin": is_builtin,
+                "model_specs": specs,
+            }
+        else:
+            return {**model_family.dict(), "is_builtin": is_builtin}
+
     @log_sync(logger=logger)
     def list_model_registrations(
         self, model_type: str, detailed: bool = False
@@ -302,6 +322,18 @@ def sort_helper(item):
                         {"model_name": model_spec.model_name, "is_builtin": False}
                     )
 
+            ret.sort(key=sort_helper)
+            return ret
+        elif model_type == "multimodal":
+            from ..model.multimodal import BUILTIN_LVLM_FAMILIES
+
+            ret = []
+            for family in BUILTIN_LVLM_FAMILIES:
+                if detailed:
+                    ret.append(self._to_multimodal_reg(family, True))
+                else:
+                    ret.append({"model_name": family.model_name, "is_builtin": True})
+
             ret.sort(key=sort_helper)
             return ret
         else:
@@ -342,6 +374,13 @@ def get_model_registration(self, model_type: str, model_name: str) -> Any:
                 if f.model_name == model_name:
                     return f
             raise ValueError(f"Model {model_name} not found")
+        elif model_type == "multimodal":
+            from ..model.multimodal import BUILTIN_LVLM_FAMILIES
+
+            for f in BUILTIN_LVLM_FAMILIES:
+                if f.model_name == model_name:
+                    return f
+            raise ValueError(f"Model {model_name} not found")
         else:
             raise ValueError(f"Unsupported model type: {model_type}")
 
diff --git a/xinference/model/core.py b/xinference/model/core.py
index 9414c504e4..bcc465247c 100644
--- a/xinference/model/core.py
+++ b/xinference/model/core.py
@@ -44,6 +44,7 @@ def create_model_instance(
     from .embedding.core import create_embedding_model_instance
     from .image.core import create_image_model_instance
     from .llm.core import create_llm_model_instance
+    from .multimodal.core import create_multimodal_model_instance
     from .rerank.core import create_rerank_model_instance
 
     if model_type == "LLM":
@@ -74,5 +75,10 @@ def create_model_instance(
         return create_rerank_model_instance(
             subpool_addr, devices, model_uid, model_name, **kwargs
         )
+    elif model_type == "multimodal":
+        kwargs.pop("trust_remote_code", None)
+        return create_multimodal_model_instance(
+            subpool_addr, devices, model_uid, model_name, **kwargs
+        )
     else:
         raise ValueError(f"Unsupported model type: {model_type}.")
diff --git a/xinference/model/llm/pytorch/core.py b/xinference/model/llm/pytorch/core.py
index b943098586..db75c721d0 100644
--- a/xinference/model/llm/pytorch/core.py
+++ b/xinference/model/llm/pytorch/core.py
@@ -29,6 +29,7 @@
     PytorchGenerateConfig,
     PytorchModelConfig,
 )
+from ...utils import select_device
 from ..core import LLM
 from ..llm_family import LLMFamilyV1, LLMSpecV1
 from ..utils import ChatModelMixin
@@ -122,7 +123,7 @@ def load(self):
         quantization = self.quantization
         num_gpus = len(cuda_visible_devices) if cuda_visible_devices_env != "-1" else 0
         device = self._pytorch_model_config.get("device", "auto")
-        self._pytorch_model_config["device"] = self._select_device(device)
+        self._pytorch_model_config["device"] = select_device(device)
         self._device = self._pytorch_model_config["device"]
 
         if self._device == "cpu":
@@ -185,33 +186,6 @@ def load(self):
             self._model.to(self._device)
         logger.debug(f"Model Memory: {self._model.get_memory_footprint()}")
 
-    def _select_device(self, device: str) -> str:
-        try:
-            import torch
-        except ImportError:
-            raise ImportError(
-                f"Failed to import module 'torch'. Please make sure 'torch' is installed.\n\n"
-            )
-
-        if device == "auto":
-            # When env CUDA_VISIBLE_DEVICES=-1, torch.cuda.is_available() return False
-            if torch.cuda.is_available():
-                return "cuda"
-            elif torch.backends.mps.is_available():
-                return "mps"
-            return "cpu"
-        elif device == "cuda":
-            if not torch.cuda.is_available():
-                raise ValueError("cuda is unavailable in your environment")
-        elif device == "mps":
-            if not torch.backends.mps.is_available():
-                raise ValueError("mps is unavailable in your environment")
-        elif device == "cpu":
-            pass
-        else:
-            raise ValueError(f"Device {device} is not supported in temporary")
-        return device
-
     @classmethod
     def match(
         cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
diff --git a/xinference/model/llm/pytorch/spec_model.py b/xinference/model/llm/pytorch/spec_model.py
index e438bbb264..a66f6fbfc1 100644
--- a/xinference/model/llm/pytorch/spec_model.py
+++ b/xinference/model/llm/pytorch/spec_model.py
@@ -17,6 +17,7 @@
 from typing import Iterator, List, Optional, Union
 
 from ....types import Completion, CompletionChunk, Embedding
+from ...utils import select_device
 from .. import LLMFamilyV1, LLMSpecV1
 from .core import PytorchChatModel, PytorchGenerateConfig, PytorchModelConfig
 
@@ -85,7 +86,7 @@ def load(self):
 
         num_gpus = len(cuda_visible_devices) if cuda_visible_devices_env != "-1" else 0
         device = self._pytorch_model_config.get("device", "auto")
-        self._pytorch_model_config["device"] = self._select_device(device)
+        self._pytorch_model_config["device"] = select_device(device)
         self._device = self._pytorch_model_config["device"]
 
         if self._device == "cpu":
diff --git a/xinference/model/multimodal/__init__.py b/xinference/model/multimodal/__init__.py
new file mode 100644
index 0000000000..bae4627739
--- /dev/null
+++ b/xinference/model/multimodal/__init__.py
@@ -0,0 +1,45 @@
+# Copyright 2022-2023 XProbe Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import codecs
+import json
+import os
+
+from .core import (
+    BUILTIN_LVLM_FAMILIES,
+    BUILTIN_MODELSCOPE_LVLM_FAMILIES,
+    MODEL_CLASSES,
+    MODEL_NAME_TO_REVISION,
+    LVLMFamilyV1,
+    LVLMPromptStyleV1,
+)
+from .qwen_vl import QwenVLChat
+
+MODEL_CLASSES.append(QwenVLChat)
+
+
+def _install():
+    json_path = os.path.join(
+        os.path.dirname(os.path.abspath(__file__)), "model_spec.json"
+    )
+    for json_obj in json.load(codecs.open(json_path, "r", encoding="utf-8")):
+        model_family = LVLMFamilyV1.parse_obj(json_obj)
+        BUILTIN_LVLM_FAMILIES.append(model_family)
+        for model_spec in model_family.model_specs:
+            MODEL_NAME_TO_REVISION[model_family.model_name].append(
+                model_spec.model_revision
+            )
+
+
+_install()
diff --git a/xinference/model/multimodal/core.py b/xinference/model/multimodal/core.py
new file mode 100644
index 0000000000..678c8583b1
--- /dev/null
+++ b/xinference/model/multimodal/core.py
@@ -0,0 +1,460 @@
+# Copyright 2022-2023 XProbe Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import abc
+import logging
+import os
+from abc import abstractmethod
+from collections import defaultdict
+from typing import Dict, Iterator, List, Literal, Optional, Tuple, Type, Union
+
+from pydantic import BaseModel, validator
+
+from ...constants import XINFERENCE_CACHE_DIR
+from ...core.utils import parse_replica_model_uid
+from ...types import ChatCompletion, ChatCompletionChunk
+from ..core import ModelDescription
+from ..utils import (
+    download_from_modelscope,
+    is_model_cached,
+    retry_download,
+    symlink_local_file,
+    valid_model_revision,
+)
+
+logger = logging.getLogger(__name__)
+
+DEFAULT_CONTEXT_LENGTH = 2048
+# Used for check whether the model is cached.
+# Init when registering all the builtin models.
+MODEL_NAME_TO_REVISION: Dict[str, List[str]] = defaultdict(list)
+
+
+class LVLMSpecV1(BaseModel):
+    model_format: Literal["pytorch", "gptq"]
+    # Must in order that `str` first, then `int`
+    model_size_in_billions: Union[str, int]
+    quantizations: List[str]
+    model_id: str
+    model_hub: str = "huggingface"
+    model_uri: Optional[str]
+    model_revision: Optional[str]
+
+    @validator("model_size_in_billions", pre=False)
+    def validate_model_size_with_radix(cls, v: object) -> object:
+        if isinstance(v, str):
+            if (
+                "_" in v
+            ):  # for example, "1_8" just returns "1_8", otherwise int("1_8") returns 18
+                return v
+            else:
+                return int(v)
+        return v
+
+
+class LVLMPromptStyleV1(BaseModel):
+    style_name: str
+    system_prompt: str = ""
+    roles: List[str]
+
+
+class LVLMFamilyV1(BaseModel):
+    version: Literal[1]
+    context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH
+    model_name: str
+    model_lang: List[str]
+    model_ability: List[Literal["chat"]]
+    model_description: Optional[str]
+    model_specs: List["LVLMSpecV1"]
+    prompt_style: Optional["LVLMPromptStyleV1"]
+
+
+class LVLMDescription(ModelDescription):
+    def __init__(
+        self,
+        address: Optional[str],
+        devices: Optional[List[str]],
+        model_family: "LVLMFamilyV1",
+        model_spec: "LVLMSpecV1",
+        quantization: Optional[str],
+    ):
+        super().__init__(address, devices)
+        self._model_family = model_family
+        self._model_spec = model_spec
+        self._quantization = quantization
+
+    def to_dict(self):
+        return {
+            "model_type": "LVLM",
+            "address": self.address,
+            "accelerators": self.devices,
+            "model_name": self._model_family.model_name,
+            "model_lang": self._model_family.model_lang,
+            "model_ability": self._model_family.model_ability,
+            "model_description": self._model_family.model_description,
+            "model_format": self._model_spec.model_format,
+            "model_size_in_billions": self._model_spec.model_size_in_billions,
+            "quantization": self._quantization,
+            "model_hub": self._model_spec.model_hub,
+            "revision": self._model_spec.model_revision,
+            "context_length": self._model_family.context_length,
+        }
+
+
+class LVLM(abc.ABC):
+    def __init__(
+        self,
+        replica_model_uid: str,
+        model_family: "LVLMFamilyV1",
+        model_spec: "LVLMSpecV1",
+        quantization: str,
+        model_path: str,
+        kwargs: Dict,
+    ):
+        self.model_uid, self.replica, self.rep_id = parse_replica_model_uid(
+            replica_model_uid
+        )
+        self.model_family = model_family
+        self.model_spec = model_spec
+        self.quantization = quantization
+        self.model_path = model_path
+        self.kwargs = kwargs
+        logger.info("Init model %s with kwargs: %s", self.model_uid, kwargs)
+
+    @abstractmethod
+    def load(self):
+        raise NotImplementedError
+
+    @abstractmethod
+    def chat(
+        self,
+        prompt: str,
+        system_prompt: Optional[str] = None,
+        chat_history: Optional[List[Dict]] = None,
+        generate_config: Optional[Dict] = None,
+    ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
+        raise NotImplementedError
+
+    @classmethod
+    def match(
+        cls, model_family: "LVLMFamilyV1", model_spec: "LVLMSpecV1", quantization: str
+    ) -> bool:
+        raise NotImplementedError
+
+
+BUILTIN_LVLM_FAMILIES: List["LVLMFamilyV1"] = []
+BUILTIN_MODELSCOPE_LVLM_FAMILIES: List["LVLMFamilyV1"] = []
+
+
+def match_multimodal(
+    model_name: str,
+    model_format: Optional[str] = None,
+    model_size_in_billions: Optional[int] = None,
+    quantization: Optional[str] = None,
+) -> Optional[Tuple[LVLMFamilyV1, LVLMSpecV1, str]]:
+    """
+    Find an multimodal family, spec, and quantization that satisfy given criteria.
+    """
+
+    def _match_quantization(q: Union[str, None], quantizations: List[str]):
+        # Currently, the quantization name could include both uppercase and lowercase letters,
+        # so it is necessary to ensure that the case sensitivity does not
+        # affect the matching results.
+        if q is None:
+            return q
+        for quant in quantizations:
+            if q.lower() == quant.lower():
+                return quant
+
+    def _apply_format_to_model_id(spec: LVLMSpecV1, q: str) -> LVLMSpecV1:
+        # Different quantized versions of some models use different model ids,
+        # Here we check the `{}` in the model id to format the id.
+        if "{" in spec.model_id:
+            spec.model_id = spec.model_id.format(quantization=q)
+        return spec
+
+    if download_from_modelscope():
+        all_families = BUILTIN_MODELSCOPE_LVLM_FAMILIES + BUILTIN_LVLM_FAMILIES
+    else:
+        all_families = BUILTIN_LVLM_FAMILIES
+
+    for family in all_families:
+        if model_name != family.model_name:
+            continue
+        for spec in family.model_specs:
+            matched_quantization = _match_quantization(quantization, spec.quantizations)
+            if (
+                model_format
+                and model_format != spec.model_format
+                or model_size_in_billions
+                and model_size_in_billions != spec.model_size_in_billions
+                or quantization
+                and matched_quantization is None
+            ):
+                continue
+            if quantization:
+                return (
+                    family,
+                    _apply_format_to_model_id(spec, matched_quantization),
+                    matched_quantization,
+                )
+            else:
+                return family, _apply_format_to_model_id(spec, "none"), "none"
+    return None
+
+
+def create_multimodal_model_instance(
+    subpool_addr: str,
+    devices: List[str],
+    model_uid: str,
+    model_name: str,
+    model_format: Optional[str] = None,
+    model_size_in_billions: Optional[int] = None,
+    quantization: Optional[str] = None,
+    **kwargs,
+) -> Tuple[LVLM, LVLMDescription]:
+    match_result = match_multimodal(
+        model_name,
+        model_format,
+        model_size_in_billions,
+        quantization,
+    )
+    if not match_result:
+        raise ValueError(
+            f"Model not found, name: {model_name}, format: {model_format},"
+            f" size: {model_size_in_billions}, quantization: {quantization}"
+        )
+    model_family, model_spec, quantization = match_result
+
+    assert quantization is not None
+    save_path = cache(model_family, model_spec, quantization)
+
+    cls = match_cls(model_family, model_spec, quantization)
+    logger.debug(f"Launching {model_uid} with {cls.__name__}")
+
+    model = cls(model_uid, model_family, model_spec, quantization, save_path, kwargs)
+    return model, LVLMDescription(
+        subpool_addr, devices, model_family, model_spec, quantization
+    )
+
+
+MODEL_CLASSES: List[Type[LVLM]] = []
+
+
+def match_cls(
+    model_family: LVLMFamilyV1, model_spec: "LVLMSpecV1", quantization: str
+) -> Type[LVLM]:
+    """
+    Find an multimodal implementation for given multimodal family and spec.
+    """
+    for cls in MODEL_CLASSES:
+        if cls.match(model_family, model_spec, quantization):
+            return cls
+    raise Exception(f"Model {model_family.model_name} is not supported")
+
+
+def _get_cache_dir(
+    model_family: LVLMFamilyV1,
+    model_spec: "LVLMSpecV1",
+    create_if_not_exist=True,
+):
+    cache_dir_name = (
+        f"{model_family.model_name}-{model_spec.model_format}"
+        f"-{model_spec.model_size_in_billions}b"
+    )
+    cache_dir = os.path.realpath(os.path.join(XINFERENCE_CACHE_DIR, cache_dir_name))
+    if create_if_not_exist and not os.path.exists(cache_dir):
+        os.makedirs(cache_dir, exist_ok=True)
+    return cache_dir
+
+
+def _get_meta_path(
+    cache_dir: str,
+    model_format: str,
+    model_hub: str,
+    quantization: Optional[str] = None,
+):
+    if model_format == "pytorch":
+        if model_hub == "huggingface":
+            return os.path.join(cache_dir, "__valid_download")
+        else:
+            return os.path.join(cache_dir, f"__valid_download_{model_hub}")
+    elif model_format in ["ggmlv3", "ggufv2", "gptq"]:
+        assert quantization is not None
+        if model_hub == "huggingface":
+            return os.path.join(cache_dir, f"__valid_download_{quantization}")
+        else:
+            return os.path.join(
+                cache_dir, f"__valid_download_{model_hub}_{quantization}"
+            )
+    else:
+        raise ValueError(f"Unsupported format: {model_format}")
+
+
+def _skip_download(
+    cache_dir: str,
+    model_format: str,
+    model_hub: str,
+    model_revision: Optional[str],
+    quantization: Optional[str] = None,
+) -> bool:
+    if model_format == "pytorch":
+        model_hub_to_meta_path = {
+            "huggingface": _get_meta_path(
+                cache_dir, model_format, "huggingface", quantization
+            ),
+            "modelscope": _get_meta_path(
+                cache_dir, model_format, "modelscope", quantization
+            ),
+        }
+        if valid_model_revision(model_hub_to_meta_path[model_hub], model_revision):
+            logger.info(f"Cache {cache_dir} exists")
+            return True
+        else:
+            for hub, meta_path in model_hub_to_meta_path.items():
+                if hub != model_hub and os.path.exists(meta_path):
+                    # PyTorch models from modelscope can also be loaded by transformers.
+                    logger.warning(f"Cache {cache_dir} exists, but it was from {hub}")
+                    return True
+            return False
+    else:
+        raise ValueError(f"Unsupported format: {model_format}")
+
+
+def _generate_meta_file(
+    meta_path: str,
+    model_family: "LVLMFamilyV1",
+    model_spec: "LVLMSpecV1",
+    quantization: Optional[str] = None,
+):
+    assert not valid_model_revision(
+        meta_path, model_spec.model_revision
+    ), f"meta file {meta_path} should not be valid"
+    with open(meta_path, "w") as f:
+        import json
+
+        desc = LVLMDescription(None, None, model_family, model_spec, quantization)
+        json.dump(desc.to_dict(), f)
+
+
+def cache_from_modelscope(
+    model_family: LVLMFamilyV1,
+    model_spec: "LVLMSpecV1",
+    quantization: Optional[str] = None,
+) -> str:
+    """
+    Cache model from Modelscope. Return the cache directory.
+    """
+    from modelscope.hub.snapshot_download import snapshot_download
+
+    cache_dir = _get_cache_dir(model_family, model_spec)
+    if _skip_download(
+        cache_dir,
+        model_spec.model_format,
+        model_spec.model_hub,
+        model_spec.model_revision,
+        quantization,
+    ):
+        return cache_dir
+
+    if model_spec.model_format in ["pytorch", "gptq"]:
+        download_dir = retry_download(
+            snapshot_download,
+            model_family.model_name,
+            {
+                "model_size": model_spec.model_size_in_billions,
+                "model_format": model_spec.model_format,
+            },
+            model_spec.model_id,
+            revision=model_spec.model_revision,
+        )
+        for subdir, dirs, files in os.walk(download_dir):
+            for file in files:
+                relpath = os.path.relpath(os.path.join(subdir, file), download_dir)
+                symlink_local_file(os.path.join(subdir, file), cache_dir, relpath)
+    else:
+        raise ValueError(f"Unsupported format: {model_spec.model_format}")
+
+    meta_path = _get_meta_path(
+        cache_dir, model_spec.model_format, model_spec.model_hub, quantization
+    )
+    _generate_meta_file(meta_path, model_family, model_spec, quantization)
+
+    return cache_dir
+
+
+def cache_from_huggingface(
+    model_family: LVLMFamilyV1,
+    model_spec: "LVLMSpecV1",
+    quantization: Optional[str] = None,
+) -> str:
+    """
+    Cache model from Hugging Face. Return the cache directory.
+    """
+    import huggingface_hub
+
+    cache_dir = _get_cache_dir(model_family, model_spec)
+    if _skip_download(
+        cache_dir,
+        model_spec.model_format,
+        model_spec.model_hub,
+        model_spec.model_revision,
+        quantization,
+    ):
+        return cache_dir
+
+    if model_spec.model_format in ["pytorch"]:
+        assert isinstance(model_spec, LVLMSpecV1)
+        retry_download(
+            huggingface_hub.snapshot_download,
+            model_family.model_name,
+            {
+                "model_size": model_spec.model_size_in_billions,
+                "model_format": model_spec.model_format,
+            },
+            model_spec.model_id,
+            revision=model_spec.model_revision,
+            local_dir=cache_dir,
+            local_dir_use_symlinks=True,
+        )
+    else:
+        raise ValueError(f"Unsupported model format: {model_spec.model_format}")
+
+    meta_path = _get_meta_path(
+        cache_dir, model_spec.model_format, model_spec.model_hub, quantization
+    )
+    _generate_meta_file(meta_path, model_family, model_spec, quantization)
+
+    return cache_dir
+
+
+def cache(
+    model_family: LVLMFamilyV1,
+    model_spec: "LVLMSpecV1",
+    quantization: Optional[str] = None,
+) -> str:
+    if model_spec.model_hub == "huggingface":
+        logger.info(f"Caching from Hugging Face: {model_spec.model_id}")
+        return cache_from_huggingface(model_family, model_spec, quantization)
+    elif model_spec.model_hub == "modelscope":
+        logger.info(f"Caching from Modelscope: {model_spec.model_id}")
+        return cache_from_modelscope(model_family, model_spec, quantization)
+    else:
+        raise ValueError(f"Unknown model hub: {model_spec.model_hub}")
+
+
+def get_cache_status(
+    model_spec: LVLMSpecV1,
+) -> bool:
+    return is_model_cached(model_spec, MODEL_NAME_TO_REVISION)
diff --git a/xinference/model/multimodal/model_spec.json b/xinference/model/multimodal/model_spec.json
new file mode 100644
index 0000000000..07af7f2f19
--- /dev/null
+++ b/xinference/model/multimodal/model_spec.json
@@ -0,0 +1,34 @@
+[
+  {
+    "version": 1,
+    "context_length": 4096,
+    "model_name": "qwen-vl-chat",
+    "model_lang": [
+      "en",
+      "zh"
+    ],
+    "model_ability": [
+      "chat"
+    ],
+    "model_description": "Qwen-VL-Chat supports more flexible interaction, such as multiple image inputs, multi-round question answering, and creative capabilities.",
+    "model_specs": [
+      {
+        "model_format": "pytorch",
+        "model_size_in_billions": 7,
+        "quantizations": [
+          "none"
+        ],
+        "model_id": "Qwen/Qwen-VL-Chat",
+        "model_revision": "6665c780ade5ff3f08853b4262dcb9c8f9598d42"
+      }
+    ],
+    "prompt_style": {
+      "style_name": "QWEN",
+      "system_prompt": "You are a helpful assistant.",
+      "roles": [
+        "user",
+        "assistant"
+      ]
+    }
+  }
+]
diff --git a/xinference/model/multimodal/qwen_vl.py b/xinference/model/multimodal/qwen_vl.py
new file mode 100644
index 0000000000..55e29fe182
--- /dev/null
+++ b/xinference/model/multimodal/qwen_vl.py
@@ -0,0 +1,120 @@
+# Copyright 2022-2023 XProbe Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import operator
+import time
+import uuid
+from typing import Dict, Iterator, List, Optional, Union
+
+from ...types import (
+    ChatCompletion,
+    ChatCompletionChoice,
+    ChatCompletionChunk,
+    CompletionUsage,
+)
+from ..utils import select_device
+from .core import LVLM, LVLMFamilyV1, LVLMSpecV1
+
+
+class QwenVLChat(LVLM):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self._tokenizer = None
+        self._model = None
+
+    @classmethod
+    def match(
+        cls, model_family: "LVLMFamilyV1", model_spec: "LVLMSpecV1", quantization: str
+    ) -> bool:
+        if "qwen" in model_family.model_name:
+            return True
+        return False
+
+    def load(self):
+        from transformers import AutoModelForCausalLM, AutoTokenizer
+        from transformers.generation import GenerationConfig
+
+        device = self.kwargs.get("device", "auto")
+        device = select_device(device)
+
+        self._tokenizer = AutoTokenizer.from_pretrained(
+            self.model_path,
+            trust_remote_code=True,
+            code_revision=self.model_spec.model_revision,
+        )
+        self._model = AutoModelForCausalLM.from_pretrained(
+            self.model_path,
+            device_map=device,
+            trust_remote_code=True,
+            code_revision=self.model_spec.model_revision,
+        ).eval()
+        # Specify hyperparameters for generation
+        self._model.generation_config = GenerationConfig.from_pretrained(
+            self.model_path,
+            trust_remote_code=True,
+            code_revision=self.model_spec.model_revision,
+        )
+
+    def _message_content_to_qwen(self, content) -> str:
+        if not isinstance(content, str):
+            content = [
+                {"image": c["image_url"]["url"], "type": "image"}
+                if c.get("type") == "image_url"
+                else c
+                for c in content
+            ]
+            content = sorted(content, key=operator.itemgetter("type"))
+            return self._tokenizer.from_list_format(content)
+        return content
+
+    def chat(
+        self,
+        prompt: Union[str, List[Dict]],
+        system_prompt: Optional[str] = None,
+        chat_history: Optional[List[Dict]] = None,
+        generate_config: Optional[Dict] = None,
+    ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
+        prompt = self._message_content_to_qwen(prompt)
+        # Convert openai history to qwen vl history
+        qwen_history = []
+        query_to_response: List = []
+        for h in chat_history or []:
+            role = h["role"]
+            content = self._message_content_to_qwen(h["content"])
+            if len(query_to_response) == 0 and role == "user":
+                query_to_response.append(content)
+            if len(query_to_response) == 1 and role == "assistant":
+                query_to_response.append(content)
+            if len(query_to_response) == 2:
+                qwen_history.append(query_to_response)
+                query_to_response = []
+        response, history = self._model.chat(
+            self._tokenizer, query=prompt, history=qwen_history
+        )
+        return ChatCompletion(
+            id="chat" + str(uuid.uuid1()),
+            object="chat.completion",
+            created=int(time.time()),
+            model=self.model_uid,
+            choices=[
+                ChatCompletionChoice(
+                    index=0,
+                    message={"role": "assistant", "content": response},
+                    finish_reason="stop",
+                )
+            ],
+            usage=CompletionUsage(
+                prompt_tokens=-1, completion_tokens=-1, total_tokens=-1
+            ),
+        )
diff --git a/xinference/model/multimodal/tests/__init__.py b/xinference/model/multimodal/tests/__init__.py
new file mode 100644
index 0000000000..37f6558d95
--- /dev/null
+++ b/xinference/model/multimodal/tests/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2022-2023 XProbe Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/xinference/model/multimodal/tests/test_multimodal.py b/xinference/model/multimodal/tests/test_multimodal.py
new file mode 100644
index 0000000000..38317049b8
--- /dev/null
+++ b/xinference/model/multimodal/tests/test_multimodal.py
@@ -0,0 +1,80 @@
+# Copyright 2022-2023 XProbe Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import pytest
+
+
+@pytest.mark.skip(reason="Cost too many resources.")
+def test_restful_api_for_qwen_vl(setup):
+    endpoint, _ = setup
+    from ....client import Client
+
+    client = Client(endpoint)
+
+    model_uid = client.launch_model(
+        model_uid="my_controlnet",
+        model_name="qwen-vl-chat",
+        model_type="multimodal",
+        device="cpu",
+    )
+    model = client.get_model(model_uid)
+    assert model
+
+    # openai client
+    import openai
+
+    client = openai.Client(api_key="not empty", base_url=f"{endpoint}/v1")
+    completion = client.chat.completions.create(
+        model=model_uid,
+        messages=[
+            {
+                "role": "user",
+                "content": [
+                    {"type": "text", "text": "What’s in this image?"},
+                    {
+                        "type": "image_url",
+                        "image_url": {
+                            "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
+                        },
+                    },
+                ],
+            }
+        ],
+    )
+    assert "grass" in completion.choices[0].message.content
+    assert "tree" in completion.choices[0].message.content
+    assert "sky" in completion.choices[0].message.content
+    messages = [
+        {
+            "role": "user",
+            "content": [
+                {"type": "text", "text": "这是什么?"},
+                {
+                    "type": "image_url",
+                    "image_url": {
+                        "url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
+                    },
+                },
+            ],
+        }
+    ]
+    completion = client.chat.completions.create(model=model_uid, messages=messages)
+    assert "女" in completion.choices[0].message.content
+    assert "狗" in completion.choices[0].message.content
+    assert "沙滩" in completion.choices[0].message.content
+    messages.append(completion.choices[0].message.model_dump())
+    messages.append({"role": "user", "content": "框出图中击掌的位置"})
+    completion = client.chat.completions.create(model=model_uid, messages=messages)
+    assert "击掌" in completion.choices[0].message.content
+    assert "<ref>" in completion.choices[0].message.content
+    assert "<box>" in completion.choices[0].message.content
diff --git a/xinference/model/utils.py b/xinference/model/utils.py
index dafc25fd2a..490e1b5cb7 100644
--- a/xinference/model/utils.py
+++ b/xinference/model/utils.py
@@ -255,3 +255,31 @@ def _patched_resolve_trust_remote_code(*args, **kwargs):
             resolve_trust_remote_code.__code__ = (
                 _patched_resolve_trust_remote_code.__code__
             )
+
+
+def select_device(device):
+    try:
+        import torch
+    except ImportError:
+        raise ImportError(
+            f"Failed to import module 'torch'. Please make sure 'torch' is installed.\n\n"
+        )
+
+    if device == "auto":
+        # When env CUDA_VISIBLE_DEVICES=-1, torch.cuda.is_available() return False
+        if torch.cuda.is_available():
+            return "cuda"
+        elif torch.backends.mps.is_available():
+            return "mps"
+        return "cpu"
+    elif device == "cuda":
+        if not torch.cuda.is_available():
+            raise ValueError("cuda is unavailable in your environment")
+    elif device == "mps":
+        if not torch.backends.mps.is_available():
+            raise ValueError("mps is unavailable in your environment")
+    elif device == "cpu":
+        pass
+    else:
+        raise ValueError(f"Device {device} is not supported in temporary")
+    return device