diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index d8cd0fe3e9fe..407f5d4dd13e 100644 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -20,6 +20,7 @@ import json import logging import os +import sys from typing import Dict, Optional, Tuple from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_remote_url @@ -98,15 +99,19 @@ def __init__(self, **kwargs): logger.error("Can't set {} with value {} for {}".format(key, value, self)) raise err - def save_pretrained(self, save_directory): - """ - Save a configuration object to the directory `save_directory`, so that it - can be re-loaded using the :func:`~transformers.PretrainedConfig.from_pretrained` class method. + @property + def num_labels(self): + return self._num_labels - Args: - save_directory (:obj:`string`): - Directory where the configuration JSON file will be saved. - """ + @num_labels.setter + def num_labels(self, num_labels): + self._num_labels = num_labels + self.id2label = {i: "LABEL_{}".format(i) for i in range(self.num_labels)} + self.id2label = dict((int(key), value) for key, value in self.id2label.items()) + self.label2id = dict(zip(self.id2label.values(), self.id2label.keys())) + self.label2id = dict((key, int(value)) for key, value in self.label2id.items()) + + def save_pretrained(self, save_directory): assert os.path.isdir( save_directory ), "Saving path should be a directory where the model and configuration can be saved" @@ -223,8 +228,6 @@ def get_config_dict( local_files_only=local_files_only, ) # Load config dict - if resolved_config_file is None: - raise EnvironmentError config_dict = cls._dict_from_json_file(resolved_config_file) except EnvironmentError: