Skip to content

Commit a2497cb

Browse files
Lukasz Kaisercopybara-github
authored andcommitted
* Rename gumbel_sample to logsoftmax_sample for clarity.
* Add the forgotten decoding import in supervised/__init__. * Allow to access decoding.autoregressive_sample in streaming mode. PiperOrigin-RevId: 323177464
1 parent 88b033c commit a2497cb

File tree

5 files changed

+67
-32
lines changed

5 files changed

+67
-32
lines changed

trax/layers/core.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -378,12 +378,12 @@ def multigaussian_loss(preds, targets, ngauss=1): # pylint: disable=invalid-nam
378378
return fastmath.logsumexp(loglogits + glogprobs, axis=-1)
379379

380380

381-
def gumbel_sample(log_probs, temperature=1.0): # pylint: disable=invalid-name
382-
"""Returns a Gumbel sample from a categorical distribution, with temperature.
381+
def logsoftmax_sample(log_probs, temperature=1.0): # pylint: disable=invalid-name
382+
"""Returns a sample from a log-softmax output, with temperature.
383383
384384
Args:
385-
log_probs: <tbd>
386-
temperature: <tbd>
385+
log_probs: Logarithms of probabilities (often coming from LogSofmax)
386+
temperature: For scaling before sampling (1.0 = default, 0.0 = pick argmax)
387387
"""
388388
# This is equivalent to sampling from a softmax with temperature.
389389
u = np.random.uniform(low=1e-6, high=1.0 - 1e-6, size=log_probs.shape)

trax/rl/distributions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,12 +93,12 @@ def _unflatten_inputs(self, inputs):
9393
)
9494

9595
def sample(self, inputs, temperature=1.0):
96-
# No need for LogSoftmax with Gumbel sampling - softmax normalization is
97-
# subtracting a constant from every logit, and Gumbel sampling is taking
96+
# No need for LogSoftmax with sampling - softmax normalization is
97+
# subtracting a constant from every logit, and sampling is taking
9898
# a max over logits plus noise, so invariant to adding a constant.
9999
if temperature == 0.0:
100100
return jnp.argmax(self._unflatten_inputs(inputs), axis=-1)
101-
return tl.gumbel_sample(self._unflatten_inputs(inputs), temperature)
101+
return tl.logsoftmax_sample(self._unflatten_inputs(inputs), temperature)
102102

103103
def log_prob(self, inputs, point):
104104
inputs = tl.LogSoftmax()(self._unflatten_inputs(inputs))

trax/supervised/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
"""Supervised learning imports in Trax."""
1717

18+
from trax.supervised import decoding
1819
from trax.supervised import lr_schedules
1920
from trax.supervised import trainer_lib
2021
from trax.supervised import training

trax/supervised/decoding.py

Lines changed: 55 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,51 @@
1919
from trax import layers as tl
2020

2121

22-
def autoregressive_sample(model, prefix=None, inputs=None,
22+
def autoregressive_sample_stream(model, inputs=None,
23+
batch_size=1, temperature=1.0,
24+
start_id=0, accelerate=True):
25+
"""Stream aturegressive samples from the provided model.
26+
27+
Note that the provided model should be an autoregressive model initialized
28+
in 'predict' mode. In this mode, a model takes the outputs it is generating
29+
one-by-one (instead of taking them all at once, as, e.g., during training).
30+
Model state is used to store the intermediate information needed, and usually
31+
the model perfoms inference in this mode faster than in 'eval' mode.
32+
33+
Args:
34+
model: instance of trax.Layer, the model to sample from (at mode='predict')
35+
inputs: optional tensor [batch_size, M]: inputs to provide to the model;
36+
for language models (with n_in=1) we use inputs as prefix to the model
37+
batch_size: how many batches to sample (default: 1)
38+
temperature: sampling temperature (default: 1.0)
39+
start_id: int, id for the start symbol fed at the beginning (default: 1)
40+
accelerate: whether to accelerate the model before decoding (default: True)
41+
42+
Yields:
43+
Tensor of ints of shape [batch_size] containing subsequent
44+
autoregressive samples from the model.
45+
"""
46+
if inputs is not None and inputs.shape[0] != batch_size:
47+
raise ValueError(f'Inputs batch size {inputs.shape[0]} != {batch_size}.')
48+
fast_model = tl.Accelerate(model) if accelerate else model
49+
cur_symbol = np.full((batch_size, 1), start_id, dtype=np.int32)
50+
if inputs is not None and model.n_in == 1: # use inputs as prefix
51+
cur_symbol = np.concatenate([cur_symbol, inputs], axis=1)
52+
while True:
53+
model_input = cur_symbol
54+
if inputs is not None and model.n_in > 1:
55+
model_input = (inputs, cur_symbol)
56+
logits = fast_model(model_input)
57+
if inputs is not None and model.n_in > 1:
58+
logits = logits[0] # Pick first element from model output (a pair here)
59+
sample = tl.logsoftmax_sample(logits[:, -1, :], temperature=temperature)
60+
yield sample
61+
# Note: we're using 'predict' mode autoregressive models here, so history
62+
# is caches in the model state and we are only feeding one symbol next.
63+
cur_symbol = sample[:, None]
64+
65+
66+
def autoregressive_sample(model, inputs=None,
2367
batch_size=1, temperature=1.0,
2468
start_id=0, eos_id=1, max_length=100,
2569
accelerate=True):
@@ -33,8 +77,8 @@ def autoregressive_sample(model, prefix=None, inputs=None,
3377
3478
Args:
3579
model: instance of trax.Layer, the model to sample from (at mode='predict')
36-
prefix: optional tensor [batch_size, L]: prefix for decoding
37-
inputs: optional tensor [batch_size, M]: inputs to provide to the model
80+
inputs: optional tensor [batch_size, M]: inputs to provide to the model;
81+
for language models (with n_in=1) we use inputs as prefix to the model
3882
batch_size: how many batches to sample (default: 1)
3983
temperature: sampling temperature (default: 1.0)
4084
start_id: int, id for the start symbol fed at the beginning (default: 1)
@@ -46,32 +90,22 @@ def autoregressive_sample(model, prefix=None, inputs=None,
4690
a tensor of ints of shape [batch_size, N] with N <= max_length containing
4791
the autoregressively sampled output from the model
4892
"""
49-
if prefix is not None and prefix.shape[0] != batch_size:
50-
raise ValueError(f'Prefix batch size {prefix.shape[0]} != {batch_size}.')
51-
if inputs is not None and inputs.shape[0] != batch_size:
52-
raise ValueError(f'Inputs batch size {inputs.shape[0]} != {batch_size}.')
53-
fast_model = tl.Accelerate(model) if accelerate else model
54-
cur_symbol = np.full((batch_size, 1), start_id, dtype=np.int32)
55-
if prefix is not None:
56-
cur_symbol = np.concatenate([cur_symbol, prefix], axis=1)
5793
result = []
5894
eos_seen = []
59-
for _ in range(max_length):
60-
model_input = cur_symbol if inputs is None else (inputs, cur_symbol)
61-
logits = fast_model(model_input)
62-
if inputs is not None:
63-
logits = logits[0] # Pick first element from model output (a pair here)
64-
sample = tl.gumbel_sample(logits[:, -1, :], temperature=temperature)
95+
counter = 0
96+
for sample in autoregressive_sample_stream(
97+
model, inputs, batch_size=batch_size, temperature=temperature,
98+
start_id=start_id, accelerate=accelerate):
6599
sample = sample[:, None]
66100
result.append(sample)
67-
# Note: we're using 'predict' mode autoregressive models here, so history
68-
# is caches in the model state and we are only feeding one symbol next.
69-
cur_symbol = sample
101+
counter += 1
102+
if counter >= max_length:
103+
return np.concatenate(result, axis=1)
70104
# Check at which batch positions have we already encountered EOS.
71105
for j in range(batch_size):
72106
if int(sample[j, 0]) == eos_id:
73107
eos_seen.append(j)
74108
# If EOS has been seen on all positions, stop.
75109
if all([j in eos_seen for j in range(batch_size)]):
76-
break
110+
return np.concatenate(result, axis=1)
77111
return np.concatenate(result, axis=1)

