@@ -19,6 +19,7 @@ def arrangement(
1919 attn_mask ,
2020 scale ,
2121 output ,
22+ with_attn_mask ,
2223 with_kv_cache ,
2324 BLOCK_SIZE_M = BLOCK_SIZE_M ,
2425 BLOCK_SIZE_N = BLOCK_SIZE_N ,
@@ -68,6 +69,7 @@ def arrange_attn_mask(input):
6869 attn_mask_arranged = arrange_attn_mask (attn_mask )
6970 scale_arranged = scale
7071 output_arranged = arrange_query_or_output (output )
72+ with_attn_mask_arranged = with_attn_mask
7173
7274 if with_kv_cache :
7375 return (
@@ -81,6 +83,7 @@ def arrange_attn_mask(input):
8183 attn_mask_arranged ,
8284 scale_arranged ,
8385 output_arranged ,
86+ with_attn_mask_arranged ,
8487 )
8588
8689 return (
@@ -90,6 +93,7 @@ def arrange_attn_mask(input):
9093 attn_mask_arranged ,
9194 scale_arranged ,
9295 output_arranged ,
96+ with_attn_mask_arranged ,
9397 )
9498
9599
@@ -104,14 +108,19 @@ def application_with_kv_cache(
104108 attn_mask ,
105109 scale ,
106110 output ,
111+ with_attn_mask ,
107112):
108113 present_key_slot = present_key # noqa: F841
109114 present_value_slot = present_value # noqa: F841
110115
111- application_without_kv_cache (query , key , value , attn_mask , scale , output )
116+ application_without_kv_cache (
117+ query , key , value , attn_mask , scale , output , with_attn_mask
118+ )
112119
113120
114- def application_without_kv_cache (query , key , value , attn_mask , scale , output ):
121+ def application_without_kv_cache (
122+ query , key , value , attn_mask , scale , output , with_attn_mask
123+ ):
115124 for i in range (query .shape [0 ]):
116125 query_i = (1.4426950408889634 * scale * query [i ]).to (query [i ].dtype )
117126
@@ -120,9 +129,12 @@ def application_without_kv_cache(query, key, value, attn_mask, scale, output):
120129 max = ntl .full ((query_i .shape [- 2 ],), float ("-inf" ), dtype = ntl .float32 )
121130
122131 for j in range (key .shape [0 ]):
123- qk = ntl .dot (query_i , ntl .trans (key [j ])) + attn_mask [ j ]
132+ qk = ntl .dot (query_i , ntl .trans (key [j ]))
124133 qk = ntl .where (key [j ].offsets (- 2 ) < key .source .shape [- 2 ], qk , float ("-inf" ))
125134
135+ if with_attn_mask :
136+ qk += attn_mask [j ]
137+
126138 next_max = ntl .maximum (max , ntl .max (qk , 1 ))
127139 stable_qk = ntl .exp2 (qk - next_max [:, None ])
128140
@@ -156,6 +168,7 @@ def make(with_kv_cache):
156168 for _ in range (4 )
157169 )
158170 scale = Tensor (0 )
171+ with_attn_mask = Tensor (0 , constexpr = True )
159172
160173 if with_kv_cache :
161174 application = application_with_kv_cache
@@ -173,6 +186,7 @@ def make(with_kv_cache):
173186 attn_mask ,
174187 scale ,
175188 output ,
189+ with_attn_mask ,
176190 )
177191
178192 return ninetoothed .make (
0 commit comments