Skip to content

Commit

Permalink
fix the code style for the multi-modal module
Browse files Browse the repository at this point in the history
  • Loading branch information
lianqing11 committed Jun 12, 2023
1 parent 4542349 commit aaead4f
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ if [ $# -ge 2 ]; then
lora_args="--lora_model_path $2"
fi

CUDA_VISIBLE_DEVICES=7 \
CUDA_VISIBLE_DEVICES=0 \
deepspeed examples/inference.py \
--deepspeed configs/ds_config_multimodal.json \
--model_name_or_path ${model} \
--arch_type visionEncoder_decoder \
--arch_type vision_encoder_decoder \
${lora_args}
2 changes: 1 addition & 1 deletion src/lmflow/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class ModelArguments:
"Model architecture type, e.g. \"decoder_only\","
" \"encoder_decoder\""
),
"choices": ["decoder_only", "encoder_decoder", "text_regression", "visionEncoder_decoder"],
"choices": ["decoder_only", "encoder_decoder", "text_regression", "vision_encoder_decoder"],
},
)
config_name: Optional[str] = field(
Expand Down
2 changes: 1 addition & 1 deletion src/lmflow/models/auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def get_model(self, model_args, *args, **kwargs):
elif arch_type == "text_regression":
return TextRegressionModel(model_args, *args, **kwargs)
elif arch_type == "encoder_decoder" or \
arch_type == "visionEncoder_decoder":
arch_type == "vision_encoder_decoder":
return HFEncoderDecoderModel(model_args, *args, **kwargs)
else:
raise NotImplementedError(
Expand Down
8 changes: 4 additions & 4 deletions src/lmflow/models/hf_encoder_decoder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"""

import logging
import copy
from typing import List, Union

import deepspeed
Expand Down Expand Up @@ -50,7 +51,6 @@
from lmflow.datasets.dataset import Dataset
from lmflow.models.encoder_decoder_model import EncoderDecoderModel
from lmflow.models.interfaces.tunable import Tunable
import copy

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -132,7 +132,7 @@ def __init__(
model_register = AutoModel
else:
model_register = AutoModelForSeq2SeqLM
elif self.arch_type == "visionEncoder_decoder":
elif self.arch_type == "vision_encoder_decoder":
model_register = AutoModelForVision2Seq
else:
raise NotImplementedError
Expand Down Expand Up @@ -168,7 +168,7 @@ def __init__(
)
if self.arch_type == "encoder_decoder":
tokenizer_register = AutoTokenizer
elif self.arch_type == "visionEncoder_decoder":
elif self.arch_type == "vision_encoder_decoder":
tokenizer_register = AutoProcessor
else:
raise NotImplementedError
Expand Down Expand Up @@ -302,7 +302,7 @@ def inference(self, inputs, *args, **kwargs):
# TODO need to discuss how to handle pad_token_id
if self.arch_type == "encoder_decoder":
kwargs.update(pad_token_id=self.tokenizer.pad_token_id)
elif self.arch_type == "visionEncoder_decoder":
elif self.arch_type == "vision_encoder_decoder":
# TODO disucss how to modify the interface to remove this part.
inputs = copy.deepcopy(inputs)
input_ids = inputs.pop('input_ids')
Expand Down
4 changes: 2 additions & 2 deletions src/lmflow/pipeline/inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def create_dataloader(self, dataset: Dataset):
elif dataset.get_type() == "image_text":
backend_dataset = dataset.get_backend_dataset()
# can not do the do_dict information because the data contains image.
inputs = [backend_dataset.__getitem__(idx) \
inputs = [backend_dataset.__getitem__(idx)
for idx in range(len(backend_dataset))]
dataset_size = len(inputs)
dataset_buf = []
Expand Down Expand Up @@ -170,7 +170,7 @@ def inference(
)
text_out = model.decode(outputs[0], skip_special_tokens=True)
# only return the generation, trucating the input
if self.model_args.arch_type != "visionEncoder_decoder":
if self.model_args.arch_type != "vision_encoder_decoder":
prompt_length = len(model.decode(inputs[0], skip_special_tokens=True,))
text_out = text_out[prompt_length:]
output_dict["instances"].append({ "text": text_out })
Expand Down

0 comments on commit aaead4f

Please sign in to comment.