Skip to content

Add LUKVPress 🤖🤖🤖#236

Merged
maxjeblick merged 4 commits into
NVIDIA:mainfrom
molanyu:add-lukv-press
Jul 2, 2026
Merged

Add LUKVPress 🤖🤖🤖#236
maxjeblick merged 4 commits into
NVIDIA:mainfrom
molanyu:add-lukv-press

Conversation

@molanyu

@molanyu molanyu commented Jun 15, 2026

Copy link
Copy Markdown
Contributor

PR description

This PR adds a new KV cache compression method, LUKV:
arxiv: https://arxiv.org/abs/2602.08585
code: https://github.com/baidu-baige/LU-KV

Checklist

Before submitting a PR, please make sure:

  • Tests are working (make test)

  • Code is formatted correctly (make style, on errors try fix with make format)

  • Copyright header is included

  • All commits are signed-off using git commit -s

  • (new press) mypress_press.py is in the presses directory

  • (new press) MyPress is in __init__.py

  • (new press) README.md is updated with a 1 liner about the new press in the Available presses section

  • (new press) New press is in the default_presses list in tests/default_presses.py

  • (new press) A docstring is provided that follows the same structure as the existing ones

@copy-pr-bot

copy-pr-bot Bot commented Jun 15, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@maxjeblick maxjeblick self-requested a review June 16, 2026 10:06

@maxjeblick maxjeblick left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for submitting the PR of LUKVPress!
I've done an initial round of review on the code, please find attached a proposed refactoring of the code here.

  • Added eager-mode guard as masking is silently ignored under eager
  • Uniform fallback is removed in favor of explicit failure
  • BUDGET_CURVE_URLS is now a dict, similar to other compression methods -> it allows to add more methods which are then looked up
  • Breaking change: In the refactor I don't enforce strict equality of expected attention press params. I'm open to dicussion here. The reason here is to 1. remove code that may be fragile when supporting more methods. 2. allow for experimentation under slight disagrement of the parameters. Another option would be to add a required_params fieild to the budget curve dictionary and compare values.
  • As in DuoAttention, we don't save the npy files to disc. Adding a global cache dir mechanism could be useful in the future.
  • load_budget_curve has been collapsed to a single method.
  • Additional refactors using codex

Please review the attached code, happy to discuss breaking changes (e.g. removal of params checks).

As for merging this PR, please also include either code to create additional budget_curves, or add instructions for it into the press' docstring, so it is possible to etend to more methods/llms.

Comment thread kvpress/presses/lukv_press.py Outdated
Comment thread kvpress/presses/lukv_press.py Outdated
Comment thread kvpress/presses/lukv_press.py
Comment thread kvpress/presses/lukv_press.py Outdated
@maxjeblick

Copy link
Copy Markdown
Collaborator
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from dataclasses import dataclass, field
from io import BytesIO
from typing import Optional

import numpy as np
import requests
import torch
from cachetools import LRUCache, cached  # type: ignore[import-untyped]
from torch import nn
from transformers import PreTrainedModel

from kvpress.presses.base_press import BasePress
from kvpress.presses.expected_attention_press import ExpectedAttentionPress
from kvpress.presses.scorer_press import ScorerPress

# (model name_or_path, scorer class name) -> curve url. Inject custom curves here for other models / scorers.
# A curve only matches the config it was profiled for, noted next to each entry.
BUDGET_CURVE_URLS = {
    # ExpectedAttentionPress(epsilon=2e-2), sink=4, window=1
    ("meta-llama/Llama-3.1-8B", "ExpectedAttentionPress"): "https://raw.githubusercontent.com/baidu-baige/LU-KV/main/evaluation/curve_data/llama-3.1-8b/ea_0.02_sink4_win1_llama_avg_ratio.npy",  # noqa: E501 # fmt: skip
}

cache = LRUCache(maxsize=128)


@cached(cache, key=lambda url: url)
def load_budget_curve(url: str) -> np.ndarray:
    """Download and cache (in-memory) a LU-KV budget curve. np.load uses allow_pickle=False."""
    return np.load(BytesIO(requests.get(url).content))


