Skip to content

Commit 4e5f1ae

Browse files
committed
enable batch flattening
1 parent 88d45aa commit 4e5f1ae

File tree

2 files changed

+61
-75
lines changed

2 files changed

+61
-75
lines changed

tests/utils/test_packing.py

Lines changed: 55 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from unsloth.utils import attention_dispatch as attention_dispatch_utils
33
from unsloth.utils.packing import configure_sample_packing, enable_sample_packing
44

5+
from collections.abc import Iterable
56
from contextlib import ExitStack
67
from types import SimpleNamespace
78
from unittest.mock import patch
@@ -10,6 +11,7 @@
1011
import torch
1112
from datasets import Dataset
1213
from trl import SFTConfig, SFTTrainer
14+
from trl.trainer.sft_trainer import DataCollatorForLanguageModeling
1315

1416

1517
def _build_packed_training_setup(tmp_path, device):
@@ -120,25 +122,16 @@ def __init__(self):
120122
self.generation_config = SimpleNamespace(attn_implementation="sdpa")
121123

122124

123-
class _DummyCollator:
124-
def __init__(self):
125-
self.padding_free = False
126-
self.return_position_ids = False
127-
128-
def torch_call(self, examples):
129-
batch_size = len(examples)
130-
max_tokens = 4
131-
return {
132-
"input_ids": torch.zeros(batch_size, max_tokens, dtype=torch.long),
133-
"attention_mask": torch.ones(batch_size, max_tokens, dtype=torch.long),
134-
"batch": examples,
135-
}
136-
137-
138125
class _DummyTrainer:
139126
def __init__(self):
140127
self.args = SimpleNamespace(remove_unused_columns=True)
141-
self.data_collator = _DummyCollator()
128+
self.data_collator = DataCollatorForLanguageModeling(
129+
pad_token_id=0,
130+
completion_only_loss=False,
131+
padding_free=True,
132+
return_position_ids=False,
133+
return_tensors="pt",
134+
)
142135

143136

144137
def test_enable_sample_packing():
@@ -151,17 +144,21 @@ def test_enable_sample_packing():
151144
assert getattr(model, "_unsloth_allow_packed_overlength") is True
152145
assert getattr(model.child, "_unsloth_allow_packed_overlength") is True
153146

154-
# trainer args are updated to keep the packed metadata
155-
assert trainer.args.remove_unused_columns is False
156-
157147
collator = trainer.data_collator
158-
assert collator.padding_free is True
159148
assert collator.return_position_ids is True
160149
assert getattr(collator, "_unsloth_packing_wrapped") is True
161150

162151
examples = [
163-
{"seq_lengths": [2, 1]},
164-
{"seq_lengths": [3]},
152+
{
153+
"input_ids": [0, 1, 2],
154+
"labels": [0, 1, 2],
155+
"seq_lengths": [2, 1],
156+
},
157+
{
158+
"input_ids": [3, 4, 5],
159+
"labels": [3, 4, 5],
160+
"seq_lengths": [3],
161+
},
165162
]
166163
batch = collator.torch_call(examples)
167164

@@ -172,13 +169,43 @@ def test_enable_sample_packing():
172169
torch.tensor([2, 1, 3], dtype=torch.int32),
173170
)
174171

175-
assert "position_ids" in batch
176-
assert torch.equal(batch["position_ids"][0, :3], torch.tensor([0, 1, 0], dtype=torch.long))
177-
assert torch.equal(batch["position_ids"][1, :3], torch.tensor([0, 1, 2], dtype=torch.long))
172+
assert batch["input_ids"].shape == (1, 6)
173+
expected_positions = torch.tensor([0, 1, 0, 0, 1, 2], dtype=torch.long)
174+
assert torch.equal(batch["position_ids"].view(-1)[:6], expected_positions)
178175

