Skip to content

Conversation

@The-Hierophant
Copy link
Contributor

What does this PR do?

This PR adds a VLM SFT engine pipeline, which complements volcengine/verl#3589.

Currently implemented features:

  • Qwen 2.5 VL with FSDP backend
  • Multi-Turn & Multi images Training
  • Sequence balancing

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a supervised fine-tuning (SFT) pipeline for Vision Language Models (VLMs), including data processing, training scripts, and necessary modifications to the dataset and engine components. The changes are extensive and enable multi-turn and multi-image training for models like Qwen 2.5 VL.

My review has identified a few critical and high-severity issues. These include a leftover breakpoint() call that could halt execution, in-place modification of data structures which can lead to subtle bugs and race conditions, and an incorrect docstring in a data preprocessing script. Addressing these points will improve the robustness and maintainability of the new VLM SFT pipeline.

Comment on lines +282 to +285
multi_modal_inputs = batch.pop("multi_modal_inputs", None)
if multi_modal_inputs is not None:
assert len(multi_modal_inputs) == len(batch["input_ids"]), \
"Length of 'multi_modal_inputs' must match the batch size."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The function rearrange_micro_batches modifies its input batch by using pop("multi_modal_inputs", None). This creates an undocumented side effect that can lead to bugs if the caller reuses the batch object, as it will be missing the multi_modal_inputs key. It's safer to access the value without modifying the dictionary, or to work on a copy. Using TensorDict.exclude() would be a safer alternative as it returns a new TensorDict without the specified key.

Suggested change
multi_modal_inputs = batch.pop("multi_modal_inputs", None)
if multi_modal_inputs is not None:
assert len(multi_modal_inputs) == len(batch["input_ids"]), \
"Length of 'multi_modal_inputs' must match the batch size."
multi_modal_inputs = batch.get("multi_modal_inputs", None)
if multi_modal_inputs is not None:
batch = batch.exclude("multi_modal_inputs")
assert len(multi_modal_inputs) == len(batch["input_ids"]), \
"Length of 'multi_modal_inputs' must match the batch size."

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant