From 52c294c6819c702a0ed308f0798f65933c295e0b Mon Sep 17 00:00:00 2001 From: pathfinder-fp Date: Tue, 23 Sep 2025 09:48:50 +0800 Subject: [PATCH] add default_sampling_params --- python/sgl_jax/srt/entrypoints/engine.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/python/sgl_jax/srt/entrypoints/engine.py b/python/sgl_jax/srt/entrypoints/engine.py index ab9e6e20..c2dca1c0 100644 --- a/python/sgl_jax/srt/entrypoints/engine.py +++ b/python/sgl_jax/srt/entrypoints/engine.py @@ -12,7 +12,7 @@ import os import signal import threading -from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union +from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union, Any import zmq import zmq.asyncio @@ -43,6 +43,7 @@ set_ulimit, ) from sgl_jax.version import __version__ +from sgl_jax.srt.sampling import SamplingParams logger = logging.getLogger(__name__) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) @@ -99,6 +100,8 @@ def __init__(self, **kwargs): context, zmq.DEALER, self.port_args.rpc_ipc_name, True ) + self.default_sampling_params: Union[dict[str, Any], None] = None + def generate( self, prompt: Optional[Union[List[str], str]] = None, @@ -343,6 +346,15 @@ async def async_score( request=None, ) + def get_default_sampling_params(self) -> SamplingParams: + if self.default_sampling_params is None: + self.default_sampling_params = ( + self.llm_engine.model_config.get_diff_sampling_param()) + if self.default_sampling_params: + return SamplingParams.from_optional(**self.default_sampling_params) + return SamplingParams() + + def _set_envs_and_config(): # Set ulimit