diff --git a/src/modalities/models/huggingface/huggingface_models.py b/src/modalities/models/huggingface/huggingface_models.py index b80f222d..9a1d5b93 100644 --- a/src/modalities/models/huggingface/huggingface_models.py +++ b/src/modalities/models/huggingface/huggingface_models.py @@ -2,8 +2,9 @@ from typing import Any, Dict, List, Optional import torch -from pydantic import BaseModel -from transformers import AutoModelForCausalLM, AutoModelForMaskedLM, AutoTokenizer, PreTrainedTokenizer + +from pydantic import BaseModel, ConfigDict +from transformers import AutoModelForCausalLM, AutoModelForMaskedLM, AutoTokenizer from modalities.config.lookup_enum import LookupEnum from modalities.models.model import NNModel @@ -36,6 +37,10 @@ class HuggingFacePretrainedModelConfig(BaseModel): model_args: Optional[Any] = None kwargs: Optional[Any] = None + # avoid warning about protected namespace 'model_', see + # https://docs.pydantic.dev/2.7/api/config/#pydantic.config.ConfigDict.protected_namespaces + model_config = ConfigDict(protected_namespaces=()) + class HuggingFacePretrainedModel(NNModel): diff --git a/src/modalities/tokenization/tokenizer_wrapper.py b/src/modalities/tokenization/tokenizer_wrapper.py index 1e413a92..f419771e 100644 --- a/src/modalities/tokenization/tokenizer_wrapper.py +++ b/src/modalities/tokenization/tokenizer_wrapper.py @@ -19,9 +19,11 @@ def vocab_size(self) -> int: def get_token_id(self, token: str) -> int: raise NotImplementedError + def get_token_id(self, token: str) -> int: raise NotImplementedError + class PreTrainedHFTokenizer(TokenizerWrapper): def __init__( self, pretrained_model_name_or_path: str, max_length: int, truncation: bool = True, padding: str = "max_length" @@ -44,10 +46,12 @@ def tokenize(self, text: str) -> List[int]: )["input_ids"] return tokens + def decode(self, token_ids: List[int]) -> str: decoded_text = self.tokenizer.decode(token_ids) return decoded_text + def get_token_id(self, token: str) -> int: token_id = self.tokenizer.convert_tokens_to_ids(token) if isinstance(token_id, list):