Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add qwen2vl for sequence classification #34086

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

sorenmc
Copy link

@sorenmc sorenmc commented Oct 11, 2024

What does this PR do?

Adds sequence classification for qwen2-vl. This work was done because there are currently no way to do text-image classification in transformers. This is useful for rerankers, reward models etc. Mostly copied and stitched together from Qwen2VLForConditionalGeneration and LLamaForSequenceClassification.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@ArthurZucker I saw you review the original qwen2 vl PR, so would love for you to review this also. Let me know if you think I need to do some refactoring or other things.

@LysandreJik
Copy link
Member

Maybe @zucchini-nlp can take a look as well!

@sorenmc
Copy link
Author

sorenmc commented Oct 21, 2024

@ArthurZucker @zucchini-nlp Any of you who could take a look at this?

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I missed this one. LGTM in general but I am not sure if we add dedicated classes for specific tasks when there is no checkpoint for that model. I'll leave that questiion for @ArthurZucker

I left a couple comments, mostly nits. Thanks for working on this!

src/transformers/models/qwen2_vl/modeling_qwen2_vl.py Outdated Show resolved Hide resolved
src/transformers/models/qwen2_vl/modeling_qwen2_vl.py Outdated Show resolved Hide resolved
Comment on lines +321 to +323
@unittest.skip("LM test not for VLM")
def test_attention_outputs(self):
pass
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did all the tests start failing after SequenceClassification was added? These tests should be okay with VLMs so I don't think it is a good idea to skip them. We should rather try to fix it

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes these tests started failing after adding sequenceclassification! I can try to look into it

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like it is a problem with the pooling of logits that was copied from llamaForSequenceClassification. Will investigate further.

@ArthurZucker
Copy link
Collaborator

In general, we add these classes if:

  • the feature request issue related has user activity (10 reactions for example)
  • we have strong demand from the community
  • a paper or a pretrained model exist
  • there is an official implementation of this task
    I am not certain that in this case we are filling this no?
    Or is there a source / reference implementation for this?

@sorenmc
Copy link
Author

sorenmc commented Oct 31, 2024

In general, we add these classes if:

  • the feature request issue related has user activity (10 reactions for example)
  • we have strong demand from the community
  • a paper or a pretrained model exist
  • there is an official implementation of this task
    I am not certain that in this case we are filling this no?
    Or is there a source / reference implementation for this?

Sorry for not being active here, have had some busy weeks

I have not seen any implementation or talk about this, but this is probably because multimodality is still not very mainstream given the increased compute requirements. We have seen big releases recently for retrieval (col-pali, col-qwen, dse),
general multimodal lms (qwen2-vl, llava-OneVision llama3.2 vision eg.) and now we are starting to see reward models (llava-critic ). Therefore I believe it's just a matter of time before we will see multimodal sequence classifiers used for reward modelling and retrieval rerankers. Currently none of these model backbones supports these 2 tasks, so why not have atleast one reference implementation?

I currently have this need, and would much prefer if this was a part of the standard library.

@sorenmc
Copy link
Author

sorenmc commented Nov 21, 2024

For anyone else looking for vision language reranker/cross-encoder checkout this model released by lightonai that is taking qwen2-vl 2b and finetuning it on vidore.
Blog post
Model page

@Garibelhj
Copy link

I am using the Qwen2-VL-2B model for a classification task, and I want to modify the Qwen2VLForConditionalGeneration by adding a linear classification head to adapt it to the task. I am unsure whether my modification is correct, as the test results show that the model has almost no classification ability after adding the classification head. Below is my training code.

