Skip to content

Commit 675eae1

Browse files
afrozenatorcopybara-github
authored andcommitted
No need for a validation split, if eval_holdout_size has been specified.
PiperOrigin-RevId: 375127276
1 parent 86dc892 commit 675eae1

File tree

6 files changed

+72
-4
lines changed

6 files changed

+72
-4
lines changed

trax/data/inputs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,9 @@ def Parallel( # pylint: disable=invalid-name
153153
# Remove generators with zero counters
154154
counters = list(counters)
155155
fns = list(fns)
156-
zeros = [j for j in range(len(counters)) if counters[j] != 0]
157-
counters = [counters[j] for j in zeros]
158-
fns = [fns[j] for j in zeros]
156+
non_zeros = [j for j in range(len(counters)) if counters[j] != 0]
157+
counters = [counters[j] for j in non_zeros]
158+
fns = [fns[j] for j in non_zeros]
159159
else:
160160
counters = [1] * len(fns)
161161

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
{
2+
"citation": "@misc {paracrawl,\n title = \"ParaCrawl\",\n year = \"2018\",\n url = \"http://paracrawl.eu/download.html.\"\n}",
3+
"configDescription": "Translation dataset from English to de.",
4+
"configName": "ende",
5+
"description": "Web-Scale Parallel Corpora for Official European Languages.",
6+
"downloadSize": "1307754745",
7+
"location": {
8+
"urls": [
9+
"https://paracrawl.eu/releases.html"
10+
]
11+
},
12+
"name": "para_crawl",
13+
"splits": [
14+
{
15+
"name": "train",
16+
"numBytes": "3241",
17+
"shardLengths": [
18+
"10"
19+
]
20+
}
21+
],
22+
"supervisedKeys": {
23+
"input": "en",
24+
"output": "de"
25+
},
26+
"version": "1.2.0"
27+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
{
2+
"type": "tensorflow_datasets.core.features.translation_feature.Translation",
3+
"content": {
4+
"languages": [
5+
"de",
6+
"en"
7+
]
8+
}
9+
}
3.17 KB
Binary file not shown.

trax/data/tf_inputs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ def _train_and_eval_dataset(dataset_name,
226226
if dataset_name != 'c4/multilingual' and tfds.Split.TRAIN not in splits:
227227
raise ValueError('To train we require a train split in the dataset.')
228228
train_split = tfds.Split.TRAIN if dataset_name != 'c4/multilingual' else 'en'
229+
eval_split = None
229230
train_examples = info.splits[train_split].num_examples
230231
eval_holdout_examples = int(train_examples * eval_holdout_size)
231232
if eval_holdout_examples > 0 or subsplit is not None:
@@ -248,7 +249,7 @@ def _train_and_eval_dataset(dataset_name,
248249
'validation_mismatched' if use_alt_eval else 'validation_matched')
249250
elif dataset_name == 'c4/multilingual':
250251
eval_split = 'en-validation'
251-
else:
252+
elif eval_split is None:
252253
if tfds.Split.VALIDATION not in splits and 'test' not in splits:
253254
raise ValueError('We require a validation or test split in the dataset.')
254255
eval_split = tfds.Split.VALIDATION

trax/data/tf_inputs_test.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,37 @@ def test_TFDS_single_host_with_eval_holdout(self):
143143
print(f'Eval: {d}')
144144
break
145145

146+
def test_TFDS_single_host_with_eval_holdout_no_valid_split(self):
147+
train_ds_gen = tf_inputs.TFDS(
148+
'para_crawl/ende',
149+
data_dir=_TESTDATA,
150+
train=True,
151+
host_id=0,
152+
keys=('en', 'de'),
153+
n_hosts=1,
154+
eval_holdout_size=0.1)
155+
156+
# Just ensure that this doesn't crash.
157+
for d in train_ds_gen():
158+
print(f'Train: {d}')
159+
break
160+
161+
# para_crawl doesn't have a validation set, see that this still doesn't
162+
# crash because of eval_holdout_set.
163+
valid_ds_gen = tf_inputs.TFDS(
164+
'para_crawl/ende',
165+
data_dir=_TESTDATA,
166+
train=False,
167+
host_id=0,
168+
keys=('en', 'de'),
169+
n_hosts=1,
170+
eval_holdout_size=0.1)
171+
172+
# Just ensure that this doesn't crash.
173+
for d in valid_ds_gen():
174+
print(f'Eval: {d}')
175+
break
176+
146177
def test_TFDS_mnli_split_is_eval(self):
147178
with mock.patch('tensorflow_datasets.load') as tfds_load:
148179
with mock.patch('trax.data.tf_inputs.download_and_prepare',

0 commit comments

Comments
 (0)