@dataclass
class LUKVPress(BasePress):
    """
    LU-KV: layer- and head-wise budget allocation around a score-based press.

    LU-KV wraps a ``ScorerPress`` and uses a pre-computed budget curve to allocate a different
    token budget to each attention layer and KV head, then evicts the lowest-scoring tokens per
    head (head-wise compression, see ``AdaKVPress``). The default configuration is
    ``LUKVPress(ExpectedAttentionPress(epsilon=2e-2), sink=4, window=1)``.

    Budget curves are model-, scorer- and configuration-specific. Published curves live in the
    module-level ``BUDGET_CURVE_URLS`` registry, keyed by ``(model name_or_path, scorer class)``.
    To use LU-KV with another model or scorer, register a curve, e.g.::

        BUDGET_CURVE_URLS[("mistralai/Mistral-7B-v0.3", "KeyDiffPress")] = "https://.../curve.npy"

    Based on Predicting Future Utility: Global Combinatorial Optimization for Task-Agnostic
    KV Cache Eviction (https://arxiv.org/abs/2602.08585).

    Parameters
    ----------
    press : ScorerPress, default=ExpectedAttentionPress(epsilon=2e-2)
        The scoring method used to rank cached tokens within each KV head.
    compression_ratio : float, default=0.0
        Fraction of KV pairs to remove globally. Selects the column of the budget curve to use.
    sink : int, default=4
        Number of initial tokens to protect from eviction.
    window : int, default=1
        Number of most recent tokens to protect from eviction.
    """

    press: ScorerPress = field(default_factory=lambda: ExpectedAttentionPress(epsilon=2e-2))
    compression_ratio: float = 0.0
    sink: int = 4
    window: int = 1

    # Loaded in post_init_from_model, reset on every call so the press never reuses a stale curve.
    _budget_curves: Optional[np.ndarray] = field(init=False, repr=False, default=None)

    def __post_init__(self):
        assert isinstance(self.press, ScorerPress), "LUKVPress requires a ScorerPress as input"
        assert 0 <= self.compression_ratio < 1, "compression_ratio must be in [0, 1)"
        assert self.sink >= 0 and self.window >= 0, "sink and window must be non-negative"

    def post_init_from_model(self, model: PreTrainedModel):
        self.press.post_init_from_model(model)
        self._budget_curves = self._load_budget_curves(model)

    def _load_budget_curves(self, model: PreTrainedModel) -> np.ndarray:
        """Look up, download (cached) and shape-check the budget curve for ``model``."""
        key = (model.config.name_or_path, type(self.press).__name__)
        url = BUDGET_CURVE_URLS.get(key)
        if url is None:
            raise KeyError(f"No LU-KV budget curve registered for {key}. Register one in BUDGET_CURVE_URLS.")

        curves = load_budget_curve(url)
        n_heads = getattr(model.config, "num_key_value_heads", None) or model.config.num_attention_heads
        assert curves.shape == (99, model.config.num_hidden_layers, n_heads), f"unexpected curve shape {curves.shape}"
        return curves

    def compress(
        self,
        module: nn.Module,
        hidden_states: torch.Tensor,
        keys: torch.Tensor,
        values: torch.Tensor,
        attentions: torch.Tensor,
        kwargs: dict,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        if self.compression_ratio == 0:
            return keys, values
        assert module.config._attn_implementation != "eager", "eager mode not supported"
        assert (
            self._budget_curves is not None
        ), "Call post_init_from_model (or use the press as a context manager) first."

        _, num_kv_heads, seq_len, _ = keys.shape
        scores = self.press.score(module, hidden_states, keys, values, attentions, kwargs)

        # Protect sink and window tokens from eviction by giving them the maximal score.
        # LU-KV targets long-context prefill, so seq_len >> sink + window and plain slicing is safe.
        protected_score = scores.max().item() + 1
        scores[..., : self.sink] = protected_score
        scores[..., seq_len - self.window :] = protected_score

        # Per-head prune ratios from the curve. Column i (0..98) maps to compression ratio (i+1)%.
        target_idx = min(98, max(0, round(self.compression_ratio * 100) - 1))
        prune_ratios = torch.as_tensor(self._budget_curves[target_idx, module.layer_idx], device=keys.device)
        ideal_keep = (1 - prune_ratios).clamp(0, 1) * seq_len

        # Largest-remainder apportionment: integer per-head keep counts summing to the global target
        keep_per_head = torch.floor(ideal_keep).long()
        remainder = int(torch.round(ideal_keep.sum()).item()) - int(keep_per_head.sum().item())
        if remainder > 0:
            top_indices = torch.topk(ideal_keep - keep_per_head, k=min(remainder, num_kv_heads)).indices
            keep_per_head[top_indices] += 1
        keep_per_head = keep_per_head.clamp(1, seq_len)

        if torch.all(keep_per_head >= seq_len):
            module.masked_key_indices = None
            return keys, values

        # Keep the highest-scoring tokens per head and mask the rest (see attention_patch.py)
        sorted_indices = torch.argsort(scores, dim=-1, descending=True, stable=True)
        rank = torch.arange(seq_len, device=scores.device).view(1, 1, seq_len).expand_as(sorted_indices)
        prune_mask = rank >= keep_per_head.view(1, num_kv_heads, 1)
        batch_indices, head_indices, rank_indices = torch.where(prune_mask)
        pruned_seq_indices = sorted_indices[batch_indices, head_indices, rank_indices]
        module.masked_key_indices = (batch_indices, head_indices, pruned_seq_indices)  # type: ignore[assignment]

        return keys, values

This is the proposed refactoring of the code.

Signed-off-by: tangziyao <672208690@qq.com>
Comment thread kvpress/presses/lukv_press.py
@maxjeblick

Copy link
Copy Markdown
Collaborator

/ok to test molanyu@4517790

@maxjeblick

Copy link
Copy Markdown
Collaborator

Hi! Please also merge latest main into your branch, we fixed an error w.r.t. github runner, causing failing tests.

molanyu added 3 commits July 1, 2026 16:44
Signed-off-by: tangziyao <672208690@qq.com>
Signed-off-by: tangziyao <672208690@qq.com>
Signed-off-by: tangziyao <672208690@qq.com>
@SimJeg

SimJeg commented Jul 2, 2026

Copy link
Copy Markdown
Collaborator

/ok to test f457238

@molanyu molanyu requested a review from maxjeblick July 2, 2026 07:52

@maxjeblick maxjeblick left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the contribution, LGTM!

@maxjeblick maxjeblick merged commit 520d418 into NVIDIA:main Jul 2, 2026
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants