Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

past_key_values not being set in model_inputs keys #36001

Open
1 of 4 tasks
AEJaspan opened this issue Feb 1, 2025 · 0 comments
Open
1 of 4 tasks

past_key_values not being set in model_inputs keys #36001

AEJaspan opened this issue Feb 1, 2025 · 0 comments
Labels

Comments

@AEJaspan
Copy link

AEJaspan commented Feb 1, 2025

System Info

transformers-cli env

Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.

- `transformers` version: 4.48.2
- Platform: Linux-5.15.167.4-microsoft-standard-WSL2-x86_64-with-glibc2.35
- Python version: 3.10.12
- Huggingface_hub version: 0.28.1
- Safetensors version: 0.5.2
- Accelerate version: 1.3.0
- Accelerate config:    not found
- PyTorch version (GPU?): 2.5.1+cu124 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?: <fill in>
- Using GPU in script?: <fill in>
- GPU type: NVIDIA GeForce RTX 4070 Laptop GPU

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I was broadly following the steps outlined in this project's example: https://github.com/aehrc/cxrmate/blob/main/examples/cxrmate.ipynb

pip install --force-reinstall transformers==4.48.2

import torch, transformers
from PIL import Image
from torchvision import transforms
import pathlib

ckpt_name = 'aehrc/cxrmate'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

encoder_decoder = transformers.AutoModel.from_pretrained(ckpt_name, trust_remote_code=True).to(device)
encoder_decoder.eval()
tokenizer = transformers.PreTrainedTokenizerFast.from_pretrained(ckpt_name)
image_processor = transformers.AutoFeatureExtractor.from_pretrained(ckpt_name)

test_transforms = transforms.Compose(
    [
        transforms.Resize(size=image_processor.size['shortest_edge']),
        transforms.CenterCrop(size=[
            image_processor.size['shortest_edge'],
            image_processor.size['shortest_edge'],
        ]
        ),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=image_processor.image_mean,
            std=image_processor.image_std,
        ),
    ]
)
image_1 = Image.open(pathlib.Path.cwd() / 'image1.jpg')
image_1 = image_1.convert('RGB')
image_1 = test_transforms(image_1)
stack = torch.stack([image_1, image_1], dim=0)

images = torch.nn.utils.rnn.pad_sequence([stack, stack], batch_first=True, padding_value=0.0)
images.shape
previous_findings = [None, None]
previous_impression = [None, None]

# Tokenize prompt:
prompt = encoder_decoder.tokenize_prompt(
    previous_findings, 
    previous_impression, 
    tokenizer, 
    256, 
    add_bos_token_id=True,
)
outputs = encoder_decoder.generate(
    pixel_values=images.to(device),
    decoder_input_ids=prompt['input_ids'],
    special_token_ids=[
        tokenizer.additional_special_tokens_ids[
            tokenizer.additional_special_tokens.index('[PMT-SEP]')
        ],
        tokenizer.bos_token_id,
        tokenizer.sep_token_id,
    ],  
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
    pad_token_id=tokenizer.pad_token_id,
    mask_token_id=tokenizer.pad_token_id,
    return_dict_in_generate=True,
    use_cache=True,
    max_length=256 + prompt['input_ids'].shape[1],
    num_beams=4,
)
if torch.all(outputs.sequences[:, 0] == 1):
    outputs.sequences = outputs.sequences[:, 1:]

outputs.sequences
File {...}/python3.10/site-packages/transformers/generation/utils.py:3467, in GenerationMixin._beam_search(self, input_ids, beam_scorer, logits_processor, stopping_criteria, generation_config, synced_gpus, **model_kwargs)
   3464 decoder_prompt_len = input_ids.shape[-1]  # record the prompt length of decoder
   3466 while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
-> 3467     model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
   3469     # prepare variable output controls (note: some models won't accept all output controls)
   3470     model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})

File ~/.cache/huggingface/modules/transformers_modules/aehrc/cxrmate/de123c14139435361ca20f3cf94abc9a8b06ca19/modelling_longitudinal.py:291, in LongitudinalPromptMultiCXREncoderDecoderModel.prepare_inputs_for_generation(self, input_ids, special_token_ids, mask_token_id, past_key_values, attention_mask, use_cache, encoder_outputs, **kwargs)
    281     token_type_ids = self.token_ids_to_token_type_ids_past(input_ids, special_token_ids, [0, 1, 0, 1])
    282     decoder_position_ids = decoder_position_ids[:, -1:]
    284 input_dict = {
    285     'attention_mask': attention_mask,
    286     'decoder_attention_mask': decoder_attention_mask,
    287     'decoder_input_ids': decoder_inputs['input_ids'],
    288     'decoder_token_type_ids': token_type_ids,
    289     'decoder_position_ids': decoder_position_ids,
    290     'encoder_outputs': encoder_outputs,
--> 291     'past_key_values': decoder_inputs['past_key_values'],
    292     'use_cache': use_cache,
    293 }
    294 return input_dict

KeyError: 'past_key_values'

I found the fix in my case was to change:

if past_key_values is not None:
    model_inputs["past_key_values"] = past_key_values

to

model_inputs["past_key_values"] = past_key_values
if model_inputs["past_key_values"]:
    model_inputs["past_key_values"] = past_key_values

in lines 384-386 of transformers/generation/utils.py

Expected behavior

I would expect a default variable to be applied to the model_inputs["past_key_values"] variable, as has been assumed by the cxrmate package.

@AEJaspan AEJaspan added the bug label Feb 1, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant