Skip to content

Fix batched generation for prompts of different lengths #2216

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

RunFMe
Copy link

@RunFMe RunFMe commented Mar 27, 2025

Previous code

def _fast_prepare_inputs_for_generation(self, input_ids, **kwargs,):
    past_key_values = kwargs.get("past_key_values", None)
    if past_key_values is not None:
        # Check for uninitialized DynamicCache
        if len(past_key_values) == 0:
            past_key_values = None
            kwargs["past_key_values"] = None
        else:
            input_ids = input_ids[:,[-1]]
            kwargs["attention_mask"] = kwargs["attention_mask"][:,[-1]]

    if "cache_position" in kwargs:
        kwargs["position_ids"] = kwargs["cache_position"]
    return { "input_ids" : input_ids, **kwargs, }
pass

_fast_prepare_inputs_for_generation is called to get forward method arguments for generating new token. Notice that it crops attention_mask and only takes into account the last value. Attention mask of length 1 triggers flag in the forward which makes attention mask None.

I fixed it by copying a piece of code from traditional prepare_inputs_for_generation which calls base_model._prepare_4d_causal_attention_mask_with_cache_position if it's present. This way we allow models which have this function (most popular) to create attention mask as they see fit.

@RunFMe
Copy link
Author

RunFMe commented Mar 29, 2025

@Datta0 What do you think? What else should we do before merging?

Copy link
Contributor

@Datta0 Datta0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM :) !

@RunFMe
Copy link
Author

RunFMe commented Apr 3, 2025

@Datta0 do you know what the process is for getting approval of a maintainer?

@Datta0
Copy link
Contributor

Datta0 commented Apr 4, 2025

Hey @RunFMe , I will text Daniel about this
Also it'd be great if you can include samples from previous and current code in the PR description...

@RunFMe
Copy link
Author

RunFMe commented Apr 8, 2025

@Datta0 done)
Tell me, if there's anything else I can help with.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants