@@ -164,6 +164,7 @@ def SingleOutputTransformerEncoderLayer(
164
164
dtype = None ,
165
165
sequence_len : int = None ,
166
166
single_output_forward = False ,
167
+ query_index : int = - 1 ,
167
168
):
168
169
"""Continual Single-output Transformer Encoder layer.
169
170
@@ -191,6 +192,7 @@ def SingleOutputTransformerEncoderLayer(
191
192
dtype: datatype of layer parameters. Defaults to None.
192
193
sequence_len: length of token-sequence to perform attention across. Defaults to None.
193
194
single_output_forward: whether to restrict the attention to the last token during forward. Defaults to False.
195
+ query_index: the sequence position index to compute the attention for.
194
196
195
197
Examples::
196
198
@@ -225,7 +227,7 @@ def SingleOutputTransformerEncoderLayer(
225
227
bias = True ,
226
228
batch_first = True ,
227
229
embed_dim_second = True ,
228
- query_index = - 1 ,
230
+ query_index = query_index ,
229
231
device = device ,
230
232
dtype = dtype ,
231
233
sequence_len = sequence_len ,
@@ -462,7 +464,7 @@ def TransformerEncoderLayerFactory(
462
464
463
465
Examples::
464
466
465
- encoder_layer = co.TransformerEncoderLayerFactory(d_model=512, nhead=8)
467
+ encoder_layer = co.TransformerEncoderLayerFactory(d_model=512, nhead=8, sequence_len=32 )
466
468
transformer_encoder = co.TransformerEncoder(encoder_layer, num_layers=2)
467
469
src = torch.rand(10, 512, 32)
468
470
out = transformer_encoder(src)
@@ -527,7 +529,7 @@ class TransformerEncoder(Sequential):
527
529
528
530
Examples::
529
531
530
- encoder_layer = co.TransformerEncoderLayerFactory(d_model=512, nhead=8)
532
+ encoder_layer = co.TransformerEncoderLayerFactory(d_model=512, nhead=8, sequence_len=32 )
531
533
transformer_encoder = co.TransformerEncoder(encoder_layer, num_layers=2)
532
534
src = torch.rand(10, 512, 32)
533
535
out = transformer_encoder(src)
0 commit comments