trax/supervised/decoding_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ def test_autoregressive_sample_transformerlm(self):
5353
self.assertLess(s2.shape[1], 11)
5454
model.init(shapes.ShapeDtype((1, 1), dtype=np.int32))
5555
prefix = np.array([[1, 2, 3]])
56-
s3 = decoding.autoregressive_sample(model, eos_id=-1, max_length=10,
57-
batch_size=1, prefix=prefix)
56+
s3 = decoding.autoregressive_sample(model, prefix, eos_id=-1, max_length=10,
57+
batch_size=1)
5858
self.assertEqual(s3.shape[0], 1)
5959
self.assertEqual(s3.shape[1], 10)
6060

@@ -131,7 +131,7 @@ def test_autoregressive_sample_transformerlm_quality(self):
131131
pred_model.init_from_file(model_path, weights_only=True,
132132
input_signature=(shape11, shape11))
133133
inputs = np.array([[0, 3, 7, 5, 3, 2, 4, 0]], dtype=np.int32)
134-
s = decoding.autoregressive_sample(pred_model, prefix=inputs,
134+
s = decoding.autoregressive_sample(pred_model, inputs,
135135
max_length=6, temperature=0.0)
136136
self.assertEqual(str(s[0]), '[3 7 5 3 2 4]')
137137

@@ -146,7 +146,7 @@ def test_autoregressive_sample_reformerlm_quality(self):
146146
pred_model.init_from_file(model_path, weights_only=True,
147147
input_signature=(shape11, shape11))
148148
inputs = np.array([[0, 3, 7, 5, 3, 2, 4, 0]], dtype=np.int32)
149-
s = decoding.autoregressive_sample(pred_model, prefix=inputs,
149+
s = decoding.autoregressive_sample(pred_model, inputs,
150150
max_length=6, temperature=0.0)
151151
self.assertEqual(str(s[0]), '[3 7 5 3 2 4]')
152152

0 commit comments

Comments
 (0)