diff --git a/examples/vllm_inference.py b/examples/vllm_inference.py index f9d329cb5..d6e4b7859 100644 --- a/examples/vllm_inference.py +++ b/examples/vllm_inference.py @@ -53,6 +53,7 @@ def main(): dataset, release_gpu=False, enable_decode_inference_result=pipeline_args.enable_decode_inference_result, + enable_distributed_vllm_inference=pipeline_args.enable_distributed_vllm_inference, ) diff --git a/scripts/run_vllm_inference.sh b/scripts/run_vllm_inference.sh index 681d2d5ec..8e9598dc3 100644 --- a/scripts/run_vllm_inference.sh +++ b/scripts/run_vllm_inference.sh @@ -74,4 +74,5 @@ python examples/vllm_inference.py \ --enable_decode_inference_result False \ --vllm_gpu_memory_utilization 0.95 \ --vllm_tensor_parallel_size 2 \ + --enable_distributed_vllm_inference False \ 2>&1 | tee ${log_dir}/vllm_inference.log \ No newline at end of file diff --git a/src/lmflow/pipeline/rm_inferencer.py b/src/lmflow/pipeline/rm_inferencer.py index ab1058d63..fa4c9bbbe 100644 --- a/src/lmflow/pipeline/rm_inferencer.py +++ b/src/lmflow/pipeline/rm_inferencer.py @@ -12,9 +12,12 @@ import json import time import logging -from typing import Dict, List, Union, Tuple +from typing import Dict, List, Union, Tuple, Any from accelerate import Accelerator +import ray +import ray.data +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy import torch from tqdm import tqdm from transformers import AutoConfig @@ -32,7 +35,8 @@ from lmflow.pipeline.base_pipeline import BasePipeline from lmflow.utils.data_utils import ( set_random_seed, - batchlize + batchlize, + RewardModelInferenceResultWithInput, ) from lmflow.datasets.dataset import KEY_SCORE @@ -61,6 +65,7 @@ def __init__( model_args: ModelArguments, data_args: DatasetArguments, inferencer_args: InferencerArguments, + **kwargs, ): self.data_args = data_args self.inferencer_args = inferencer_args @@ -79,8 +84,7 @@ def __init__( ) if inferencer_args.use_accelerator: - self.accelerator = Accelerator() - self.accelerator.wait_for_everyone() + self.accelerator: Accelerator = kwargs.get('accelerator', Accelerator()) def inference( @@ -89,6 +93,8 @@ def inference( dataset: Dataset, transform_dataset_in_place: bool=True, use_vllm: bool = False, + enable_distributed_inference: bool = False, + **kwargs, ) -> Dataset: if use_vllm: logger.warning("VLLM doesn't support reward model inference, using normal inference instead.") @@ -98,49 +104,73 @@ def inference( if not transform_dataset_in_place: dataset = copy.deepcopy(dataset) - output_dict = {"type": "", "instances": []} - if dataset.get_type() == "text_to_textlist": - output_dict["type"] = "text_to_scored_textlist" - for idx, instance in enumerate(dataset.get_backend_dataset()): - if len(instance["output"]) < 2: - logger.warning(f"Instance {idx} has less than 2 outputs, skipping.") - output_dict["instances"].append( - { - "input": instance["input"], - "output": [{"text": text} for text in instance["output"]], - } - ) - else: - raise NotImplementedError(f"Dataset type {dataset.get_type()} is not supported for reward model inference.") - + model_input = model.prepare_inputs_for_inference( + dataset=dataset, + apply_chat_template=True, + enable_distributed_inference=enable_distributed_inference, + use_vllm=use_vllm + ) + if use_vllm: - scores = self.__vllm_inference(model, dataset) + inference_result = self.__vllm_inference( + model=model, + model_input=model_input, + enable_distributed_inference=enable_distributed_inference, + ) else: - scores = self.__inference(model, dataset) - - for i, instance_scores in enumerate(scores): - for j, score in enumerate(instance_scores): - output_dict["instances"][i]["output"][j][KEY_SCORE] = score + inference_result = self._inference( + model=model, + model_input=model_input, + enable_distributed_inference=enable_distributed_inference, + **kwargs, + ) - output_dataset_args = copy.deepcopy(self.data_args) - output_dataset_args.dataset_path = None - output_dataset_args.dataset_name = f"{output_dataset_args.dataset_name}_scored" - output_dataset = Dataset(output_dataset_args) - output_dataset = output_dataset.from_dict(output_dict) + if enable_distributed_inference: + output_dataset = model.postprocess_distributed_inference_outputs( + dataset=dataset, + inference_result=inference_result, + ) + else: + output_dataset = model.postprocess_inference_outputs( + dataset=dataset, + scores=inference_result + ) return output_dataset + + + def _inference( + self, + model: HFTextRegressionModel, + model_input: Union[Dataset, ray.data.Dataset], + enable_distributed_inference: bool = False, + **kwargs, + ): + if enable_distributed_inference: + inference_res = self.__distributed_inference( + model=model, + model_input=model_input, + num_instances=kwargs.get("distributed_inference_num_instances", 1), + batch_size=kwargs.get("inference_batch_size", 1), + ) + else: + inference_res = self.__inference( + model=model, + model_input=model_input, + ) + + return inference_res def __inference( self, model: HFTextRegressionModel, - dataset: Dataset, + model_input: Dataset, ) -> Union[List[float], List[List[float]]]: - tokenized_dataset = model.tokenize(dataset) - if dataset.get_type() in ["text_to_textlist"]: - model_input_ids, num_outputs = self.flatten_list(tokenized_dataset.get_backend_dataset()["input_ids"]) + if model_input.get_type() in ["text_to_textlist"]: + model_input_ids, num_outputs = self.flatten_list(model_input.get_backend_dataset()["input_ids"]) else: - model_input_ids = tokenized_dataset.get_backend_dataset()["input_ids"] + model_input_ids = model_input.get_backend_dataset()["input_ids"] dataloader = batchlize( examples=model_input_ids, @@ -157,32 +187,132 @@ def __inference( unit="batch" ): # len(batch) = batch_size, and batch element is dataset sample - model_input = torch.LongTensor(batched_input_ids).to("cpu" if model.device == "cpu" else "cuda") + model_input_tensor = torch.LongTensor(batched_input_ids).to("cpu" if model.device == "cpu" else "cuda") if self.inferencer_args.use_accelerator: with self.accelerator.autocast(): batch_output = model.inference( - inputs=model_input, + inputs=model_input_tensor, use_vllm=False, ) else: batch_output = model.inference( - inputs=model_input, + inputs=model_input_tensor, use_vllm=False, ) batch_output = self.__post_process_model_output(batch_output) final_output.extend(batch_output) - if dataset.get_type() in ["text_to_textlist"]: + if model_input.get_type() in ["text_to_textlist"]: final_output = self.compress_list(final_output, num_outputs) return final_output + def __distributed_inference( + self, + model: HFTextRegressionModel, + model_input: ray.data.Dataset, + num_instances: int, + batch_size: int, + ) -> List[RewardModelInferenceResultWithInput]: + def scheduling_strategy_fn(): + # One bundle per tensor parallel worker + pg = ray.util.placement_group( + [{ + "GPU": 1, + "CPU": 1 + }] * self.inferencer_args.tensor_parallel_size, + strategy="STRICT_PACK", + ) + return dict( + scheduling_strategy=PlacementGroupSchedulingStrategy( + pg, placement_group_capture_child_tasks=True + ) + ) + + resources_kwarg: Dict[str, Any] = {} + if self.inferencer_args.tensor_parallel_size == 1: + # For tensor_parallel_size == 1, we simply set num_gpus=1. + resources_kwarg["num_gpus"] = 1 + else: + # Otherwise, we have to set num_gpus=0 and provide + # a function that will create a placement group for + # each instance. + resources_kwarg["num_gpus"] = 0 + resources_kwarg["ray_remote_args_fn"] = scheduling_strategy_fn + + ## predictor + class DistributedPredictor: + def __init__( + self, + model_args: ModelArguments, + ): + self.model = HFTextRegressionModel( + model_args=model_args, + tune_strategy='none', + use_accelerator=True + ) + self.model.activate_model_for_inference(use_vllm=False) + + def __call__(self, batch: Dict[str, np.ndarray]): + """batch: Dict[str, np.ndarray] + Example (batch size=2): + {'input': array(['...','...'], dtype=object), + 'output': array([array(["...", "..."], dtype=object), array(['...','...'], dtype=object)], dtype=object), + 'input_ids': array([[[128000, 128006, 882, ..., 128256, 128256, 128256], + [128000, 128006, 882, ..., 128256, 128256, 128256]], + [[128000, 128006, 882, ..., 128256, 128256, 128256], + [128000, 128006, 882, ..., 128256, 128256, 128256]]])} + """ + # The batch is managed by ray and the actual batch size may smaller than + # inference_batch_size in config, since there may be some remainders. + # For example, 10 examples with 2 inference instances and inference_batch_size=4, + # there will be only 2 examples for instance 0 to run and then the + # actual batch size changes. + actual_batch_size = len(batch['input']) + batched_inference_res = self.model.inference( + inputs=torch.LongTensor(batch['input_ids']).flatten(start_dim=0, end_dim=1).to("cuda"), + ).logits + batched_inference_res = batched_inference_res.to("cpu").reshape(actual_batch_size, -1, 1).squeeze(dim=-1).tolist() # [bs, num_output_sequences] + batched_final_res = { + "input": batch['input'].tolist(), + "output": [ + [ + {"score": batched_inference_res[j][i], "text": batch["output"][j][i]} + for i in range(len(batch['output'][j])) + ] + for j in range(actual_batch_size) + ], + } # do this since we're writing to a pandas dataframe + return batched_final_res + + # inference + model_input_mapping = model_input.map_batches( + DistributedPredictor, + concurrency=num_instances, # Set the concurrency to the number of LLM instances. + batch_size=batch_size, + fn_constructor_kwargs={ + "model_args": model.model_args, + }, + **resources_kwarg, + ) + + df_model_output = model_input_mapping.to_pandas() # the actual forwards are executed here + logger.info(f"Distributed reward model inference result preview:\n{df_model_output.head(10)}") + + model_output = [ + {"input": row["input"], "output": row["output"]} for _, row in df_model_output.iterrows() + ] + + return model_output + + def __vllm_inference( self, model: HFTextRegressionModel, - dataset: Dataset, + model_input: List[str], + enable_distributed_inference: bool = False, ) -> List[float]: raise NotImplementedError("VLLM inference for reward model is not implemented yet.") diff --git a/src/lmflow/pipeline/utils/memory_safe_vllm_inference.py b/src/lmflow/pipeline/utils/memory_safe_vllm_inference.py index 86f765acb..d7859d9a1 100644 --- a/src/lmflow/pipeline/utils/memory_safe_vllm_inference.py +++ b/src/lmflow/pipeline/utils/memory_safe_vllm_inference.py @@ -55,6 +55,9 @@ def main(): dataset, release_gpu=False, enable_decode_inference_result=pipeline_args.enable_decode_inference_result, + enable_distributed_inference=pipeline_args.enable_distributed_inference, + distributed_inference_num_instances=pipeline_args.distributed_inference_num_instances, + inference_batch_size=pipeline_args.vllm_inference_batch_size, ) # use this as a flag, stdout will be captured by the pipeline diff --git a/src/lmflow/pipeline/vllm_inferencer.py b/src/lmflow/pipeline/vllm_inferencer.py index f6c1f71a4..a68cb99aa 100644 --- a/src/lmflow/pipeline/vllm_inferencer.py +++ b/src/lmflow/pipeline/vllm_inferencer.py @@ -1,19 +1,23 @@ #!/usr/bin/env python # coding=utf-8 # Copyright 2024 Statistics and Machine Learning Research Group. All rights reserved. -import os -import sys -import signal +import copy +from functools import partial +import importlib.resources as pkg_resources import json -from pathlib import Path import logging +import os +os.environ['VLLM_WORKER_MULTIPROC_METHOD']='spawn' import subprocess -import importlib.resources as pkg_resources -from typing import List, Union, Optional -import time +import sys +from typing import List, Union, Optional, Dict, Any -from vllm import SamplingParams +import numpy as np +import ray +import ray.data +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from transformers import AutoTokenizer +from vllm import SamplingParams, LLM from lmflow.datasets import Dataset from lmflow.pipeline.base_pipeline import BasePipeline @@ -24,7 +28,8 @@ DatasetArguments, ) from lmflow.utils.common import make_shell_args_from_dataclass -from lmflow.utils.constants import RETURN_CODE_ERROR_BUFFER +from lmflow.utils.constants import RETURN_CODE_ERROR_BUFFER, MEMORY_SAFE_VLLM_INFERENCE_ENV_VAR_TO_REMOVE +from lmflow.utils.data_utils import VLLMInferenceResultWithInput logger = logging.getLogger(__name__) @@ -78,8 +83,8 @@ def parse_to_sampling_params( top_k=inference_args.top_k, stop_token_ids=[self.eos_token_id] + inference_args.additional_stop_token_ids ) - - + + def inference( self, model: HFDecoderModel, @@ -87,7 +92,9 @@ def inference( enable_decode_inference_result: bool = True, release_gpu: bool = False, inference_args: Optional[InferencerArguments] = None, - ) -> Union[List[List[str]], List[List[List[int]]]]: + enable_distributed_inference: bool = False, + **kwargs, + ) -> List[VLLMInferenceResultWithInput]: """Perform inference using the provided model and dataset. Will save inference results if `save_results` is set to True in `inferencer_args`. @@ -108,13 +115,14 @@ def inference( Returns ------- - Union[List[List[str]], List[List[List[int]]]] - When `enable_decode_inference_result = True`, return a list of list of strings. Inner list - contains inference_args.num_output_sequences samples for a single prompt - (i.e., `len(res[i]) = inference_args.num_output_sequences`). Outer list - contains the results for all prompts (i.e., `len(res) = len(dataset)`). + List[VLLMInferenceResultWithInput] + Return a list of VLLMInferenceResultWithInput, where each + element contains the input prompt and the corresponding output. + + When `enable_decode_inference_result = True`, the output would be a list of strings, + contains sampling_params.n samples for the corresponding prompt. - When `enable_decode_inference_result = False`, return a list of list of list of ints + When `enable_decode_inference_result = False`, return a list of list of ints (token ids, no decoding after generation). """ if inference_args: @@ -128,26 +136,151 @@ def inference( sampling_params.detokenize = enable_decode_inference_result model_input = model.prepare_inputs_for_inference( - dataset=dataset, + dataset=dataset, apply_chat_template=self.inferencer_args.apply_chat_template, - use_vllm=self.inferencer_args.use_vllm + use_vllm=self.inferencer_args.use_vllm, + enable_distributed_inference=enable_distributed_inference, ) + if enable_distributed_inference: + outputs = self._distributed_inference( + model=model, + model_input=model_input, + sampling_params=sampling_params, + num_instances=kwargs.get("distributed_inference_num_instances"), + batch_size=kwargs.get("inference_batch_size", 4), + release_gpu=release_gpu, + ) + else: + outputs = self._inference( + model=model, + model_input=model_input, + sampling_params=sampling_params, + release_gpu=release_gpu, + ) + + if self.inferencer_args.save_results: + self.save_inference_results(outputs, self.inferencer_args.results_path) + + return outputs + + + def _inference( + self, + model: HFDecoderModel, + model_input: List[str], + sampling_params: SamplingParams, + release_gpu: bool = False, + ) -> List[VLLMInferenceResultWithInput]: outputs = model.inference( inputs=model_input, sampling_params=sampling_params, release_gpu=release_gpu, - use_vllm=self.inferencer_args.use_vllm, + use_vllm=True, vllm_gpu_memory_utilization=self.inferencer_args.vllm_gpu_memory_utilization, vllm_tensor_parallel_size=self.inferencer_args.vllm_tensor_parallel_size, ) - - if self.inferencer_args.save_results: - self.save_inference_results(outputs, self.inferencer_args.results_path) - + return outputs + def _distributed_inference( + self, + model: HFDecoderModel, + model_input: ray.data.Dataset, + sampling_params: SamplingParams, + num_instances: int, + batch_size: int = 4, + release_gpu: bool = False, + ) -> List[VLLMInferenceResultWithInput]: + # prepare distributed inference resources + # from https://github.com/vllm-project/vllm/blob/main/examples/offline_inference_distributed.py + ## strategy + def scheduling_strategy_fn(): + # One bundle per tensor parallel worker + pg = ray.util.placement_group( + [{ + "GPU": 1, + "CPU": 1 + }] * self.inferencer_args.vllm_tensor_parallel_size, + strategy="STRICT_PACK", + ) + return dict( + scheduling_strategy=PlacementGroupSchedulingStrategy( + pg, placement_group_capture_child_tasks=True + ) + ) + + resources_kwarg: Dict[str, Any] = {} + if self.inferencer_args.vllm_tensor_parallel_size == 1: + # For tensor_parallel_size == 1, we simply set num_gpus=1. + resources_kwarg["num_gpus"] = 1 + else: + # Otherwise, we have to set num_gpus=0 and provide + # a function that will create a placement group for + # each instance. + resources_kwarg["num_gpus"] = 0 + resources_kwarg["ray_remote_args_fn"] = scheduling_strategy_fn + + ## predictor + class DistributedPredictor: + def __init__( + self, + model: HFDecoderModel, + sampling_params: SamplingParams, + vllm_gpu_memory_utilization: float, + vllm_tensor_parallel_size: int, + release_gpu: bool=False, + ): + self.model = copy.deepcopy(model) + self.model.activate_model_for_inference( + use_vllm=True, + vllm_gpu_memory_utilization=vllm_gpu_memory_utilization, + vllm_tensor_parallel_size=vllm_tensor_parallel_size, + ) + self.sampling_params = sampling_params + self.release_gpu = release_gpu + + def __call__(self, batch: Dict[str, np.ndarray]): + """batch: Dict[str, np.ndarray], {"item": array(['...', '...', '...', ...])} + """ + batched_inference_res = self.model.inference( + inputs=batch['item'], + sampling_params=self.sampling_params, + release_gpu=self.release_gpu, + use_vllm=True, + ) # this is the postprocessed output, see model.__vllm_inference + batched_final_res = { + "input": [sample['input'] for sample in batched_inference_res], + "output": [sample['output'] for sample in batched_inference_res] + } # do this since we're writing to a pandas dataframe + return batched_final_res + + # inference + model_input_mapping = model_input.map_batches( + DistributedPredictor, + concurrency=num_instances, # Set the concurrency to the number of LLM instances. + batch_size=batch_size, + fn_constructor_kwargs={ + "model": model, + "sampling_params": sampling_params, + "vllm_gpu_memory_utilization": self.inferencer_args.vllm_gpu_memory_utilization, + "vllm_tensor_parallel_size": self.inferencer_args.vllm_tensor_parallel_size, + "release_gpu": release_gpu, + }, + **resources_kwarg, + ) + + df_model_output = model_input_mapping.to_pandas() # the actual forwards are executed here + logger.info(f"Distributed vllm inference result preview:\n{df_model_output.head(10)}") + + model_output = [ + {"input": row["input"], "output": row["output"]} for _, row in df_model_output.iterrows() + ] + + return model_output + + def save_inference_results( self, outputs: Union[List[List[str]], List[List[List[int]]]], @@ -181,7 +314,7 @@ def __init__( self.inferencer_file_path = pkg_resources.files("lmflow.pipeline.utils") / "memory_safe_vllm_inference.py" - def inference(self): + def inference(self) -> List[VLLMInferenceResultWithInput]: inferencer_args = make_shell_args_from_dataclass( dataclass_objects=[ self.model_args, @@ -191,13 +324,17 @@ def inference(self): format="shell", ) cmd = "python " + str(self.inferencer_file_path) + " " + inferencer_args + current_env = os.environ.copy() + for var in MEMORY_SAFE_VLLM_INFERENCE_ENV_VAR_TO_REMOVE: + current_env.pop(var, None) cli_res = subprocess.run( args=cmd, stdout=sys.stdout, stderr=sys.stdout, shell=True, - preexec_fn=os.setsid + preexec_fn=os.setsid, + env=current_env, ) logger.info(f"MemorySafeVLLMInference subprocess run finished, info at finish: {cli_res}")