179-
# attention_mask is dropped when return_position_ids is set
180-
assert "attention_mask" not in batch
181-
assert batch["batch"] == examples
176+
177+
def test_enable_sample_packing_trl_collator(tmp_path):
178+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
179+
model, _, trainer, _ = _build_packed_training_setup(tmp_path, device)
180+
181+
enable_sample_packing(model, trainer)
182+
183+
examples = [
184+
{
185+
"input_ids": [0, 1, 2],
186+
"labels": [0, 1, 2],
187+
"seq_lengths": [2, 1],
188+
},
189+
{
190+
"input_ids": [3, 4, 5],
191+
"labels": [3, 4, 5],
192+
"seq_lengths": [3],
193+
},
194+
]
195+
196+
batch = trainer.data_collator.torch_call(examples)
197+
198+
assert batch["input_ids"].shape == (1, 6)
199+
assert torch.equal(
200+
batch["packed_seq_lengths"],
201+
torch.tensor([2, 1, 3], dtype=torch.int32),
202+
)
203+
204+
expected_positions = torch.tensor([0, 1, 0, 0, 1, 2], dtype=torch.long)
205+
assert torch.equal(batch["position_ids"].view(-1)[:6], expected_positions)
206+
207+
if hasattr(trainer, "accelerator"):
208+
trainer.accelerator.free_memory()
182209

183210

184211
def test_packing_sdpa(tmp_path):

unsloth/utils/packing.py

Lines changed: 6 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -55,15 +55,6 @@ def enable_sample_packing(
5555
sequence_lengths_key: str = "seq_lengths",
5656
) -> None:
5757
"""Enable runtime support for packed batches on an existing trainer."""
58-
train_bs = getattr(trainer.args, "per_device_train_batch_size", 1)
59-
eval_bs = getattr(trainer.args, "per_device_eval_batch_size", 1)
60-
61-
if train_bs != 1 or eval_bs != 1:
62-
raise ValueError(
63-
"Sample packing requires per_device_train_batch_size=1 and "
64-
f"per_device_eval_batch_size=1; received {train_bs}, {eval_bs}."
65-
)
66-
6758
def _mark_allow_overlength(module):
6859
if hasattr(module, "max_seq_length"):
6960
setattr(module, "_unsloth_allow_packed_overlength", True)
@@ -72,17 +63,14 @@ def _mark_allow_overlength(module):
7263

7364
_mark_allow_overlength(model)
7465

75-
if hasattr(trainer.args, "remove_unused_columns"):
76-
trainer.args.remove_unused_columns = False
77-
7866
collator = getattr(trainer, "data_collator", None)
79-
if collator is None or not hasattr(collator, "torch_call"):
80-
return
81-
if getattr(collator, "_unsloth_packing_wrapped", False):
67+
if (
68+
collator is None
69+
or not hasattr(collator, "torch_call")
70+
or getattr(collator, "_unsloth_packing_wrapped", False)
71+
):
8272
return
8373

84-
if hasattr(collator, "padding_free"):
85-
collator.padding_free = True
8674
if hasattr(collator, "return_position_ids"):
8775
collator.return_position_ids = True
8876

@@ -92,41 +80,12 @@ def torch_call_with_lengths(examples: Sequence[dict]):
9280
batch = original_torch_call(examples)
9381
if examples and isinstance(examples[0], dict):
9482
seq_lengths: list[int] = []
95-
per_example_counts: list[int] = []
9683
for example in examples:
9784
lengths = example.get(sequence_lengths_key)
9885
if isinstance(lengths, Iterable):
99-
numeric_lengths = [int(length) for length in lengths]
100-
seq_lengths.extend(numeric_lengths)
101-
per_example_counts.append(len(numeric_lengths))
102-
else:
103-
per_example_counts.append(0)
86+
seq_lengths.extend(int(length) for length in lengths)
10487
if seq_lengths:
10588
batch["packed_seq_lengths"] = torch.tensor(seq_lengths, dtype=torch.int32)
106-
107-
position_ids = batch.get("position_ids")
108-
input_ids = batch.get("input_ids")
109-
if position_ids is None and input_ids is not None:
110-
position_ids = torch.zeros_like(
111-
input_ids, dtype=torch.long, device=input_ids.device
112-
)
113-
114-
if position_ids is not None and input_ids is not None:
115-
seq_index = 0
116-
for row_idx, count in enumerate(per_example_counts):
117-
cursor = 0
118-
for _ in range(count):
119-
length = seq_lengths[seq_index]
120-
if length > 0:
121-
position_ids[row_idx, cursor : cursor + length] = torch.arange(
122-
length, dtype=torch.long, device=position_ids.device
123-
)
124-
cursor += length
125-
seq_index += 1
126-
batch["position_ids"] = position_ids
127-
128-
if "attention_mask" in batch and getattr(collator, "return_position_ids", False):
129-
batch.pop("attention_mask")
13089
return batch
13190

13291
collator.torch_call = torch_call_with_lengths

0 commit comments

Comments
 (0)