Skip to content

Commit

Permalink
calling load_cfg() when config is required.
Browse files Browse the repository at this point in the history
changed to type checking instead of string matching in Model.

Signed-off-by: Ahmed Umair <[email protected]>
  • Loading branch information
Umair Ahmed committed Sep 24, 2024
1 parent e885474 commit 9929be2
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 18 deletions.
7 changes: 2 additions & 5 deletions crossfit/backend/torch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
create_nested_list_series_from_3d_ar,
)
from crossfit.utils.torch_utils import cleanup_torch_cache, concat_and_pad_tensors
from crossfit.backend.torch.loader import SortedSeqLoader


class Model:
Expand Down Expand Up @@ -61,11 +62,7 @@ def max_seq_length(self) -> int:

def get_model_output(self, all_outputs_ls, index, loader, pred_output_col) -> cudf.DataFrame:
out = cudf.DataFrame(index=index)
_index = (
loader.sort_column(index.values)
if loader.__class__.__name__ == "SortedSeqLoader"
else index
)
_index = loader.sort_column(index.values) if type(loader) == SortedSeqLoader else index

if self.model_output_type == "string":
all_outputs = [o for output in all_outputs_ls for o in output]
Expand Down
18 changes: 5 additions & 13 deletions examples/custom_ct2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ def __call__(self, batch):
tokens_1d[i : i + token_ids_2d.size(1)]
for i in range(0, len(tokens_1d), token_ids_2d.size(1))
]
tokenss = self.clean_extra_tokens(tokens_2d)
tokens = self.clean_extra_tokens(tokens_2d)

tr_res = self.model.translate_batch(
tokenss,
tokens,
min_decoding_length=0,
max_decoding_length=256,
beam_size=5,
Expand All @@ -81,10 +81,7 @@ def __call__(self, batch):
class ModelForSeq2SeqModel(HFModel):
def __init__(self, config):
self.trans_config = config
self.config = AutoConfig.from_pretrained(
pretrained_model_name_or_path=self.trans_config.pretrained_model_name_or_path,
trust_remote_code=True,
)
self.config = self.load_cfg()
super().__init__(
self.trans_config.pretrained_model_name_or_path, model_output_type="string"
)
Expand All @@ -94,10 +91,7 @@ def load_model(self, device="cuda"):
return model

def load_config(self):
return AutoConfig.from_pretrained(
pretrained_model_name_or_path=self.trans_config.pretrained_model_name_or_path,
trust_remote_code=True,
)
return self.load_cfg()

@lru_cache(maxsize=1)
def load_tokenizer(self):
Expand Down Expand Up @@ -155,9 +149,7 @@ def main():
with cf.Distributed(rmm_pool_size=args.pool_size, n_workers=args.num_workers):
model = ModelForSeq2SeqModel(Config)
pipe = op.Sequential(
op.Tokenizer(
model, cols=[args.input_column], tokenizer_type="default", max_length=255
),
op.Tokenizer(model, cols=[args.input_column], tokenizer_type="default", max_length=255),
op.Predictor(model, sorted_data_loader=True, batch_size=args.batch_size),
repartition=args.partitions,
keep_cols=[args.input_column],
Expand Down

0 comments on commit 9929be2

Please sign in to comment.