Skip to content

Commit 2ae2ba4

Browse files
author
Ryan Sepassi
committedNov 14, 2017
Bug fixes
PiperOrigin-RevId: 175611828
1 parent 75b75f2 commit 2ae2ba4

File tree

4 files changed

+20
-7
lines changed

4 files changed

+20
-7
lines changed
 

‎CONTRIBUTING.md

+10
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
11
# How to Contribute
22

3+
# Issues
4+
5+
* Please tag your issue with `bug`, `feature request`, or `question` to help us
6+
effectively respond.
7+
* Please include the versions of TensorFlow and Tensor2Tensor you are running
8+
(run `pip list | grep tensor`)
9+
* Please provide the command line you ran as well as the log output.
10+
11+
# Pull Requests
12+
313
We'd love to accept your patches and contributions to this project. There are
414
just a few small guidelines you need to follow.
515

‎tensor2tensor/bin/t2t-datagen

+4-5
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ from tensor2tensor.data_generators import all_problems # pylint: disable=unused
4343
from tensor2tensor.data_generators import audio
4444
from tensor2tensor.data_generators import generator_utils
4545
from tensor2tensor.data_generators import snli
46-
from tensor2tensor.data_generators import translate
4746
from tensor2tensor.data_generators import wsj_parsing
4847
from tensor2tensor.utils import registry
4948
from tensor2tensor.utils import usr_dir
@@ -82,10 +81,10 @@ _SUPPORTED_PROBLEM_GENERATORS = {
8281
lambda: algorithmic_math.algebra_inverse(26, 0, 2, 100000),
8382
lambda: algorithmic_math.algebra_inverse(26, 3, 3, 10000)),
8483
"parsing_english_ptb8k": (
85-
lambda: translate.parsing_token_generator(
86-
FLAGS.data_dir, FLAGS.tmp_dir, True, 2**13),
87-
lambda: translate.parsing_token_generator(
88-
FLAGS.data_dir, FLAGS.tmp_dir, False, 2**13)),
84+
lambda: wsj_parsing.parsing_token_generator(
85+
FLAGS.data_dir, FLAGS.tmp_dir, True, 2**13, 2**9),
86+
lambda: wsj_parsing.parsing_token_generator(
87+
FLAGS.data_dir, FLAGS.tmp_dir, False, 2**13, 2**9)),
8988
"parsing_english_ptb16k": (
9089
lambda: wsj_parsing.parsing_token_generator(
9190
FLAGS.data_dir, FLAGS.tmp_dir, True, 2**14, 2**9),

‎tensor2tensor/models/shake_shake.py

+2
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,8 @@ def model_fn_body(self, features):
132132
@registry.register_hparams
133133
def shakeshake_cifar10():
134134
"""Parameters for CIFAR-10."""
135+
tf.logging.warning("shakeshake_cifar10 hparams have not been verified to "
136+
"achieve good performance.")
135137
hparams = common_hparams.basic_params1()
136138
# This leads to effective batch size 128 when number of GPUs is 1
137139
hparams.batch_size = 4096 * 8

‎tensor2tensor/utils/expert_utils.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,8 @@ def daisy_chain_getter(getter, name, *args, **kwargs):
200200
else:
201201
var = getter(name, *args, **kwargs)
202202
v = tf.identity(var._ref()) # pylint: disable=protected-access
203+
_add_variable_proxy_methods(var, v)
203204
# update the cache
204-
_add_variable_proxy_methods(var, v)
205205
cache[name] = v
206206
cache[device_var_key] = v
207207
return v
@@ -210,10 +210,12 @@ def daisy_chain_getter(getter, name, *args, **kwargs):
210210
# so we make a custom getter that uses identity to cache the variable.
211211
# pylint: disable=cell-var-from-loop
212212
def caching_getter(getter, name, *args, **kwargs):
213-
v = getter(name, *args, **kwargs)
213+
"""Cache variables on device."""
214214
key = (self._caching_devices[i], name)
215215
if key in cache:
216216
return cache[key]
217+
218+
v = getter(name, *args, **kwargs)
217219
with tf.device(self._caching_devices[i]):
218220
ret = tf.identity(v._ref()) # pylint: disable=protected-access
219221
_add_variable_proxy_methods(v, ret)

0 commit comments

Comments
 (0)
Please sign in to comment.