@@ -46,6 +46,9 @@ def __init__(
46
46
num_heads ,
47
47
max_seq_len ,
48
48
num_buckets ,
49
+ sparsity ,
50
+ target_size ,
51
+ sort_by_length ,
49
52
requires_grad ,
50
53
persistent_kernel : bool = False ,
51
54
) -> None :
@@ -54,6 +57,9 @@ def __init__(
54
57
self .num_heads = num_heads
55
58
self .max_seq_len = max_seq_len
56
59
self .num_buckets = num_buckets
60
+ self .sparsity = sparsity
61
+ self .target_size = target_size
62
+ self .sort_by_length = sort_by_length
57
63
self .all_ts_weights = torch .nn .Parameter (
58
64
torch .randn (
59
65
(self .num_buckets + 1 ,),
@@ -73,7 +79,11 @@ def __init__(
73
79
self .persistent_kernel = persistent_kernel
74
80
75
81
def forward (
76
- self , qkv : torch .Tensor , seq_offsets : torch .Tensor , timestamps : torch .Tensor
82
+ self ,
83
+ qkv : torch .Tensor ,
84
+ seq_offsets : torch .Tensor ,
85
+ timestamps : torch .Tensor ,
86
+ num_targets : torch .Tensor ,
77
87
) -> torch .Tensor :
78
88
NUM_BUCKETS = self .num_buckets
79
89
torch ._check (timestamps .size (0 ) + 1 == seq_offsets .size (0 ))
@@ -99,7 +109,7 @@ def forward(
99
109
"PW" : self .all_pos_weights ,
100
110
"Bias" : None ,
101
111
"seq2_offsets" : None ,
102
- "num_targets" : None ,
112
+ "num_targets" : num_targets ,
103
113
"Scale" : None ,
104
114
"Out" : out ,
105
115
"stride_qm" : q .stride (0 ),
@@ -171,25 +181,75 @@ def forward(
171
181
kwargs ["ATTN_BIAS_TYPE" ], # relative_bias_type
172
182
kwargs ["MAX_ATTN_LEN" ], # max_attn_len
173
183
kwargs ["CONTEXTUAL_SEQ_LEN" ], # contextual_seq_len
174
- kwargs [ "sort_by_length_indices" ], # sort_by_length
184
+ self . sort_by_length ,
175
185
)
176
186
177
187
return out
178
188
179
189
190
+ def generate_sparse_seq_len (
191
+ size : int ,
192
+ max_seq_len : int ,
193
+ sparsity : float ,
194
+ device : torch .device ,
195
+ ) -> torch .Tensor :
196
+ if sparsity == 0.0 :
197
+ return torch .zeros (size = (size ,), device = device , dtype = torch .int )
198
+ elif sparsity == 1.0 :
199
+ return torch .ones (size = (size ,), device = device , dtype = torch .int ) * max_seq_len
200
+ elif sparsity >= 0.5 :
201
+ min_seq_len : int = int ((2 * sparsity - 1.0 ) * max_seq_len )
202
+ return torch .randint (
203
+ low = min_seq_len ,
204
+ high = max_seq_len ,
205
+ size = (size ,),
206
+ device = device ,
207
+ dtype = torch .int ,
208
+ )
209
+ else :
210
+ min_seq_len : int = 0
211
+ max_seq_len : int = int (2 * sparsity * max_seq_len )
212
+ return torch .randint (
213
+ low = min_seq_len ,
214
+ high = max_seq_len ,
215
+ size = (size ,),
216
+ device = device ,
217
+ dtype = torch .int ,
218
+ )
219
+
220
+
180
221
def get_test_inputs (
181
- batch_size , num_heads , max_seq_len , requires_grad
182
- ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
222
+ batch_size ,
223
+ num_heads ,
224
+ max_seq_len ,
225
+ sparsity ,
226
+ target_size ,
227
+ sort_by_length ,
228
+ requires_grad ,
229
+ ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ]:
183
230
timestamp_deltas : torch .Tensor = torch .randint (
184
231
86400 ,
185
232
size = (batch_size , max_seq_len + 1 ),
186
233
).cuda ()
187
234
timestamps = timestamp_deltas .cumsum (dim = 1 )
188
235
189
- lengths = torch .randint (
190
- max_seq_len + 1 ,
191
- size = (batch_size ,),
192
- ).cuda ()
236
+ lengths = generate_sparse_seq_len (
237
+ size = batch_size ,
238
+ max_seq_len = max_seq_len ,
239
+ sparsity = sparsity ,
240
+ device = torch .device ("cuda" ),
241
+ )
242
+ # assume has_delta_q is False
243
+ num_targets = None
244
+ if target_size != 0 :
245
+ num_targets = torch .randint (
246
+ 1 ,
247
+ target_size + 1 ,
248
+ (batch_size ,),
249
+ device = lengths .device ,
250
+ dtype = lengths .dtype ,
251
+ )
252
+ num_targets = torch .where (num_targets > lengths , lengths , num_targets )
193
253
seq_offsets = torch .zeros (
194
254
(batch_size + 1 ,),
195
255
dtype = torch .int64 ,
@@ -208,4 +268,4 @@ def get_test_inputs(
208
268
.requires_grad_ (requires_grad )
209
269
.cuda ()
210
270
)
211
- return qkv , seq_offsets , timestamps
271
+ return qkv , seq_offsets , timestamps , num_targets
0 commit comments