Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
qinxuye committed Jan 19, 2025
1 parent 9e081f6 commit ab64442
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 5 deletions.
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ all =
optimum
outlines>=0.0.34
sglang>=0.2.7 ; sys_platform=='linux'
mlx-lm ; sys_platform=='darwin' and platform_machine=='arm64'
mlx-lm>=0.21.1 ; sys_platform=='darwin' and platform_machine=='arm64'
mlx-vlm>=0.1.11 ; sys_platform=='darwin' and platform_machine=='arm64'
mlx-whisper ; sys_platform=='darwin' and platform_machine=='arm64'
f5-tts-mlx ; sys_platform=='darwin' and platform_machine=='arm64'
Expand Down Expand Up @@ -182,7 +182,7 @@ sglang =
vllm>=0.5.2 ; sys_platform=='linux'
outlines>=0.0.34
mlx =
mlx-lm
mlx-lm>=0.21.1
mlx-vlm>=0.1.11
mlx-whisper
f5-tts-mlx
Expand Down
16 changes: 13 additions & 3 deletions xinference/model/llm/mlx/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,10 @@ def _sanitize_generate_config(
# default config is adapted from
# https://github.com/ml-explore/mlx-examples/blob/f212b770d8b5143e23102eda20400ae43340f844/llms/mlx_lm/utils.py#L129
generate_config.setdefault("temperature", 0.0)
generate_config.setdefault("logit_bias", None)
generate_config.setdefault("repetition_penalty", None)
generate_config.setdefault("repetition_context_size", 20)
generate_config.setdefault("top_p", 1.0)
generate_config.setdefault("logit_bias", None)
return generate_config

def _load_model(self, **kwargs):
Expand Down Expand Up @@ -199,14 +199,24 @@ def _get_prompt_cache(
return prompt

def _generate_stream_inner(self, **kwargs):
from mlx_lm.utils import make_sampler, stream_generate
from mlx_lm.utils import make_logits_processors, make_sampler, stream_generate

sampler = make_sampler(
temp=kwargs.pop("temperature"), top_p=kwargs.pop("top_p")
)
prompt_token_ids = kwargs.pop("prompt_token_ids")
logits_processors = make_logits_processors(
logit_bias=kwargs.pop("logits_bias", None),
repetition_penalty=kwargs.pop("repetition_penalty"),
repetition_context_size=kwargs.pop("repetition_context_size"),
)
yield from stream_generate(
self._model, self._tokenizer, prompt_token_ids, sampler=sampler, **kwargs
self._model,
self._tokenizer,
prompt_token_ids,
sampler=sampler,
logits_processors=logits_processors,
**kwargs,
)

def _prepare_inputs(
Expand Down

0 comments on commit ab64442

Please sign in to comment.