Skip to content

Commit

Permalink
[feature] multi instance vllm inference and rm inference
Browse files Browse the repository at this point in the history
  • Loading branch information
wheresmyhair committed Jul 19, 2024
1 parent 3d99d8a commit c677815
Show file tree
Hide file tree
Showing 5 changed files with 339 additions and 67 deletions.
1 change: 1 addition & 0 deletions examples/vllm_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def main():
dataset,
release_gpu=False,
enable_decode_inference_result=pipeline_args.enable_decode_inference_result,
enable_distributed_vllm_inference=pipeline_args.enable_distributed_vllm_inference,
)


Expand Down
1 change: 1 addition & 0 deletions scripts/run_vllm_inference.sh
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,5 @@ python examples/vllm_inference.py \
--enable_decode_inference_result False \
--vllm_gpu_memory_utilization 0.95 \
--vllm_tensor_parallel_size 2 \
--enable_distributed_vllm_inference False \
2>&1 | tee ${log_dir}/vllm_inference.log
210 changes: 170 additions & 40 deletions src/lmflow/pipeline/rm_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
import json
import time
import logging
from typing import Dict, List, Union, Tuple
from typing import Dict, List, Union, Tuple, Any

from accelerate import Accelerator
import ray
import ray.data
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
import torch
from tqdm import tqdm
from transformers import AutoConfig
Expand All @@ -32,7 +35,8 @@
from lmflow.pipeline.base_pipeline import BasePipeline
from lmflow.utils.data_utils import (
set_random_seed,
batchlize
batchlize,
RewardModelInferenceResultWithInput,
)
from lmflow.datasets.dataset import KEY_SCORE

Expand Down Expand Up @@ -61,6 +65,7 @@ def __init__(
model_args: ModelArguments,
data_args: DatasetArguments,
inferencer_args: InferencerArguments,
**kwargs,
):
self.data_args = data_args
self.inferencer_args = inferencer_args
Expand All @@ -79,8 +84,7 @@ def __init__(
)

if inferencer_args.use_accelerator:
self.accelerator = Accelerator()
self.accelerator.wait_for_everyone()
self.accelerator: Accelerator = kwargs.get('accelerator', Accelerator())


