Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import json
import logging
import os
import sys
from typing import Dict, Optional, Tuple

from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_remote_url
Expand Down Expand Up @@ -98,15 +99,19 @@ def __init__(self, **kwargs):
logger.error("Can't set {} with value {} for {}".format(key, value, self))
raise err

def save_pretrained(self, save_directory):
"""
Save a configuration object to the directory `save_directory`, so that it
can be re-loaded using the :func:`~transformers.PretrainedConfig.from_pretrained` class method.
@property
def num_labels(self):
return self._num_labels

Args:
save_directory (:obj:`string`):
Directory where the configuration JSON file will be saved.
"""
@num_labels.setter
def num_labels(self, num_labels):
self._num_labels = num_labels
self.id2label = {i: "LABEL_{}".format(i) for i in range(self.num_labels)}
self.id2label = dict((int(key), value) for key, value in self.id2label.items())
self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
self.label2id = dict((key, int(value)) for key, value in self.label2id.items())

def save_pretrained(self, save_directory):
assert os.path.isdir(
save_directory
), "Saving path should be a directory where the model and configuration can be saved"
Expand Down Expand Up @@ -223,8 +228,6 @@ def get_config_dict(
local_files_only=local_files_only,
)
# Load config dict
if resolved_config_file is None:
raise EnvironmentError
config_dict = cls._dict_from_json_file(resolved_config_file)

except EnvironmentError:
Expand Down