Skip to content

Commit

Permalink
support image encoder with image caption as example
Browse files Browse the repository at this point in the history
  • Loading branch information
lianqing11 committed Jun 9, 2023
1 parent 5dd0d2c commit 4542349
Show file tree
Hide file tree
Showing 9 changed files with 198 additions and 30 deletions.
17 changes: 17 additions & 0 deletions configs/ds_config_multimodal.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
"fp16": {
"enabled": false
},
"bf16": {
"enabled": false
},
"comms_logger": {
"enabled": false,
"verbose": false,
"prof_all": false,
"debug": false
},
"steps_per_print": 20000000000000000,
"train_micro_batch_size_per_gpu": 1,
"wall_clock_breakdown": false
}
80 changes: 80 additions & 0 deletions examples/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 Statistics and Machine Learning Research Group at HKUST. All rights reserved.
"""A simple shell to inference the input data.
"""
import logging
import json
import os
import sys
sys.path.remove(os.path.abspath(os.path.dirname(sys.argv[0])))
import warnings

from dataclasses import dataclass, field
from transformers import HfArgumentParser
from typing import Optional

from lmflow.datasets.dataset import Dataset
from lmflow.pipeline.auto_pipeline import AutoPipeline
from lmflow.models.auto_model import AutoModel
from lmflow.args import ModelArguments, DatasetArguments, AutoArguments
from PIL import Image
import torch
import requests
from transformers import BlipProcessor, BlipForConditionalGeneration


logging.disable(logging.ERROR)
warnings.filterwarnings("ignore")


def main():
pipeline_name = "inferencer"
PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name)

parser = HfArgumentParser((
ModelArguments,
PipelineArguments,
))

model_args, pipeline_args = (
parser.parse_args_into_dataclasses()
)
inferencer_args = pipeline_args

with open (pipeline_args.deepspeed, "r") as f:
ds_config = json.load(f)

model = AutoModel.get_model(
model_args,
tune_strategy='none',
ds_config=ds_config,
device=pipeline_args.device,
)

data_args = DatasetArguments(dataset_path=None)
dataset = Dataset(data_args)

inferencer = AutoPipeline.get_pipeline(
pipeline_name=pipeline_name,
model_args=model_args,
data_args=data_args,
pipeline_args=pipeline_args,
)


img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg'
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
input_dataset = dataset.from_dict({
"type": "image_text",
"instances": [{"images": raw_image,
"text": "",}]
})

prompt_text = "a photography of"
output = inferencer.inference(model, input_dataset,
prompt_structure=prompt_text + "{input}")
print(output.backend_dataset['text'])

if __name__ == "__main__":
main()
17 changes: 17 additions & 0 deletions scripts/inference_multimodal_model.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#!/bin/bash

model="Salesforce/blip-image-captioning-base"
lora_args=""
if [ $# -ge 1 ]; then
model=$1
fi
if [ $# -ge 2 ]; then
lora_args="--lora_model_path $2"
fi

CUDA_VISIBLE_DEVICES=7 \
deepspeed examples/inference.py \
--deepspeed configs/ds_config_multimodal.json \
--model_name_or_path ${model} \
--arch_type visionEncoder_decoder \
${lora_args}
2 changes: 1 addition & 1 deletion src/lmflow/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class ModelArguments:
"Model architecture type, e.g. \"decoder_only\","
" \"encoder_decoder\""
),
"choices": ["decoder_only", "encoder_decoder", "text_regression"],
"choices": ["decoder_only", "encoder_decoder", "text_regression", "visionEncoder_decoder"],
},
)
config_name: Optional[str] = field(
Expand Down
2 changes: 1 addition & 1 deletion src/lmflow/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"text_only",
"text2text",
"float_only",
"image_text",
]

KEY_TYPE = "type"
Expand Down Expand Up @@ -84,7 +85,6 @@ def __init__(self, data_args=None, backend: str="huggingface", *args, **kwargs):
f' ]\n'
'}'
)

if self.type is None:
self.type = json_data[KEY_TYPE]
elif self.type != json_data[KEY_TYPE]:
Expand Down
3 changes: 2 additions & 1 deletion src/lmflow/models/auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ def get_model(self, model_args, *args, **kwargs):
return HFDecoderModel(model_args, *args, **kwargs)
elif arch_type == "text_regression":
return TextRegressionModel(model_args, *args, **kwargs)
elif arch_type == "encoder_decoder":
elif arch_type == "encoder_decoder" or \
arch_type == "visionEncoder_decoder":
return HFEncoderDecoderModel(model_args, *args, **kwargs)
else:
raise NotImplementedError(
Expand Down
61 changes: 45 additions & 16 deletions src/lmflow/models/hf_encoder_decoder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,15 @@
AutoConfig,
AutoTokenizer,
AutoModelForSeq2SeqLM,
AutoModelForVision2Seq,
AutoModel,
AutoProcessor,
)

from lmflow.datasets.dataset import Dataset
from lmflow.models.encoder_decoder_model import EncoderDecoderModel
from lmflow.models.interfaces.tunable import Tunable

import copy

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -123,14 +125,24 @@ def __init__(
)
model_args.use_ram_optimized_load = False


# get model register
self.arch_type = model_args.arch_type
if self.arch_type == "encoder_decoder":
if model_args.model_name_or_path == 'THUDM/chatglm-6b':
model_register = AutoModel
else:
model_register = AutoModelForSeq2SeqLM
elif self.arch_type == "visionEncoder_decoder":
model_register = AutoModelForVision2Seq
else:
raise NotImplementedError
if model_args.model_name_or_path == 'THUDM/chatglm-6b':
self.backend_model = AutoModel.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)

self.backend_model = model_register.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
elif model_args.use_ram_optimized_load and peft_model_id is None:
try:
# RAM-optimized load
self.backend_model = AutoModelForSeq2SeqLM.from_pretrained(
self.backend_model = model_register.from_pretrained(
model_args.model_name_or_path,
device_map="auto",
offload_folder="offload",
Expand All @@ -142,7 +154,7 @@ def __init__(
" use original load instead."
)
# Normal load
self.backend_model = AutoModelForSeq2SeqLM.from_pretrained(
self.backend_model = model_register.from_pretrained(
model_args.model_name_or_path,
)
else:
Expand All @@ -151,11 +163,17 @@ def __init__(
"LoRA does not support RAM optimized load currently."
" Automatically use original load instead."
)
self.backend_model = AutoModelForSeq2SeqLM.from_pretrained(
self.backend_model = model_register.from_pretrained(
model_args.model_name_or_path,
)
if self.arch_type == "encoder_decoder":
tokenizer_register = AutoTokenizer
elif self.arch_type == "visionEncoder_decoder":
tokenizer_register = AutoProcessor
else:
raise NotImplementedError

self.tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
self.tokenizer = tokenizer_register.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
self.backend_model_full = self.backend_model
if peft_model_id is not None:
self.backend_model = PeftModel.from_pretrained(
Expand All @@ -172,10 +190,11 @@ def __init__(
elif tune_strategy == 'adapter':
raise NotImplementedError('adapter tune strategy not implemented')

if self.tokenizer.eos_token_id is None:
self.tokenizer.eos_token_id = self.backend_model.config.eos_token_id
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
if self.arch_type == "encoder_decoder":
if self.tokenizer.eos_token_id is None:
self.tokenizer.eos_token_id = self.backend_model.config.eos_token_id
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

def tokenize(self, dataset, *args, **kwargs):
"""
Expand Down Expand Up @@ -219,7 +238,11 @@ def encode(self, input: Union[str, List[str]], *args, **kwargs ) -> Union[List[i
outputs :
The tokenized inputs.
"""
if isinstance(input, list):
if isinstance(input, dict):
# TODO refactor the input type to make it elegant.
kwargs.update(input)
return self.tokenizer(*args, **kwargs)
elif isinstance(input, list):
return self.tokenizer(text=input, *args, **kwargs)#batch encode,will automatically do left padding
elif isinstance(input, str):
return self.tokenizer.encode(text=input, *args, **kwargs)
Expand Down Expand Up @@ -276,22 +299,28 @@ def inference(self, inputs, *args, **kwargs):
outputs :
The generated sequence output
"""

# TODO need to discuss how to handle pad_token_id
if self.arch_type == "encoder_decoder":
kwargs.update(pad_token_id=self.tokenizer.pad_token_id)
elif self.arch_type == "visionEncoder_decoder":
# TODO disucss how to modify the interface to remove this part.
inputs = copy.deepcopy(inputs)
input_ids = inputs.pop('input_ids')
kwargs.update(**inputs)
inputs = input_ids

with torch.no_grad():
if self.device == "gpu":
outputs = self.ds_engine.module.generate(
input_ids=inputs,
synced_gpus=True,
pad_token_id=self.tokenizer.pad_token_id,
*args,
**kwargs
)
elif self.device == "cpu":
outputs = self.backend_model.generate(
input_ids=inputs,
synced_gpus=True,
pad_token_id=self.tokenizer.pad_token_id,
*args,
**kwargs
)
Expand Down
45 changes: 34 additions & 11 deletions src/lmflow/pipeline/inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@
def rstrip_partial_utf8(string):
return string.replace("\ufffd", "")

supported_dataset_type = [
"text_only",
"image_text",
]

class Inferencer(BasePipeline):
"""
Initializes the `Inferencer` class with given arguments.
Expand Down Expand Up @@ -69,8 +74,24 @@ def __init__(self, model_args, data_args, inferencer_args):


def create_dataloader(self, dataset: Dataset):
data_dict = dataset.to_dict()
inputs = [ instance["text"] for instance in data_dict["instances"] ]
r"""Batchlize dataset and format it to dataloader.
Args:
dataset (Dataset): the dataset object
Output:
dataloader (batchlize): the dataloader object
dataset_size (int): the length of the dataset
"""
if dataset.get_type() == "text_only":
data_dict = dataset.to_dict()
inputs = [instance["text"] for instance in data_dict["instances"] ]
elif dataset.get_type() == "image_text":
backend_dataset = dataset.get_backend_dataset()
# can not do the do_dict information because the data contains image.
inputs = [backend_dataset.__getitem__(idx) \
for idx in range(len(backend_dataset))]
dataset_size = len(inputs)
dataset_buf = []
for idx in range(dataset_size):
Expand Down Expand Up @@ -110,10 +131,10 @@ def inference(
output_dataset: Dataset object.
"""
if dataset.get_type() != "text_only":
if dataset.get_type() not in supported_dataset_type:
raise NotImplementedError(
'input dataset should have type "text_only"'
)
'input dataset should have type {}'.format(
supported_dataset_type))

dataloader, data_size = self.create_dataloader(dataset)

Expand All @@ -126,8 +147,11 @@ def inference(

for batch_index, batch in enumerate(dataloader):
current_batch = batch[0] # batch size is 1

input = prompt_structure.format(input=current_batch['input'])
if isinstance(current_batch['input'], str):
input = prompt_structure.format(input=current_batch['input'])
else:
input = current_batch['input']
input['text'] = prompt_structure.format(input=input['text'])

if self.inferencer_args.device == "gpu":
inputs = model.encode(input, return_tensors="pt").to(device=self.local_rank)
Expand All @@ -137,7 +161,6 @@ def inference(
raise NotImplementedError(
f"device \"{self.inferencer_args.device}\" is not supported"
)

outputs = model.inference(
inputs,
max_new_tokens=self.inferencer_args.max_new_tokens,
Expand All @@ -146,10 +169,10 @@ def inference(
do_sample=self.inferencer_args.do_sample,
)
text_out = model.decode(outputs[0], skip_special_tokens=True)

# only return the generation, trucating the input
prompt_length = len(model.decode(inputs[0], skip_special_tokens=True,))
text_out = text_out[prompt_length:]
if self.model_args.arch_type != "visionEncoder_decoder":
prompt_length = len(model.decode(inputs[0], skip_special_tokens=True,))
text_out = text_out[prompt_length:]
output_dict["instances"].append({ "text": text_out })

output_dataset = Dataset(DatasetArguments(dataset_path = None))
Expand Down
1 change: 1 addition & 0 deletions src/lmflow/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,4 +167,5 @@
"text_only": ["text"],
"text2text": ["input", "output"],
"float_only": ["value"],
"image_text": ["images", "text"],
}

0 comments on commit 4542349

Please sign in to comment.