19
19
from trax import layers as tl
20
20
21
21
22
- def autoregressive_sample (model , prefix = None , inputs = None ,
22
+ def autoregressive_sample_stream (model , inputs = None ,
23
+ batch_size = 1 , temperature = 1.0 ,
24
+ start_id = 0 , accelerate = True ):
25
+ """Stream aturegressive samples from the provided model.
26
+
27
+ Note that the provided model should be an autoregressive model initialized
28
+ in 'predict' mode. In this mode, a model takes the outputs it is generating
29
+ one-by-one (instead of taking them all at once, as, e.g., during training).
30
+ Model state is used to store the intermediate information needed, and usually
31
+ the model perfoms inference in this mode faster than in 'eval' mode.
32
+
33
+ Args:
34
+ model: instance of trax.Layer, the model to sample from (at mode='predict')
35
+ inputs: optional tensor [batch_size, M]: inputs to provide to the model;
36
+ for language models (with n_in=1) we use inputs as prefix to the model
37
+ batch_size: how many batches to sample (default: 1)
38
+ temperature: sampling temperature (default: 1.0)
39
+ start_id: int, id for the start symbol fed at the beginning (default: 1)
40
+ accelerate: whether to accelerate the model before decoding (default: True)
41
+
42
+ Yields:
43
+ Tensor of ints of shape [batch_size] containing subsequent
44
+ autoregressive samples from the model.
45
+ """
46
+ if inputs is not None and inputs .shape [0 ] != batch_size :
47
+ raise ValueError (f'Inputs batch size { inputs .shape [0 ]} != { batch_size } .' )
48
+ fast_model = tl .Accelerate (model ) if accelerate else model
49
+ cur_symbol = np .full ((batch_size , 1 ), start_id , dtype = np .int32 )
50
+ if inputs is not None and model .n_in == 1 : # use inputs as prefix
51
+ cur_symbol = np .concatenate ([cur_symbol , inputs ], axis = 1 )
52
+ while True :
53
+ model_input = cur_symbol
54
+ if inputs is not None and model .n_in > 1 :
55
+ model_input = (inputs , cur_symbol )
56
+ logits = fast_model (model_input )
57
+ if inputs is not None and model .n_in > 1 :
58
+ logits = logits [0 ] # Pick first element from model output (a pair here)
59
+ sample = tl .logsoftmax_sample (logits [:, - 1 , :], temperature = temperature )
60
+ yield sample
61
+ # Note: we're using 'predict' mode autoregressive models here, so history
62
+ # is caches in the model state and we are only feeding one symbol next.
63
+ cur_symbol = sample [:, None ]
64
+
65
+
66
+ def autoregressive_sample (model , inputs = None ,
23
67
batch_size = 1 , temperature = 1.0 ,
24
68
start_id = 0 , eos_id = 1 , max_length = 100 ,
25
69
accelerate = True ):
@@ -33,8 +77,8 @@ def autoregressive_sample(model, prefix=None, inputs=None,
33
77
34
78
Args:
35
79
model: instance of trax.Layer, the model to sample from (at mode='predict')
36
- prefix : optional tensor [batch_size, L ]: prefix for decoding
37
- inputs: optional tensor [batch_size, M]: inputs to provide to the model
80
+ inputs : optional tensor [batch_size, M ]: inputs to provide to the model;
81
+ for language models (with n_in=1) we use inputs as prefix to the model
38
82
batch_size: how many batches to sample (default: 1)
39
83
temperature: sampling temperature (default: 1.0)
40
84
start_id: int, id for the start symbol fed at the beginning (default: 1)
@@ -46,32 +90,22 @@ def autoregressive_sample(model, prefix=None, inputs=None,
46
90
a tensor of ints of shape [batch_size, N] with N <= max_length containing
47
91
the autoregressively sampled output from the model
48
92
"""
49
- if prefix is not None and prefix .shape [0 ] != batch_size :
50
- raise ValueError (f'Prefix batch size { prefix .shape [0 ]} != { batch_size } .' )
51
- if inputs is not None and inputs .shape [0 ] != batch_size :
52
- raise ValueError (f'Inputs batch size { inputs .shape [0 ]} != { batch_size } .' )
53
- fast_model = tl .Accelerate (model ) if accelerate else model
54
- cur_symbol = np .full ((batch_size , 1 ), start_id , dtype = np .int32 )
55
- if prefix is not None :
56
- cur_symbol = np .concatenate ([cur_symbol , prefix ], axis = 1 )
57
93
result = []
58
94
eos_seen = []
59
- for _ in range (max_length ):
60
- model_input = cur_symbol if inputs is None else (inputs , cur_symbol )
61
- logits = fast_model (model_input )
62
- if inputs is not None :
63
- logits = logits [0 ] # Pick first element from model output (a pair here)
64
- sample = tl .gumbel_sample (logits [:, - 1 , :], temperature = temperature )
95
+ counter = 0
96
+ for sample in autoregressive_sample_stream (
97
+ model , inputs , batch_size = batch_size , temperature = temperature ,
98
+ start_id = start_id , accelerate = accelerate ):
65
99
sample = sample [:, None ]
66
100
result .append (sample )
67
- # Note: we're using 'predict' mode autoregressive models here, so history
68
- # is caches in the model state and we are only feeding one symbol next.
69
- cur_symbol = sample
101
+ counter += 1
102
+ if counter >= max_length :
103
+ return np . concatenate ( result , axis = 1 )
70
104
# Check at which batch positions have we already encountered EOS.
71
105
for j in range (batch_size ):
72
106
if int (sample [j , 0 ]) == eos_id :
73
107
eos_seen .append (j )
74
108
# If EOS has been seen on all positions, stop.
75
109
if all ([j in eos_seen for j in range (batch_size )]):
76
- break
110
+ return np . concatenate ( result , axis = 1 )
77
111
return np .concatenate (result , axis = 1 )
0 commit comments