Skip to content

Commit bc49336

Browse files
author
Yusuke Oda
authored
Merge pull request #195 from neulab/e501
Complying with E501 Former-commit-id: ee63b46
2 parents 84792e3 + c444740 commit bc49336

28 files changed

+1383
-294
lines changed

explainaboard/explainaboard_main.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@ def main():
2020
type=str,
2121
required=True,
2222
nargs="+",
23-
help="the directories of system outputs. Multiple one should be separated by space, for example: system1 system2",
23+
help=(
24+
"the directories of system outputs. Multiple one should be separated by "
25+
"space, for example: system1 system2"
26+
),
2427
)
2528

2629
parser.add_argument(
@@ -84,11 +87,13 @@ def main():
8487
# Checks on inputs
8588
if num_outputs > 2:
8689
raise ValueError(
87-
f'ExplainaBoard currently only supports 1 or 2 system outputs, but received {num_outputs}'
90+
f'ExplainaBoard currently only supports 1 or 2 system outputs, but '
91+
f'received {num_outputs}'
8892
)
8993
if task not in TaskType.list():
9094
raise ValueError(
91-
f'Task name {task} was not recognized. ExplainaBoard currently supports: {TaskType.list()}'
95+
f'Task name {task} was not recognized. ExplainaBoard currently supports: '
96+
f'{TaskType.list()}'
9297
)
9398

9499
# Read in data and check validity
@@ -101,7 +106,8 @@ def main():
101106
num0 = len(system_datasets[0])
102107
num1 = len(system_datasets[1])
103108
raise ValueError(
104-
f'Data must be identical for pairwise analysis, but length of files {system_datasets[0]} ({num0}) != {system_datasets[1]} ({num1})'
109+
f'Data must be identical for pairwise analysis, but length of files '
110+
f'{system_datasets[0]} ({num0}) != {system_datasets[1]} ({num1})'
105111
)
106112
if (
107113
loaders[0].user_defined_features_configs

explainaboard/feature.py

Lines changed: 43 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818

1919
def _arrow_to_datasets_dtype(arrow_type: pa.DataType) -> str:
2020
"""
21-
_arrow_to_datasets_dtype takes a pyarrow.DataType and converts it to a datasets string dtype.
22-
In effect, `dt == string_to_arrow(_arrow_to_datasets_dtype(dt))`
21+
_arrow_to_datasets_dtype takes a pyarrow.DataType and converts it to a datasets
22+
string dtype. In effect, `dt == string_to_arrow(_arrow_to_datasets_dtype(dt))`
2323
"""
2424

2525
if pa.types.is_null(arrow_type):
@@ -74,11 +74,11 @@ def string_to_arrow(datasets_dtype: str) -> pa.DataType:
7474
"""
7575
string_to_arrow takes a datasets string dtype and converts it to a pyarrow.DataType.
7676
In effect, `dt == string_to_arrow(_arrow_to_datasets_dtype(dt))`
77-
This is necessary because the datasets.Value() primitive type is constructed using a string dtype
78-
Value(dtype=str)
79-
But Features.type (via `get_nested_type()` expects to resolve Features into a pyarrow Schema,
80-
which means that each Value() must be able to resolve into a corresponding pyarrow.DataType, which is the
81-
purpose of this function.
77+
This is necessary because the datasets.Value() primitive type is constructed using a
78+
string dtype Value(dtype=str)
79+
But Features.type (via `get_nested_type()` expects to resolve Features into a
80+
pyarrow Schema, which means that each Value() must be able to resolve into a
81+
corresponding pyarrow.DataType, which is the purpose of this function.
8282
"""
8383
timestamp_regex = re.compile(r"^timestamp\[(.*)\]$")
8484
timestamp_matches = timestamp_regex.search(datasets_dtype)
@@ -97,16 +97,21 @@ def string_to_arrow(datasets_dtype: str) -> pa.DataType:
9797
return pa.timestamp(internals_matches.group(1), internals_matches.group(2))
9898
else:
9999
raise ValueError(
100-
f"{datasets_dtype} is not a validly formatted string representation of a pyarrow timestamp."
101-
f"Examples include timestamp[us] or timestamp[us, tz=America/New_York]"
102-
f"See: https://arrow.apache.org/docs/python/generated/pyarrow.timestamp.html#pyarrow.timestamp"
100+
f"""
101+
{datasets_dtype} is not a validly formatted string representation of a pyarrow
102+
timestamp. Examples include timestamp[us] or timestamp[us, tz=America/New_York]
103+
See:
104+
https://arrow.apache.org/docs/python/generated/pyarrow.timestamp.html#pyarrow.timestamp
105+
"""
103106
)
104107
elif datasets_dtype not in pa.__dict__:
105108
if str(datasets_dtype + "_") not in pa.__dict__:
106109
raise ValueError(
107-
f"Neither {datasets_dtype} nor {datasets_dtype + '_'} seems to be a pyarrow data type. "
108-
f"Please make sure to use a correct data type, see: "
109-
f"https://arrow.apache.org/docs/python/api/datatypes.html#factory-functions"
110+
f"""
111+
Neither {datasets_dtype} nor {datasets_dtype + '_'} seems to be a pyarrow data type.
112+
Please make sure to use a correct data type, see:
113+
https://arrow.apache.org/docs/python/api/datatypes.html#factory-functions
114+
"""
110115
)
111116
arrow_data_factory_function_name = str(datasets_dtype + "_")
112117
else:
@@ -119,17 +124,23 @@ def _cast_to_python_objects(obj: Any, only_1d_for_numpy: bool) -> tuple[Any, boo
119124
"""
120125
Cast pytorch/tensorflow/pandas objects to python numpy array/lists.
121126
It works recursively.
122-
To avoid iterating over possibly long lists, it first checks if the first element that is not None has to be casted.
123-
If the first element needs to be casted, then all the elements of the list will be casted, otherwise they'll stay the same.
124-
This trick allows to cast objects that contain tokenizers outputs without iterating over every single token for example.
127+
To avoid iterating over possibly long lists, it first checks if the first element
128+
that is not None has to be casted.
129+
If the first element needs to be casted, then all the elements of the list will be
130+
casted, otherwise they'll stay the same.
131+
This trick allows to cast objects that contain tokenizers outputs without iterating
132+
over every single token for example.
125133
Args:
126134
obj: the object (nested struct) to cast
127-
only_1d_for_numpy (bool): whether to keep the full multi-dim tensors as multi-dim numpy arrays, or convert them to
128-
nested lists of 1-dimensional numpy arrays. This can be useful to keep only 1-d arrays to instantiate Arrow arrays.
129-
Indeed Arrow only support converting 1-dimensional array values.
135+
only_1d_for_numpy (bool): whether to keep the full multi-dim tensors as
136+
multi-dim numpy arrays, or convert them to nested lists of 1-dimensional
137+
numpy arrays. This can be useful to keep only 1-d arrays to instantiate
138+
Arrow arrays. Indeed Arrow only support converting 1-dimensional array
139+
values.
130140
Returns:
131141
casted_obj: the casted object
132-
has_changed (bool): True if the object has been changed, False if it is identical
142+
has_changed (bool): True if the object has been changed, False if it is
143+
identical
133144
"""
134145

135146
if config.TF_AVAILABLE and "tensorflow" in sys.modules:
@@ -240,9 +251,12 @@ def cast_to_python_objects(obj: Any, only_1d_for_numpy=False) -> Any:
240251
"""
241252
Cast numpy/pytorch/tensorflow/pandas objects to python lists.
242253
It works recursively.
243-
To avoid iterating over possibly long lists, it first checks if the first element that is not None has to be casted.
244-
If the first element needs to be casted, then all the elements of the list will be casted, otherwise they'll stay the same.
245-
This trick allows to cast objects that contain tokenizers outputs without iterating over every single token for example.
254+
To avoid iterating over possibly long lists, it first checks if the first element
255+
that is not None has to be casted.
256+
If the first element needs to be casted, then all the elements of the list will be
257+
casted, otherwise they'll stay the same.
258+
This trick allows to cast objects that contain tokenizers outputs without iterating
259+
over every single token for example.
246260
Args:
247261
obj: the object (nested struct) to cast
248262
Returns:
@@ -552,7 +566,8 @@ def encode_example(self, value):
552566

553567
def encode_nested_example(schema, obj):
554568
"""Encode a nested example.
555-
This is used since some features (in particular ClassLabel) have some logic during encoding.
569+
This is used since some features (in particular ClassLabel) have some logic during
570+
encoding.
556571
"""
557572
# Nested structures: we allow dict, list/tuples, sequences
558573
if isinstance(schema, dict):
@@ -598,10 +613,12 @@ def encode_nested_example(schema, obj):
598613
else None
599614
)
600615
# Object with special encoding:
601-
# ClassLabel will convert from string to int, TranslationVariableLanguages does some checks
616+
# ClassLabel will convert from string to int,
617+
# TranslationVariableLanguages does some checks
602618
elif isinstance(schema, (ClassLabel, Value)):
603619
return schema.encode_example(obj)
604-
# Other object should be directly convertible to a native Arrow type (like Translation and Translation)
620+
# Other object should be directly convertible to a native Arrow type
621+
# (like Translation and Translation)
605622
return obj
606623

607624

explainaboard/info.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,10 @@ class SysOutputInfo:
7474
download_link (str): the url of the system output.
7575
paper (Paper, optional): the published paper of the system.
7676
features (Features, optional): the features used to describe system output's
77-
column type.
77+
column type.
7878
is_print_case (bool): Whether or not to print out cases
79-
is_print_confidence_interval (bool): Whether or not to print out confidence intervals
79+
is_print_confidence_interval (bool): Whether or not to print out confidence
80+
intervals
8081
"""
8182

8283
# set in the system_output scripts
@@ -141,8 +142,8 @@ def _dump_info(self, file):
141142
def from_directory(cls, sys_output_info_dir: str) -> "SysOutputInfo":
142143
"""Create SysOutputInfo from the JSON file in `sys_output_info_dir`.
143144
Args:
144-
sys_output_info_dir (`str`): The directory containing the metadata file. This
145-
should be the root directory of a specific dataset version.
145+
sys_output_info_dir (`str`): The directory containing the metadata file.
146+
This should be the root directory of a specific dataset version.
146147
"""
147148
logger.info("Loading Dataset info from %s", sys_output_info_dir)
148149
if not sys_output_info_dir:
@@ -158,11 +159,6 @@ def from_directory(cls, sys_output_info_dir: str) -> "SysOutputInfo":
158159
sys_output_info_dict = json.load(f)
159160
return cls.from_dict(sys_output_info_dict)
160161

161-
# @classmethod
162-
# def from_dict(cls, task_name: str, sys_output_info_dict: dict) -> "SysOutputInfo":
163-
# field_names = set(f.name for f in dataclasses.fields(cls))
164-
# return cls(task_name, **{k: v for k, v in sys_output_info_dict.items() if k in field_names})
165-
166162
@classmethod
167163
def from_dict(cls, sys_output_info_dict: dict) -> "SysOutputInfo":
168164
field_names = set(f.name for f in dataclasses.fields(cls))

explainaboard/loaders/file_loader.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,19 @@
1414
class FileLoaderField:
1515
"""
1616
Args:
17-
src_name: field name in the source file. use int for tsv column indices and use str for dict keys
17+
src_name: field name in the source file. use int for tsv column indices and use
18+
str for dict keys
1819
target_name: field name expected in the loaded data
19-
dtype: data type of the field in the loaded data. It is only intended for simple type conversion so
20-
it only supports int, float and str. Pass in None to turn off type conversion.
21-
strip_before_parsing: call strip() on strings before casting to either str, int or float. It is
22-
only intended to be used with these three data types. It defaults to True for str. For all other
23-
types, it defaults to False
24-
parser: a custom parser for the field. When called, `data_points[idx][src_name]` is passed in as input,
25-
it is expected to return the parsed result. If parser is not None, `strip_before_parsing` and dtype
26-
will not have any effect
20+
dtype: data type of the field in the loaded data. It is only intended for simple
21+
type conversion so it only supports int, float and str. Pass in None to turn
22+
off type conversion.
23+
strip_before_parsing: call strip() on strings before casting to either str, int
24+
or float. It is only intended to be used with these three data types.
25+
It defaults to True for str. For all other types, it defaults to False
26+
parser: a custom parser for the field. When called, `data_points[idx][src_name]`
27+
is passed in as input, it is expected to return the parsed result.
28+
If parser is not None, `strip_before_parsing` and dtype will not have any
29+
effect.
2730
"""
2831

2932
src_name: Union[int, str]
@@ -94,7 +97,8 @@ def generate_id(self, parsed_data_point: dict, sample_idx: int):
9497
elif self._id_field_name:
9598
if self._id_field_name not in parsed_data_point:
9699
raise ValueError(
97-
f"The {sample_idx} data point in system outputs file does not have field {self._id_field_name}"
100+
f"The {sample_idx} data point in system outputs file does not have "
101+
f"field {self._id_field_name}"
98102
)
99103
parsed_data_point["id"] = str(parsed_data_point[self._id_field_name])
100104

@@ -104,16 +108,17 @@ def load_raw(cls, data: str, source: Source) -> Iterable:
104108
fields information to parse the data points.
105109
106110
Args:
107-
data (str): base64 encoded system output content or a path for the system output file
108-
source: source of data
111+
data (str): base64 encoded system output content or a path for the system
112+
output file
113+
source: source of data
109114
"""
110115
raise NotImplementedError(
111116
"load_raw() is not implemented for the base FileLoader"
112117
)
113118

114119
def load(self, data: str, source: Source) -> Iterable[dict]:
115-
"""Load data from source, parse data points with fields information and return an
116-
iterable of data points.
120+
"""Load data from source, parse data points with fields information and return
121+
an iterable of data points.
117122
"""
118123
raw_data = self.load_raw(data, source)
119124
parsed_data_points: list[dict] = []

explainaboard/loaders/kg_link_tail_prediction.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
class KgLinkTailPredictionLoader(Loader):
1414
"""
1515
Validate and Reformat system output file with json format:
16-
"head \t relation \t trueTail": [predTail1, predTail2, predTail3, predTail4, predTail5],
16+
"head \t relation \t trueTail": [predTail1, predTail2, ..., predTail5],
1717
1818
usage:
1919
please refer to `test_loaders.py`
@@ -24,14 +24,16 @@ class KgLinkTailPredictionLoader(Loader):
2424

2525
def load(self) -> Iterable[dict]:
2626
"""
27-
:param path_system_output: the path of system output file with following format:
28-
"head \t relation \t trueTail": [predTail1, predTail2, predTail3, predTail4, predTail5],
27+
:param path_system_output:
28+
the path of system output file with following format:
29+
"head \t relation \t trueTail": [predTail1, predTail2, ..., predTail5],
2930
3031
:return: class object
3132
"""
3233
data: list[dict] = []
3334

34-
# TODO(odashi): Avoid potential bug: load_raw returns Iterable[Any] which is not a dict.
35+
# TODO(odashi):
36+
# Avoid potential bug: load_raw returns Iterable[Any] which is not a dict.
3537
raw_data: dict[str, dict[str, str]] = self.file_loaders[ # type: ignore
3638
unwrap(self._file_type)
3739
].load_raw(self._data, self._source)

explainaboard/loaders/loader.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ class Loader:
2121
data: base64 encoded system output content or a path for the system output file
2222
source: source of data
2323
file type: tsv, json, conll, etc.
24-
file_loaders: a dict of file loaders. To customize the loading process, either implement
25-
a custome FileLoader or override `load()`
24+
file_loaders: a dict of file loaders. To customize the loading process, either
25+
implement a custome FileLoader or override `load()`
2626
"""
2727

2828
_default_source = Source.local_filesystem
@@ -49,7 +49,8 @@ def __init__(
4949

5050
if self._file_type not in self.file_loaders:
5151
raise NotImplementedError(
52-
f"A file loader for {self._file_type} is not provided. please add it to the file_loaders."
52+
f"A file loader for {self._file_type} is not provided. "
53+
"please add it to the file_loaders."
5354
)
5455

5556
self._user_defined_features_configs: dict = (
@@ -60,7 +61,8 @@ def __init__(
6061
def user_defined_features_configs(self) -> dict:
6162
if self._user_defined_features_configs is None:
6263
raise Exception(
63-
"User defined features configs are not available (data has not been loaded))"
64+
"User defined features configs are not available "
65+
"(data has not been loaded))"
6466
)
6567
return self._user_defined_features_configs
6668

explainaboard/loaders/qa_multiple_choice.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
class QAMultipleChoiceLoader(Loader):
1414
"""
1515
Validate and Reformat system output file with json format:
16-
"head \t relation \t trueTail": [predTail1, predTail2, predTail3, predTail4, predTail5],
16+
"head \t relation \t trueTail": [predTail1, predTail2, ..., predTail5],
1717
1818
usage:
1919
please refer to `test_loaders.py`
@@ -26,8 +26,9 @@ class QAMultipleChoiceLoader(Loader):
2626

2727
def load(self) -> Iterable[dict]:
2828
"""
29-
:param path_system_output: the path of system output file with following format:
30-
"head \t relation \t trueTail": [predTail1, predTail2, predTail3, predTail4, predTail5],
29+
:param path_system_output:
30+
the path of system output file with following format:
31+
"head \t relation \t trueTail": [predTail1, predTail2, ..., predTail5],
3132
3233
:return: class object
3334
"""

0 commit comments

Comments
 (0)