def inference(
Expand All @@ -89,6 +93,8 @@ def inference(
dataset: Dataset,
transform_dataset_in_place: bool=True,
use_vllm: bool = False,
enable_distributed_inference: bool = False,
**kwargs,
) -> Dataset:
if use_vllm:
logger.warning("VLLM doesn't support reward model inference, using normal inference instead.")
Expand All @@ -98,49 +104,73 @@ def inference(
if not transform_dataset_in_place:
dataset = copy.deepcopy(dataset)

output_dict = {"type": "", "instances": []}
if dataset.get_type() == "text_to_textlist":
output_dict["type"] = "text_to_scored_textlist"
for idx, instance in enumerate(dataset.get_backend_dataset()):
if len(instance["output"]) < 2:
logger.warning(f"Instance {idx} has less than 2 outputs, skipping.")
output_dict["instances"].append(
{
"input": instance["input"],
"output": [{"text": text} for text in instance["output"]],
}
)
else:
raise NotImplementedError(f"Dataset type {dataset.get_type()} is not supported for reward model inference.")

model_input = model.prepare_inputs_for_inference(
dataset=dataset,
apply_chat_template=True,
enable_distributed_inference=enable_distributed_inference,
use_vllm=use_vllm
)

if use_vllm:
scores = self.__vllm_inference(model, dataset)
inference_result = self.__vllm_inference(
model=model,
model_input=model_input,
enable_distributed_inference=enable_distributed_inference,
)
else:
scores = self.__inference(model, dataset)

for i, instance_scores in enumerate(scores):
for j, score in enumerate(instance_scores):
output_dict["instances"][i]["output"][j][KEY_SCORE] = score
inference_result = self._inference(
model=model,
model_input=model_input,
enable_distributed_inference=enable_distributed_inference,
**kwargs,
)

output_dataset_args = copy.deepcopy(self.data_args)
output_dataset_args.dataset_path = None
output_dataset_args.dataset_name = f"{output_dataset_args.dataset_name}_scored"
output_dataset = Dataset(output_dataset_args)
output_dataset = output_dataset.from_dict(output_dict)
if enable_distributed_inference:
output_dataset = model.postprocess_distributed_inference_outputs(
dataset=dataset,
inference_result=inference_result,
)
else:
output_dataset = model.postprocess_inference_outputs(
dataset=dataset,
scores=inference_result
)

return output_dataset


def _inference(
self,
model: HFTextRegressionModel,
model_input: Union[Dataset, ray.data.Dataset],
enable_distributed_inference: bool = False,
**kwargs,
):
if enable_distributed_inference:
inference_res = self.__distributed_inference(
model=model,
model_input=model_input,
num_instances=kwargs.get("distributed_inference_num_instances", 1),
batch_size=kwargs.get("inference_batch_size", 1),
)
else:
inference_res = self.__inference(
model=model,
model_input=model_input,
)

return inference_res


def __inference(
self,
model: HFTextRegressionModel,
dataset: Dataset,
model_input: Dataset,
) -> Union[List[float], List[List[float]]]:
tokenized_dataset = model.tokenize(dataset)
if dataset.get_type() in ["text_to_textlist"]:
model_input_ids, num_outputs = self.flatten_list(tokenized_dataset.get_backend_dataset()["input_ids"])
if model_input.get_type() in ["text_to_textlist"]:
model_input_ids, num_outputs = self.flatten_list(model_input.get_backend_dataset()["input_ids"])
else:
model_input_ids = tokenized_dataset.get_backend_dataset()["input_ids"]
model_input_ids = model_input.get_backend_dataset()["input_ids"]

dataloader = batchlize(
examples=model_input_ids,
Expand All @@ -157,32 +187,132 @@ def __inference(
unit="batch"
):
# len(batch) = batch_size, and batch element is dataset sample
model_input = torch.LongTensor(batched_input_ids).to("cpu" if model.device == "cpu" else "cuda")
model_input_tensor = torch.LongTensor(batched_input_ids).to("cpu" if model.device == "cpu" else "cuda")
if self.inferencer_args.use_accelerator:
with self.accelerator.autocast():
batch_output = model.inference(
inputs=model_input,
inputs=model_input_tensor,
use_vllm=False,
)
else:
batch_output = model.inference(
inputs=model_input,
inputs=model_input_tensor,
use_vllm=False,
)

batch_output = self.__post_process_model_output(batch_output)
final_output.extend(batch_output)

if dataset.get_type() in ["text_to_textlist"]:
if model_input.get_type() in ["text_to_textlist"]:
final_output = self.compress_list(final_output, num_outputs)

return final_output


def __distributed_inference(
self,
model: HFTextRegressionModel,
model_input: ray.data.Dataset,
num_instances: int,
batch_size: int,
) -> List[RewardModelInferenceResultWithInput]:
def scheduling_strategy_fn():
# One bundle per tensor parallel worker
pg = ray.util.placement_group(
[{
"GPU": 1,
"CPU": 1
}] * self.inferencer_args.tensor_parallel_size,
strategy="STRICT_PACK",
)
return dict(
scheduling_strategy=PlacementGroupSchedulingStrategy(
pg, placement_group_capture_child_tasks=True
)
)

resources_kwarg: Dict[str, Any] = {}
if self.inferencer_args.tensor_parallel_size == 1:
# For tensor_parallel_size == 1, we simply set num_gpus=1.
resources_kwarg["num_gpus"] = 1
else:
# Otherwise, we have to set num_gpus=0 and provide
# a function that will create a placement group for
# each instance.
resources_kwarg["num_gpus"] = 0
resources_kwarg["ray_remote_args_fn"] = scheduling_strategy_fn

## predictor
class DistributedPredictor:
def __init__(
self,
model_args: ModelArguments,
):
self.model = HFTextRegressionModel(
model_args=model_args,
tune_strategy='none',
use_accelerator=True
)
self.model.activate_model_for_inference(use_vllm=False)

def __call__(self, batch: Dict[str, np.ndarray]):
"""batch: Dict[str, np.ndarray]
Example (batch size=2):
{'input': array(['...','...'], dtype=object),
'output': array([array(["...", "..."], dtype=object), array(['...','...'], dtype=object)], dtype=object),
'input_ids': array([[[128000, 128006, 882, ..., 128256, 128256, 128256],
[128000, 128006, 882, ..., 128256, 128256, 128256]],
[[128000, 128006, 882, ..., 128256, 128256, 128256],
[128000, 128006, 882, ..., 128256, 128256, 128256]]])}
"""
# The batch is managed by ray and the actual batch size may smaller than
# inference_batch_size in config, since there may be some remainders.
# For example, 10 examples with 2 inference instances and inference_batch_size=4,
# there will be only 2 examples for instance 0 to run and then the
# actual batch size changes.
actual_batch_size = len(batch['input'])
batched_inference_res = self.model.inference(
inputs=torch.LongTensor(batch['input_ids']).flatten(start_dim=0, end_dim=1).to("cuda"),
).logits
batched_inference_res = batched_inference_res.to("cpu").reshape(actual_batch_size, -1, 1).squeeze(dim=-1).tolist() # [bs, num_output_sequences]
batched_final_res = {
"input": batch['input'].tolist(),
"output": [
[
{"score": batched_inference_res[j][i], "text": batch["output"][j][i]}
for i in range(len(batch['output'][j]))
]
for j in range(actual_batch_size)
],
} # do this since we're writing to a pandas dataframe
return batched_final_res

# inference
model_input_mapping = model_input.map_batches(
DistributedPredictor,
concurrency=num_instances, # Set the concurrency to the number of LLM instances.
batch_size=batch_size,
fn_constructor_kwargs={
"model_args": model.model_args,
},
**resources_kwarg,
)

df_model_output = model_input_mapping.to_pandas() # the actual forwards are executed here
logger.info(f"Distributed reward model inference result preview:\n{df_model_output.head(10)}")

model_output = [
{"input": row["input"], "output": row["output"]} for _, row in df_model_output.iterrows()
]

return model_output


def __vllm_inference(
self,
model: HFTextRegressionModel,
dataset: Dataset,
model_input: List[str],
enable_distributed_inference: bool = False,
) -> List[float]:
raise NotImplementedError("VLLM inference for reward model is not implemented yet.")

Expand Down
3 changes: 3 additions & 0 deletions src/lmflow/pipeline/utils/memory_safe_vllm_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ def main():
dataset,
release_gpu=False,
enable_decode_inference_result=pipeline_args.enable_decode_inference_result,
enable_distributed_inference=pipeline_args.enable_distributed_inference,
distributed_inference_num_instances=pipeline_args.distributed_inference_num_instances,
inference_batch_size=pipeline_args.vllm_inference_batch_size,
)

# use this as a flag, stdout will be captured by the pipeline
Expand Down
Loading

0 comments on commit c677815

Please sign in to comment.