Skip to content

Commit

Permalink
Merge pull request #109 from Modalities/pydantic-warnings-cli
Browse files Browse the repository at this point in the history
Suppress pydantic warnings in CLI help message
  • Loading branch information
mali-git committed Jun 11, 2024
2 parents b278b3b + bc7d9f3 commit b7a714e
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
9 changes: 7 additions & 2 deletions src/modalities/models/huggingface/huggingface_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
from typing import Any, Dict, List, Optional

import torch
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoModelForMaskedLM, AutoTokenizer, PreTrainedTokenizer

from pydantic import BaseModel, ConfigDict
from transformers import AutoModelForCausalLM, AutoModelForMaskedLM, AutoTokenizer

from modalities.config.lookup_enum import LookupEnum
from modalities.models.model import NNModel
Expand Down Expand Up @@ -36,6 +37,10 @@ class HuggingFacePretrainedModelConfig(BaseModel):
model_args: Optional[Any] = None
kwargs: Optional[Any] = None

# avoid warning about protected namespace 'model_', see
# https://docs.pydantic.dev/2.7/api/config/#pydantic.config.ConfigDict.protected_namespaces
model_config = ConfigDict(protected_namespaces=())


class HuggingFacePretrainedModel(NNModel):

Expand Down
4 changes: 4 additions & 0 deletions src/modalities/tokenization/tokenizer_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ def vocab_size(self) -> int:
def get_token_id(self, token: str) -> int:
raise NotImplementedError


def get_token_id(self, token: str) -> int:
raise NotImplementedError


class PreTrainedHFTokenizer(TokenizerWrapper):
def __init__(
self, pretrained_model_name_or_path: str, max_length: int, truncation: bool = True, padding: str = "max_length"
Expand All @@ -44,10 +46,12 @@ def tokenize(self, text: str) -> List[int]:
)["input_ids"]
return tokens


def decode(self, token_ids: List[int]) -> str:
decoded_text = self.tokenizer.decode(token_ids)
return decoded_text


def get_token_id(self, token: str) -> int:
token_id = self.tokenizer.convert_tokens_to_ids(token)
if isinstance(token_id, list):
Expand Down

0 comments on commit b7a714e

Please sign in to comment.