diff --git a/crossfit/backend/torch/model.py b/crossfit/backend/torch/model.py index 5d39b1a..767327d 100644 --- a/crossfit/backend/torch/model.py +++ b/crossfit/backend/torch/model.py @@ -20,6 +20,7 @@ create_nested_list_series_from_3d_ar, ) from crossfit.utils.torch_utils import cleanup_torch_cache, concat_and_pad_tensors +from crossfit.backend.torch.loader import SortedSeqLoader class Model: @@ -61,11 +62,7 @@ def max_seq_length(self) -> int: def get_model_output(self, all_outputs_ls, index, loader, pred_output_col) -> cudf.DataFrame: out = cudf.DataFrame(index=index) - _index = ( - loader.sort_column(index.values) - if loader.__class__.__name__ == "SortedSeqLoader" - else index - ) + _index = loader.sort_column(index.values) if type(loader) == SortedSeqLoader else index if self.model_output_type == "string": all_outputs = [o for output in all_outputs_ls for o in output] diff --git a/examples/custom_ct2_model.py b/examples/custom_ct2_model.py index 2e867ce..0b8116f 100644 --- a/examples/custom_ct2_model.py +++ b/examples/custom_ct2_model.py @@ -65,10 +65,10 @@ def __call__(self, batch): tokens_1d[i : i + token_ids_2d.size(1)] for i in range(0, len(tokens_1d), token_ids_2d.size(1)) ] - tokenss = self.clean_extra_tokens(tokens_2d) + tokens = self.clean_extra_tokens(tokens_2d) tr_res = self.model.translate_batch( - tokenss, + tokens, min_decoding_length=0, max_decoding_length=256, beam_size=5, @@ -81,10 +81,7 @@ def __call__(self, batch): class ModelForSeq2SeqModel(HFModel): def __init__(self, config): self.trans_config = config - self.config = AutoConfig.from_pretrained( - pretrained_model_name_or_path=self.trans_config.pretrained_model_name_or_path, - trust_remote_code=True, - ) + self.config = self.load_cfg() super().__init__( self.trans_config.pretrained_model_name_or_path, model_output_type="string" ) @@ -94,10 +91,7 @@ def load_model(self, device="cuda"): return model def load_config(self): - return AutoConfig.from_pretrained( - pretrained_model_name_or_path=self.trans_config.pretrained_model_name_or_path, - trust_remote_code=True, - ) + return self.load_cfg() @lru_cache(maxsize=1) def load_tokenizer(self): @@ -155,9 +149,7 @@ def main(): with cf.Distributed(rmm_pool_size=args.pool_size, n_workers=args.num_workers): model = ModelForSeq2SeqModel(Config) pipe = op.Sequential( - op.Tokenizer( - model, cols=[args.input_column], tokenizer_type="default", max_length=255 - ), + op.Tokenizer(model, cols=[args.input_column], tokenizer_type="default", max_length=255), op.Predictor(model, sorted_data_loader=True, batch_size=args.batch_size), repartition=args.partitions, keep_cols=[args.input_column],