Skip to content

Commit

Permalink
[Feat] Support update weights without restart server (#1157)
Browse files Browse the repository at this point in the history
  • Loading branch information
shanyu-sys authored Aug 20, 2024
1 parent 350a816 commit cd10654
Show file tree
Hide file tree
Showing 7 changed files with 303 additions and 13 deletions.
5 changes: 5 additions & 0 deletions python/sglang/srt/managers/detokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
BatchEmbeddingOut,
BatchStrOut,
BatchTokenIDOut,
UpdateWeightReqOutput,
)
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR
from sglang.srt.server_args import PortArgs, ServerArgs
Expand Down Expand Up @@ -84,6 +85,10 @@ async def handle_loop(self):
)
continue

if isinstance(recv_obj, UpdateWeightReqOutput):
self.send_to_tokenizer.send_pyobj(recv_obj)
continue

assert isinstance(recv_obj, BatchTokenIDOut)
bs = len(recv_obj.rids)

Expand Down
14 changes: 14 additions & 0 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,20 @@ class FlushCacheReq:
pass


@dataclass
class UpdateWeightReqInput:
# The model path with the new weights
model_path: str
# The format to load the weights
load_format: Optional[str] = None


@dataclass
class UpdateWeightReqOutput:
success: bool
message: str


@dataclass
class AbortReq:
# The request id
Expand Down
45 changes: 42 additions & 3 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
GenerateReqInput,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
UpdateWeightReqInput,
UpdateWeightReqOutput,
)
from sglang.srt.mm_utils import expand2square, process_anyres_image
from sglang.srt.sampling_params import SamplingParams
Expand Down Expand Up @@ -121,6 +123,10 @@ def __init__(
self.to_create_loop = True
self.rid_to_state: Dict[str, ReqState] = {}

# for update model weights
self.model_update_lock = asyncio.Lock()
self.model_update_result = None

async def get_pixel_values(self, image_data):
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
grid_pinpoints = (
Expand All @@ -146,6 +152,9 @@ async def generate_request(
if self.to_create_loop:
self.create_handle_loop()

while self.model_update_lock.locked():
await asyncio.sleep(0)

obj.post_init()
is_single = obj.is_single

Expand Down Expand Up @@ -513,6 +522,30 @@ def flush_cache(self):
req = FlushCacheReq()
self.send_to_router.send_pyobj(req)

async def update_weights(self, obj: UpdateWeightReqInput, request):
if self.to_create_loop:
self.create_handle_loop()

# default the load format to the server_args
if obj.load_format is None:
obj.load_format = self.server_args.load_format

if not self.model_update_lock.locked():
async with self.model_update_lock:
# wait for the previous generation requests to finish
while len(self.rid_to_state) > 0:
await asyncio.sleep(0)
self.send_to_router.send_pyobj(obj)
self.model_update_result = asyncio.Future()
result = await self.model_update_result
if result.success:
self.server_args.model_path = obj.model_path
self.server_args.load_format = obj.load_format
self.model_path = obj.model_path
return result.success, result.message
else:
return False, "Another update is in progress. Please try again later."

def abort_request(self, rid: str):
if rid not in self.rid_to_state:
return
Expand Down Expand Up @@ -541,12 +574,18 @@ def create_handle_loop(self):

async def handle_loop(self):
while True:
recv_obj: Union[BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut] = (
await self.recv_from_detokenizer.recv_pyobj()
)
recv_obj: Union[
BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut, UpdateWeightReqOutput
] = await self.recv_from_detokenizer.recv_pyobj()

if isinstance(recv_obj, UpdateWeightReqOutput):
self.model_update_result.set_result(recv_obj)
continue

assert isinstance(
recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
), f"Unexpected obj received: {type(recv_obj)}"

for i, rid in enumerate(recv_obj.rids):
state = self.rid_to_state.get(rid, None)
if state is None:
Expand Down
17 changes: 17 additions & 0 deletions python/sglang/srt/managers/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
FlushCacheReq,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
UpdateWeightReqInput,
UpdateWeightReqOutput,
)
from sglang.srt.managers.policy_scheduler import PolicyScheduler, PrefillAdder
from sglang.srt.managers.schedule_batch import (
Expand Down Expand Up @@ -214,6 +216,9 @@ def exposed_step(self, recv_reqs: List):
self.flush_cache()
elif isinstance(recv_req, AbortReq):
self.abort_request(recv_req)
elif isinstance(recv_req, UpdateWeightReqInput):
success, message = self.update_weights(recv_req)
self.out_pyobjs.append(UpdateWeightReqOutput(success, message))
else:
raise ValueError(f"Invalid request: {recv_req}")

Expand Down Expand Up @@ -773,12 +778,15 @@ def flush_cache(self):
self.token_to_kv_pool.clear()
torch.cuda.empty_cache()
logger.info("Cache flushed successfully!")
if_success = True
else:
logging.warning(
f"Cache not flushed because there are pending requests. "
f"#queue-req: {len(self.waiting_queue)}, "
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
)
if_success = False
return if_success

def abort_request(self, recv_req):
# Delete requests in the waiting queue
Expand All @@ -798,6 +806,15 @@ def abort_request(self, recv_req):
req.finished_reason = FINISH_ABORT()
break

def update_weights(self, recv_req):
success, message = self.model_runner.update_weights(
recv_req.model_path, recv_req.load_format
)
if success:
flash_cache_success = self.flush_cache()
assert flash_cache_success, "Cache flush failed after updating weights"
return success, message


def run_tp_server(
gpu_id: int,
Expand Down
106 changes: 97 additions & 9 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

"""ModelRunner runs the forward passes of the models."""

import gc
import importlib
import importlib.resources
import logging
Expand Down Expand Up @@ -157,9 +158,9 @@ def load_model(self):
self.server_args.dtype = "float16"

monkey_patch_vllm_dummy_weight_loader()
device_config = DeviceConfig()
load_config = LoadConfig(load_format=self.server_args.load_format)
vllm_model_config = VllmModelConfig(
self.device_config = DeviceConfig()
self.load_config = LoadConfig(load_format=self.server_args.load_format)
self.vllm_model_config = VllmModelConfig(
model=self.server_args.model_path,
quantization=self.server_args.quantization,
tokenizer=None,
Expand All @@ -173,17 +174,19 @@ def load_model(self):
if is_llama3_405b_fp8_head_16(self.model_config) and self.tp_size <= 8:
# A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
self.model_config.hf_config.num_key_value_heads = 8
vllm_model_config.hf_config.num_key_value_heads = 8
self.vllm_model_config.hf_config.num_key_value_heads = 8
monkey_patch_vllm_qvk_linear_loader()

self.dtype = vllm_model_config.dtype
self.dtype = self.vllm_model_config.dtype
if self.model_config.model_overide_args is not None:
vllm_model_config.hf_config.update(self.model_config.model_overide_args)
self.vllm_model_config.hf_config.update(
self.model_config.model_overide_args
)

self.model = get_model(
model_config=vllm_model_config,
device_config=device_config,
load_config=load_config,
model_config=self.vllm_model_config,
device_config=self.device_config,
load_config=self.load_config,
lora_config=None,
multimodal_config=None,
parallel_config=None,
Expand All @@ -206,6 +209,91 @@ def load_model(self):
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
)

def update_weights(self, model_path, load_format):
from vllm.model_executor.model_loader.loader import (
DefaultModelLoader,
device_loading_context,
get_model_loader,
)
from vllm.model_executor.model_loader.utils import set_default_torch_dtype

logger.info(
f"[gpu={self.gpu_id}] Update weights begin. "
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
)

target_device = torch.device(self.device_config.device)

try:
vllm_model_config = VllmModelConfig(
model=model_path,
quantization=self.server_args.quantization,
tokenizer=None,
tokenizer_mode=None,
trust_remote_code=self.server_args.trust_remote_code,
dtype=self.server_args.dtype,
seed=42,
skip_tokenizer_init=True,
)
except Exception as e:
logger.error(f"Failed to load model config: {e}")
return False, "Failed to update model weights"

load_config = LoadConfig(load_format=load_format)

# Only support vllm DefaultModelLoader for now
loader = get_model_loader(load_config)
if not isinstance(loader, DefaultModelLoader):
logger.error("Failed to get weights iterator: Unsupported loader")
return False, "Failed to update model weights"

def get_weight_iter(config):
iter = loader._get_weights_iterator(
config.model,
config.revision,
fall_back_to_pt=getattr(
self.model, "fall_back_to_pt_during_load", True
),
)
return iter

def model_load_weights(model, iter):
model.load_weights(iter)
for _, module in self.model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
return model

with set_default_torch_dtype(vllm_model_config.dtype):
try:
iter = get_weight_iter(vllm_model_config)
except Exception as e:
message = f"Failed to get weights iterator: {e}"
logger.error(message)
return False, message
try:
model = model_load_weights(self.model, iter)
except Exception as e:
message = f"Failed to update weights: {e}. \n Rolling back to original weights"
logger.error(message)
del iter
gc.collect()
iter = get_weight_iter(self.vllm_model_config)
self.model = model_load_weights(self.model, iter)
return False, message

self.model = model
self.server_args.model_path = model_path
self.server_args.load_format = load_format
self.vllm_model_config = vllm_model_config
self.load_config = load_config
self.model_config.path = model_path

logger.info(f"[gpu={self.gpu_id}] Update weights end.")
return True, "Succeeded to update model weights"

def profile_max_num_token(self, total_gpu_memory):
available_gpu_memory = get_available_gpu_memory(
self.gpu_id, distributed=self.tp_size > 1
Expand Down
23 changes: 22 additions & 1 deletion python/sglang/srt/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,11 @@
start_controller_process as start_controller_process_single,
)
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
from sglang.srt.managers.io_struct import (
EmbeddingReqInput,
GenerateReqInput,
UpdateWeightReqInput,
)
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.openai_api.adapter import (
load_chat_template_for_openai_api,
Expand Down Expand Up @@ -136,6 +140,23 @@ async def flush_cache():
)


@app.post("/update_weights")
async def update_weights(obj: UpdateWeightReqInput, request: Request):

success, message = await tokenizer_manager.update_weights(obj, request)
content = {"message": message, "success": str(success)}
if success:
return JSONResponse(
content,
status_code=HTTPStatus.OK,
)
else:
return JSONResponse(
content,
status_code=HTTPStatus.BAD_REQUEST,
)


async def generate_request(obj: GenerateReqInput, request: Request):
"""Handle a generate request."""
if obj.stream:
Expand Down
Loading

0 comments on commit cd10654

Please sign in to comment.