@@ -166,14 +166,28 @@ def get_all_free_blocks(self) -> List[KVCacheBlock]:
166
166
return ret
167
167
168
168
169
- def generate_block_hash_extra_keys (
170
- request : Request , start_token_idx : int , end_token_idx : int ,
171
- start_mm_idx : int ) -> Tuple [Optional [Tuple [Any , ...]], int ]:
172
- """Generate extra keys for the block hash. The extra keys can come from
173
- the multi-modal inputs and request specific metadata (e.g., LoRA ID).
174
- For multi-modal inputs, the extra keys are (mm_hash, start_offset) that
175
- indicate a mm input contained in the block and its starting offset in
176
- the block tokens.
169
+ def need_extra_keys (request : Request ) -> bool :
170
+ """Check whether the blocks allocated to this request need extra hash keys.
171
+
172
+ Args:
173
+ request (Request): The request.
174
+
175
+ Returns:
176
+ bool: Whether blocks allocated to this request need extra hash keys.
177
+ """
178
+
179
+ # Multimodal requests need to include the MM hash.
180
+ # LoRA requests need to include the LoRA ID.
181
+ return bool (request .mm_positions ) or (request .lora_request is not None )
182
+
183
+
184
+ def _gen_mm_extra_hash_keys (request : Request , start_token_idx : int ,
185
+ end_token_idx : int ,
186
+ start_mm_idx : int ) -> Tuple [List [Any ], int ]:
187
+ """Generate extra keys related to MultiModal request for block hash
188
+ computation. For multi-modal inputs, the extra keys are
189
+ (mm_hash, start_offset) that indicate a mm input contained in the
190
+ block and its starting offset in the block tokens.
177
191
178
192
Args:
179
193
request: The request object.
@@ -184,10 +198,11 @@ def generate_block_hash_extra_keys(
184
198
Returns:
185
199
A tuple of extra keys and the next multi-modal index.
186
200
"""
201
+ extra_keys : List [Any ] = []
187
202
188
203
mm_positions , mm_hashes = request .mm_positions , request .mm_hashes
189
204
if not mm_positions :
190
- return None , start_mm_idx
205
+ return extra_keys , start_mm_idx
191
206
192
207
if mm_positions and len (mm_positions ) != len (mm_hashes ):
193
208
raise ValueError (
@@ -200,14 +215,13 @@ def generate_block_hash_extra_keys(
200
215
# range. This usually happens in the late prefill phase and decoding phase.
201
216
if mm_positions [- 1 ]["offset" ] + mm_positions [- 1 ][
202
217
"length" ] < start_token_idx :
203
- return None , start_mm_idx
218
+ return extra_keys , start_mm_idx
204
219
205
220
# Support start_mm_idx == -1 to indicate the last mm input.
206
221
if start_mm_idx < 0 :
207
222
assert - start_mm_idx <= len (mm_positions )
208
223
start_mm_idx = len (mm_positions ) + start_mm_idx
209
224
210
- extra_keys = []
211
225
curr_mm_idx = start_mm_idx
212
226
while mm_positions and curr_mm_idx < len (mm_positions ):
213
227
assert mm_hashes [curr_mm_idx ] is not None
@@ -233,7 +247,50 @@ def generate_block_hash_extra_keys(
233
247
else :
234
248
# This block has not reached the current mm input.
235
249
break
236
- return tuple (extra_keys ), curr_mm_idx
250
+ return extra_keys , curr_mm_idx
251
+
252
+
253
+ def _gen_lora_extra_hash_keys (request : Request ) -> List [int ]:
254
+ """Generate extra keys related to LoRA for block hash computation.
255
+
256
+ Args:
257
+ request: The request object.
258
+
259
+ Returns:
260
+ Return LoRA id of the request if it is a LoRA request. Return empty
261
+ list otherwise.
262
+ """
263
+ if not request .lora_request :
264
+ return []
265
+ return [request .lora_request .lora_int_id ]
266
+
267
+
268
+ def generate_block_hash_extra_keys (
269
+ request : Request , start_token_idx : int , end_token_idx : int ,
270
+ start_mm_idx : int ) -> Tuple [Optional [Tuple [Any , ...]], int ]:
271
+ """Generate extra keys for the block hash. The extra keys can come from
272
+ the multi-modal inputs and request specific metadata (e.g., LoRA ID).
273
+
274
+ Args:
275
+ request: The request object.
276
+ start_token_idx: The start token index of the block.
277
+ end_token_idx: The end token index of the block.
278
+ start_mm_idx: The start multi-modal index of the block.
279
+
280
+ Returns:
281
+ A tuple of extra keys and the next multi-modal index.
282
+ """
283
+ mm_extra_keys : List [Any ]
284
+ mm_extra_keys , new_start_mm_idx = _gen_mm_extra_hash_keys (
285
+ request , start_token_idx , end_token_idx , start_mm_idx )
286
+ lora_extra_keys : List [int ] = _gen_lora_extra_hash_keys (request )
287
+
288
+ extra_keys : List [Any ] = lora_extra_keys + mm_extra_keys
289
+
290
+ if not extra_keys :
291
+ return None , new_start_mm_idx
292
+
293
+ return tuple (extra_keys ), new_start_mm_idx
237
294
238
295
239
296
def hash_block_tokens (
@@ -245,9 +302,6 @@ def hash_block_tokens(
245
302
prefix caching. We use LRU cache for this function to avoid recomputing
246
303
hash values for the same block contents.
247
304
248
- TODO: Support arbitrary metadata so that we could support more
249
- features such as LoRA adapter.
250
-
251
305
Args:
252
306
parent_block_hash: The hash of the parent block. None
253
307
if this is the first block.
@@ -276,14 +330,9 @@ def hash_request_tokens(block_size: int,
276
330
The list of computed hash values.
277
331
"""
278
332
token_ids = request .all_token_ids
279
- mm_positions , mm_hashes = request .mm_positions , request .mm_hashes
280
- if mm_positions and len (mm_positions ) != len (mm_hashes ):
281
- raise ValueError (
282
- "The number of multi-modal positions and hashes must match." )
283
333
284
- # TODO: Extend this to support other features such as LoRA.
285
- need_extra_keys = bool (mm_positions )
286
- extra_keys = None
334
+ req_need_extra_keys = need_extra_keys (request )
335
+ req_extra_keys = None
287
336
curr_mm_idx = 0
288
337
289
338
ret = []
@@ -295,13 +344,13 @@ def hash_request_tokens(block_size: int,
295
344
if len (block_token_ids ) < block_size :
296
345
break
297
346
298
- # Add extra keys if the block is a multi-modal block.
299
- if need_extra_keys :
300
- extra_keys , curr_mm_idx = generate_block_hash_extra_keys (
347
+ if req_need_extra_keys :
348
+ # MM and LoRA requests need extra keys for block-hash computation.
349
+ req_extra_keys , curr_mm_idx = generate_block_hash_extra_keys (
301
350
request , start , end , curr_mm_idx )
302
351
303
352
block_hash = hash_block_tokens (parent_block_hash_value ,
304
- block_token_ids , extra_keys )
353
+ block_token_ids , req_extra_keys )
305
354
ret .append (block_hash )
306
355
parent_block_hash_value = block_hash .hash_value
307
356
return ret
0 commit comments