class Qwen2VLWithClassificationHead(Qwen2VLForConditionalGeneration):
    def __init__(self, config):
        super().__init__(config)
        # 定义分类头
        self.mha_layer = torch.nn.MultiheadAttention(embed_dim=512, kdim=512, vdim=512, num_heads=1, batch_first=True)#num_heads=1
        self.sigmoid = nn.Sigmoid()
        self.model = Qwen2VLModel(config)

        self.classification_head = nn.Linear(1536,  7,bias =False) 
        torch.nn.init.xavier_uniform_(self.classification_head.weight) 
        self.post_init()

    def get_input_embeddings(self):
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        self.model.embed_tokens = value


    def forward(
        self,
        input_ids = None,
        attention_mask = None,
        position_ids = None,
        past_key_values = None,
        inputs_embeds = None,
        labels = None,
        labels_int = None,
        prob_id = None,
        use_cache = None,
        output_attentions = None,
        output_hidden_states = None,
        return_dict = None,
        pixel_values = None,
        pixel_values_videos = None,
        image_grid_thw= None,
        video_grid_thw = None,
        rope_deltas = None):

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if inputs_embeds is None:
            inputs_embeds = self.model.embed_tokens(input_ids)
            if pixel_values is not None:
                pixel_values = pixel_values.type(self.visual.get_dtype())
                image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
                image_mask = (
                    (input_ids == self.config.image_token_id)
                    .unsqueeze(-1)
                    .expand_as(inputs_embeds)
                    .to(inputs_embeds.device)
                )
                image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
                inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)

            if pixel_values_videos is not None:
                pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
                video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
                video_mask = (
                    (input_ids == self.config.video_token_id)
                    .unsqueeze(-1)
                    .expand_as(inputs_embeds)
                    .to(inputs_embeds.device)
                )
                video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
                inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)

            if attention_mask is not None:
                attention_mask = attention_mask.to(inputs_embeds.device)

        transformer_outputs = self.model(
            input_ids=None,
            position_ids=position_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        hidden_states = transformer_outputs[0]
        # 获取 hidden_states
        batch_size = input_ids.shape[0]
        classification_logits = self.classification_head(hidden_states)
        pool_classification_logits = classification_logits[torch.arange(batch_size,device=classification_logits.device),-1]
        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(pool_classification_logits.view(-1, 7), labels_int.view(-1))


        return SequenceClassifierOutputWithPast(
            loss=loss,
            logits=pool_classification_logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )

def process_func(example):
    """
    将数据集进行预处理
    """
    number_map = {
    0:'(A) Pornographic Websites',
    1:'(B) Gambling Websites',
    2:'(C) Prize Scam Websites',
    3:'(D) Phishing Websites',
    4:'(E) Malicious Distribution  Websites',
    5:'(F) Fraudulent E-Commerce Website',
    6:'(G) Fraudulent Financial Services Website',
}
    MAX_LENGTH = 8192
    input_ids, attention_mask, labels = [], [], []
    conversation = example["conversations"]
    prompt = conversation['prompt']
    image = conversation['image']
    _id = example['id']
    solution = conversation['solution']
    prob_id = torch.tensor(int(_id))

    answer = conversation['answer']
    messages = [
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "image": f"{image}",
                    "resized_height": 280,
                    "resized_width": 280,
                },
                {"type": "text", "text": prompt},
            ],
        }
    ]
    labels =number_map[answer]
    labels = f"The answer is {labels}."+"Beacuse:"+solution
    response = tokenizer(f"The answer is {labels}",add_special_tokens=False)
    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )  # 获取文本
    image_inputs, video_inputs = process_vision_info(messages)  # 获取数据数据(预处理过)
    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to(device)

    inputs = {key: value.tolist() for key, value in inputs.items()} #tensor -> list,为了方便拼接
    instruction = inputs

    input_ids = (
            instruction["input_ids"][0] + response["input_ids"] + [tokenizer.pad_token_id]
    )

    attention_mask = instruction["attention_mask"][0] + response["attention_mask"] + [1]
    labels = (
            [-100] * len(instruction["input_ids"][0])
            + response["input_ids"]
            + [tokenizer.pad_token_id]
    )

    if len(input_ids) > MAX_LENGTH:  # 做一个截断
        input_ids = input_ids[:MAX_LENGTH]
        attention_mask = attention_mask[:MAX_LENGTH]
        labels = labels[:MAX_LENGTH]

    input_ids = torch.tensor(input_ids)
    attention_mask = torch.tensor(attention_mask)
    labels_seq = torch.tensor(labels)
    labels_int = torch.tensor(int(answer))
    inputs['pixel_values'] = torch.tensor(inputs['pixel_values'])
    inputs['image_grid_thw'] = torch.tensor(inputs['image_grid_thw']).squeeze(0)  #由(1,h,w)变换为(h,w)

    return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels_seq,'labels_int':labels_int,'prob_id':prob_id,
            "pixel_values": inputs['pixel_values'], "image_grid_thw": inputs['image_grid_thw']}





# 在modelscope上下载Qwen2-VL模型到本地目录下
# model_dir = snapshot_download("Qwen/Qwen2-VL-7B-Instruct", cache_dir="./", revision="master")
from transformers import AutoModel

# 使用Transformers加载模型权重
tokenizer = AutoTokenizer.from_pretrained("/home/hongjiegu/projects/qwen2vl_cot/Qwen/Qwen2-VL-2B-Instruct", use_fast=False, trust_remote_code=True)
processor = AutoProcessor.from_pretrained("/home/hongjiegu/projects/qwen2vl_cot/Qwen/Qwen2-VL-2B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)
base_model = AutoModel.from_pretrained(
    "/home/hongjiegu/projects/qwen2vl_cot/Qwen/Qwen2-VL-2B-Instruct",
    device_map=device
)

# 初始化自定义模型(关键步骤)
model = Qwen2VLWithClassificationHead.from_pretrained(
    "/home/hongjiegu/projects/qwen2vl_cot/Qwen/Qwen2-VL-2B-Instruct",
    config=base_model.config,
    ignore_mismatched_sizes=True
)
model = model.to(device)

# Step 2: 手动设置自定义参数

model.enable_input_require_grads()  # 开启梯度检查点时,要执行该方法

# 拆分成训练集和测试集,保存为data_vl_train.json和data_vl_test.json
train_dataset = train_ds.map(process_func)
eval_dataset = eval_ds.map(process_func)



# 配置训练参数
args = TrainingArguments(
    output_dir="/home/hongjiegu/projects/GMMR/checkpoint/linear_classifier",
    per_device_train_batch_size=2,
    gradient_accumulation_steps=2,
    logging_steps=10,
    logging_first_step=5,
    num_train_epochs=10,
    save_steps=1000,
    learning_rate=5e-5,
    save_on_each_node=True,
    gradient_checkpointing=True,
    report_to="none",
)
        
# 设置SwanLab回调
swanlab_callback = SwanLabCallback(
    project="Qwen2-VL-finetune_2classes",
    experiment_name="GMMR_FWC_stage1",
)

# 配置Trainer
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),

)

# 开启模型训练
trainer.train()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants