Skip to content

Commit 302aa2f

Browse files
udpates + adding epochs
1 parent 9593a16 commit 302aa2f

File tree

3 files changed

+64
-59
lines changed

3 files changed

+64
-59
lines changed

mlx_vlm/lora.py

Lines changed: 15 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import mlx.optimizers as optim
66
from datasets import load_dataset
77

8-
from .prompt_utils import apply_chat_template
98
from .trainer import Dataset, TrainingArgs, Colors, train, print_trainable_parameters
109
from .trainer.utils import (
1110
apply_lora_layers,
@@ -61,31 +60,6 @@ def transform_to_messages(examples):
6160
else:
6261
raise ValueError(f"{Colors.FAIL}Dataset must have a 'messages' column or both 'question' and 'answer' columns{Colors.ENDC}")
6362

64-
if args.apply_chat_template:
65-
logger.info(f"{Colors.OKBLUE}Applying chat template to the dataset{Colors.ENDC}")
66-
67-
def process_data(examples):
68-
if config["model_type"] == "pixtral":
69-
conversations = apply_chat_template(
70-
config=config,
71-
processor=processor,
72-
prompt=examples["messages"],
73-
return_messages=True,
74-
)
75-
examples["messages"] = [
76-
json.dumps(item, ensure_ascii=False) for item in conversations
77-
]
78-
else:
79-
examples["messages"] = apply_chat_template(
80-
config=config,
81-
processor=processor,
82-
prompt=examples["messages"],
83-
return_messages=True,
84-
)
85-
return examples
86-
87-
dataset = dataset.map(process_data)
88-
8963
# Create Dataset objects
9064
train_dataset = Dataset(
9165
dataset,
@@ -95,6 +69,14 @@ def process_data(examples):
9569
image_resize_shape=args.image_resize_shape,
9670
)
9771

72+
if args.epochs is not None:
73+
dataset_size = len(train_dataset)
74+
steps_per_epoch = dataset_size // args.batch_size
75+
total_steps = steps_per_epoch * args.epochs
76+
iters = total_steps
77+
else:
78+
iters = args.iters
79+
9880
# Use train dataset for validation if no validation dataset is provided
9981
val_dataset = None
10082

@@ -144,7 +126,7 @@ def process_data(examples):
144126
# Create TrainingArgs
145127
training_args = TrainingArgs(
146128
batch_size=args.batch_size,
147-
iters=args.iters,
129+
iters=iters,
148130
steps_per_report=args.steps_per_report,
149131
steps_per_eval=args.steps_per_eval,
150132
steps_per_save=args.steps_per_save,
@@ -232,6 +214,12 @@ def process_data(examples):
232214
parser.add_argument(
233215
"--iters", type=int, default=1000, help="Number of iterations to train for"
234216
)
217+
parser.add_argument(
218+
"--epochs",
219+
type=int,
220+
default=None,
221+
help="Number of epochs to train for. If provided, overrides --iters and computes steps from dataset size and batch size.",
222+
)
235223
parser.add_argument(
236224
"--steps-per-report", type=int, default=10, help="Number of training steps between loss reporting"
237225
)

mlx_vlm/trainer/datasets.py

Lines changed: 48 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,6 @@
22
import json
33

44

5-
def get_prompt(model_type, processor, conversation):
6-
if model_type == "paligemma":
7-
return conversation
8-
9-
if "chat_template" in processor.__dict__.keys():
10-
prompt = processor.apply_chat_template(
11-
conversation,
12-
tokenize=False,
13-
add_generation_prompt=False,
14-
)
15-
elif "tokenizer" in processor.__dict__.keys():
16-
prompt = processor.tokenizer.apply_chat_template(
17-
conversation,
18-
tokenize=False,
19-
add_generation_prompt=False,
20-
)
21-
22-
return prompt
23-
24-
255
class Dataset:
266
def __init__(
277
self,
@@ -53,9 +33,21 @@ def __getitem__(self, idx):
5333
item = self.dataset[idx]
5434

5535
images = item.get("images", item.get("image", None))
56-
conversations = item.get("messages", item.get("conversations"))
57-
if images in (None, "", []):
36+
37+
if images is None or images == "" or images == []:
5838
images = []
39+
elif not isinstance(images, list):
40+
images = [images]
41+
42+
image_paths = []
43+
image_data = []
44+
for img in images:
45+
if isinstance(img, str):
46+
image_paths.append(img)
47+
else:
48+
image_data.append(img)
49+
50+
conversations = item.get("messages", item.get("conversations"))
5951
prompts = []
6052

6153
if isinstance(conversations, list) and isinstance(conversations[0], list):
@@ -67,27 +59,52 @@ def __getitem__(self, idx):
6759
"Pixtral batch processing is not supported yet. Set batch size to 1."
6860
)
6961

70-
prompt = get_prompt(
71-
self.config["model_type"], self.processor, conversation
72-
)
62+
if "chat_template" in self.processor.__dict__:
63+
prompt = self.processor.apply_chat_template(
64+
conversation,
65+
tokenize=False,
66+
add_generation_prompt=False,
67+
num_images=len(images),
68+
num_audios=0,
69+
)
70+
else:
71+
prompt = self.processor.tokenizer.apply_chat_template(
72+
conversation,
73+
tokenize=False,
74+
add_generation_prompt=False,
75+
num_images=len(images),
76+
num_audios=0,
77+
)
7378
prompts.append(prompt)
7479

7580
else:
7681
if self.config["model_type"] == "pixtral":
7782
conversations = [json.loads(i) for i in conversations]
78-
prompt = get_prompt(
79-
self.config["model_type"], self.processor, conversations
80-
)
83+
if "chat_template" in self.processor.__dict__:
84+
prompt = self.processor.apply_chat_template(
85+
conversations,
86+
tokenize=False,
87+
add_generation_prompt=False,
88+
num_images=len(images),
89+
num_audios=0,
90+
)
91+
else:
92+
prompt = self.processor.tokenizer.apply_chat_template(
93+
conversations,
94+
tokenize=False,
95+
add_generation_prompt=False,
96+
num_images=len(images),
97+
num_audios=0,
98+
)
8199
prompts.append(prompt)
82100

83-
image_token_index = getattr(self.config, "image_token_index", "image_token_id")
84101

85102
inputs = prepare_inputs(
86103
processor=self.processor,
87-
images=images,
104+
images=image_data,
88105
audio=None,
89106
prompts=prompts,
90-
image_token_index=image_token_index,
107+
image_token_index=getattr(self.config, "image_token_index", "image_token_id"),
91108
resize_shape=self.image_resize_shape
92109
)
93110

mlx_vlm/trainer/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class Colors:
2020
BOLD = '\033[1m'
2121
UNDERLINE = '\033[4m'
2222

23-
supported_for_training = {"qwen2_vl", "qwen2_5_vl"}
23+
supported_for_training = {"qwen2_vl", "qwen2_5_vl", "gemma3"}
2424

2525
def grad_checkpoint(layer):
2626
"""

0 commit comments

Comments
 (0)