You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
My own task or dataset (give details below)
Reproduction
Locate a model which
isn't present in TOKENIZER_MAPPING
doesn't specify model_config.tokenizer_class
nevertheless has a tokenizer on the hub, loadable with AutoTokenizer
These requirements means this happens for custom models only (not integrated into the library), AFAIK. Running those requires trust_remote_code=True, so it might be wise to create your own example meeting these requirements. I will be using "tcheda/mot_test".
Verify that code works properly where tokenizer and model are passed as pre-instantiated objects
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("tcheda/mot_test", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("tcheda/mot_test", trust_remote_code=True)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
Attempt to create a pipeline, specifying both the model and the tokenizer as strings.
from transformers import pipeline
pipe = pipeline("text-generation", model="tcheda/mot_test", tokenizer="tcheda/mot_test", trust_remote_code=True)
The code crashes immediately with:
[usr/local/lib/python3.10/dist-packages/transformers/pipelines/__init__.py](https://localhost:8080/#) in pipeline(task, model, config, tokenizer, feature_extractor, image_processor, framework, revision, use_fast, token, device, device_map, torch_dtype, trust_remote_code, model_kwargs, pipeline_class, **kwargs)
1106 kwargs["device"] = device
1107
-> 1108 return pipeline_class(model=model, framework=framework, task=task, **kwargs)
[/usr/local/lib/python3.10/dist-packages/transformers/pipelines/text_generation.py](https://localhost:8080/#) in __init__(self, *args, **kwargs)
94
95 def __init__(self, *args, **kwargs):
---> 96 super().__init__(*args, **kwargs)
97 self.check_model_type(
98 TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES if self.framework == "tf" else MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
[/usr/local/lib/python3.10/dist-packages/transformers/pipelines/base.py](https://localhost:8080/#) in __init__(self, model, tokenizer, feature_extractor, image_processor, modelcard, framework, task, args_parser, device, torch_dtype, binary_output, **kwargs)
895 self.tokenizer is not None
896 and self.model.can_generate()
--> 897 and self.tokenizer.pad_token_id is not None
898 and self.model.generation_config.pad_token_id is None
899 ):
AttributeError: 'str' object has no attribute 'pad_token_id'
load_tokenizer = type(model_config) in TOKENIZER_MAPPING or model_config.tokenizer_class is not None
This isn't set when tokenizer is a string, so the tokenizer initialization block (which contains proper handling of the string case) at
The pipeline should be created correctly, loading the tokenizer as if with AutoTokenizer.from_pretrained(tokenizer). This is the behaviour described in the docs.
The text was updated successfully, but these errors were encountered:
System Info
transformers
version: 4.42.1Who can help?
@Narsil
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
These requirements means this happens for custom models only (not integrated into the library), AFAIK. Running those requires trust_remote_code=True, so it might be wise to create your own example meeting these requirements. I will be using "tcheda/mot_test".
The bug is probably at
transformers/src/transformers/pipelines/__init__.py
Line 907 in 1c68f2c
load_tokenizer = type(model_config) in TOKENIZER_MAPPING or model_config.tokenizer_class is not None
This isn't set when
tokenizer
is a string, so the tokenizer initialization block (which contains proper handling of the string case) attransformers/src/transformers/pipelines/__init__.py
Line 907 in 1c68f2c
Expected behavior
The pipeline should be created correctly, loading the tokenizer as if with AutoTokenizer.from_pretrained(tokenizer). This is the behaviour described in the docs.
The text was updated successfully, but these errors were encountered: