-
Notifications
You must be signed in to change notification settings - Fork 2.4k
add VLM SFT engine pipeline #3590
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a supervised fine-tuning (SFT) pipeline for Vision Language Models (VLMs), including data processing, training scripts, and necessary modifications to the dataset and engine components. The changes are extensive and enable multi-turn and multi-image training for models like Qwen 2.5 VL.
My review has identified a few critical and high-severity issues. These include a leftover breakpoint() call that could halt execution, in-place modification of data structures which can lead to subtle bugs and race conditions, and an incorrect docstring in a data preprocessing script. Addressing these points will improve the robustness and maintainability of the new VLM SFT pipeline.
| multi_modal_inputs = batch.pop("multi_modal_inputs", None) | ||
| if multi_modal_inputs is not None: | ||
| assert len(multi_modal_inputs) == len(batch["input_ids"]), \ | ||
| "Length of 'multi_modal_inputs' must match the batch size." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function rearrange_micro_batches modifies its input batch by using pop("multi_modal_inputs", None). This creates an undocumented side effect that can lead to bugs if the caller reuses the batch object, as it will be missing the multi_modal_inputs key. It's safer to access the value without modifying the dictionary, or to work on a copy. Using TensorDict.exclude() would be a safer alternative as it returns a new TensorDict without the specified key.
| multi_modal_inputs = batch.pop("multi_modal_inputs", None) | |
| if multi_modal_inputs is not None: | |
| assert len(multi_modal_inputs) == len(batch["input_ids"]), \ | |
| "Length of 'multi_modal_inputs' must match the batch size." | |
| multi_modal_inputs = batch.get("multi_modal_inputs", None) | |
| if multi_modal_inputs is not None: | |
| batch = batch.exclude("multi_modal_inputs") | |
| assert len(multi_modal_inputs) == len(batch["input_ids"]), \ | |
| "Length of 'multi_modal_inputs' must match the batch size." |
What does this PR do?
This PR adds a VLM SFT engine pipeline, which complements volcengine/verl#3589.
Currently implemented features: