Skip to content

Commit

Permalink
Add CPU support for chatbot
Browse files Browse the repository at this point in the history
  • Loading branch information
research4pan committed Apr 11, 2023
1 parent 9780c2f commit 3034460
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 14 deletions.
1 change: 1 addition & 0 deletions examples/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def main():
model_args,
tune_strategy='none',
ds_config=ds_config,
device=pipeline_args.device,
)

# We don't need input data, we will read interactively from stdin
Expand Down
17 changes: 17 additions & 0 deletions scripts/run_chatbot_cpu.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#!/bin/bash

model=gpt2
lora_args=""
if [ $# -ge 1 ]; then
model=$1
fi
if [ $# -ge 2 ]; then
lora_args="--lora_model_path $2"
fi

CUDA_VISIBLE_DEVICES="" \
python examples/chatbot.py \
--deepspeed configs/ds_config_chatbot.json \
--model_name_or_path ${model} \
--device "cpu" \
${lora_args}
7 changes: 7 additions & 0 deletions src/lmflow/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,13 @@ class InferencerArguments:
mixed precision mode, whether to use bf16 or fp16
"""
device: str = field(
default="gpu",
metadata={
"help": "device of chatbot",
"choices": ["gpu", "cpu"],
},
)
local_rank: int = field(
default=-1,
metadata={"help": "For distributed training: local_rank"
Expand Down
37 changes: 26 additions & 11 deletions src/lmflow/models/hf_decoder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def __init__(
model_args,
tune_strategy='normal',
ds_config=None,
device="gpu",
*args,
**kwargs
):
Expand All @@ -100,6 +101,7 @@ def __init__(
# Distributed training: The .from_pretrained methods guarantee that
# only one local process can concurrently download model & vocab.

self.device = device

if tune_strategy == 'normal':
config_kwargs = {
Expand Down Expand Up @@ -228,10 +230,10 @@ def __init__(
self.backend_model, peft_model_id
)


deepspeed.init_distributed()
self.ds_engine = deepspeed.initialize(model=self.backend_model, config_params=ds_config)[0]
self.ds_engine.module.eval()
if device == "gpu":
deepspeed.init_distributed()
self.ds_engine = deepspeed.initialize(model=self.backend_model, config_params=ds_config)[0]
self.ds_engine.module.eval()

elif tune_strategy == 'adapter':
raise NotImplementedError('adapter tune strategy not implemented')
Expand Down Expand Up @@ -400,13 +402,26 @@ def inference(self, inputs, *args, **kwargs):


with torch.no_grad():
outputs = self.ds_engine.module.generate(
input_ids=inputs,
synced_gpus=True,
pad_token_id=self.tokenizer.eos_token_id,
*args,
**kwargs
)
if self.device == "gpu":
outputs = self.ds_engine.module.generate(
input_ids=inputs,
synced_gpus=True,
pad_token_id=self.tokenizer.eos_token_id,
*args,
**kwargs
)
elif self.device == "cpu":
outputs = self.backend_model.generate(
input_ids=inputs,
synced_gpus=True,
pad_token_id=self.tokenizer.eos_token_id,
*args,
**kwargs
)
else:
raise NotImplementedError(
f"device \"{self.device}\" is not supported"
)
return outputs


Expand Down
21 changes: 18 additions & 3 deletions src/lmflow/pipeline/inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,15 @@ def __init__(self, model_args, data_args, inferencer_args):

self.local_rank = int(os.getenv("LOCAL_RANK", "0"))
self.world_size = int(os.getenv("WORLD_SIZE", "1"))
torch.cuda.set_device(self.local_rank) # NOTE: cpu-only machine will have error
deepspeed.init_distributed()
if inferencer_args.device == "gpu":
torch.cuda.set_device(self.local_rank) # NOTE: cpu-only machine will have error
deepspeed.init_distributed()
else:
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "15000"
dist.init_process_group(
"gloo", rank=self.local_rank, world_size=self.world_size
)

self.config = AutoConfig.from_pretrained(model_args.model_name_or_path)
try:
Expand Down Expand Up @@ -119,7 +126,15 @@ def inference(

input = prompt_structure.format(input=current_batch['input'])

inputs = model.encode(input, return_tensors="pt").to(device=self.local_rank)
if self.inferencer_args.device == "gpu":
inputs = model.encode(input, return_tensors="pt").to(device=self.local_rank)
elif self.inferencer_args.device == "cpu":
inputs = model.encode(input, return_tensors="pt").to(device='cpu')
else:
raise NotImplementedError(
f"device \"{device}\" is not supported"
)

outputs = model.inference(
inputs,
max_new_tokens=max_new_tokens,
Expand Down

0 comments on commit 3034460

Please sign in to comment.