Skip to content

Commit 6f6e869

Browse files
authored
[minor] phi4 train improvements (#1564)
1 parent e9c12e3 commit 6f6e869

File tree

4 files changed

+7
-8
lines changed

4 files changed

+7
-8
lines changed

configs/recipes/vision/phi4/sft/train.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ model:
3232

3333
data:
3434
train:
35-
collator_name: "vision_language_with_padding"
35+
collator_name: "vision_language_sft"
3636
use_torchdata: true
3737
datasets:
3838
- dataset_name: "merve/vqav2-small"
@@ -71,7 +71,7 @@ training:
7171
output_dir: "output/vlm_finetuned"
7272
trainer_type: "TRL_SFT"
7373
enable_gradient_checkpointing: True
74-
per_device_train_batch_size: 1 # Due to processor's handling of variable sized img-features.
74+
per_device_train_batch_size: 2
7575
gradient_accumulation_steps: 8
7676
max_steps: 20
7777

src/oumi/builders/collators.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,10 @@ def build_collator_from_config(
161161
collator_kwargs["allow_multi_image_inputs"] = (
162162
model_config.visual_config.supports_multiple_images
163163
)
164-
collator_kwargs["main_image_feature"] = (
165-
model_config.visual_config.main_image_feature
166-
)
164+
if collator_name == "vision_language_with_padding":
165+
collator_kwargs["main_image_feature"] = (
166+
model_config.visual_config.main_image_feature
167+
)
167168

168169
if collator_name == "vision_language_sft":
169170
processor_name = collator_kwargs.get(

src/oumi/core/collators/vision_language_sft_collator.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121
from oumi.core.tokenizers.base_tokenizer import BaseTokenizer
2222
from oumi.core.types import Conversation
2323

24-
_PIXEL_VALUES_KEY = "pixel_values"
25-
2624

2725
class VisionLanguageSftCollator:
2826
def __init__(

tests/unit/builders/test_collators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def test_build_collator_from_config_no_collator(mock_tokenizer):
154154
assert collator is None
155155

156156

157-
def test_build_collator_from_config_no_collator_no_tokenzier():
157+
def test_build_collator_from_config_no_collator_no_tokenizer():
158158
training_config = TrainingConfig(
159159
data=DataParams(
160160
train=DatasetSplitParams(

0 commit comments

Comments
 (0)