diff --git a/verl/workers/rollout/vllm_rollout/utils.py b/verl/workers/rollout/vllm_rollout/utils.py index d2369dc2e89..e827d04d36b 100644 --- a/verl/workers/rollout/vllm_rollout/utils.py +++ b/verl/workers/rollout/vllm_rollout/utils.py @@ -218,6 +218,9 @@ def update_weights_from_ipc(self, peft_config: dict = None, base_sync_done=False if use_standard_weight_load: patch_vllm_moe_model_weight_loader(self.model_runner.model) + # save vllm model to safetensors before weight update + self._save_vllm_model_to_safetensors(suffix="before") + # receive bucket and update weights while True: metadata = socket.recv_pyobj() @@ -249,6 +252,9 @@ def update_weights_from_ipc(self, peft_config: dict = None, base_sync_done=False model_config = self.model_runner.vllm_config.model_config process_weights_after_loading(model, model_config, self.device) + # save vllm model to safetensors after weight update + self._save_vllm_model_to_safetensors(suffix="after") + # clean up socket.close() del buffer @@ -284,6 +290,25 @@ def _update_weights(self, weights: list[tuple[str, torch.Tensor]], peft_config: logger.info("Loading standard weights (non-FP8, async)") self.model_runner.model.load_weights(weights) + def _save_vllm_model_to_safetensors(self, suffix: str): + """Save vLLM model to safetensors via save_model. + Dir from env VERL_VLLM_WEIGHT_SAVE_DIR (default None for not saving). + Args: + suffix: The suffix of the saved model. + """ + save_dir = os.environ.get("VERL_VLLM_WEIGHT_SAVE_DIR", None) + if save_dir is None: + return + os.makedirs(save_dir, exist_ok=True) + model = self.model_runner.model + rank = getattr(self, "local_rank", 0) + device_uuid = get_device_uuid(self.device.index) + path = os.path.join(save_dir, f"vllm_weights_{suffix}_rank{rank}_device{device_uuid}.safetensors") + from safetensors.torch import save_model + + save_model(model, path) + logger.info(f"vLLM model saved to {path}") + def _get_zmq_handle(self) -> str: """Get ZMQ handle for communication.""" if not hasattr(self, "device_uuid") or not self.device_uuid: