22from unsloth .utils import attention_dispatch as attention_dispatch_utils
33from unsloth .utils .packing import configure_sample_packing , enable_sample_packing
44
5+ from collections .abc import Iterable
56from contextlib import ExitStack
67from types import SimpleNamespace
78from unittest .mock import patch
1011import torch
1112from datasets import Dataset
1213from trl import SFTConfig , SFTTrainer
14+ from trl .trainer .sft_trainer import DataCollatorForLanguageModeling
1315
1416
1517def _build_packed_training_setup (tmp_path , device ):
@@ -120,25 +122,16 @@ def __init__(self):
120122 self .generation_config = SimpleNamespace (attn_implementation = "sdpa" )
121123
122124
123- class _DummyCollator :
124- def __init__ (self ):
125- self .padding_free = False
126- self .return_position_ids = False
127-
128- def torch_call (self , examples ):
129- batch_size = len (examples )
130- max_tokens = 4
131- return {
132- "input_ids" : torch .zeros (batch_size , max_tokens , dtype = torch .long ),
133- "attention_mask" : torch .ones (batch_size , max_tokens , dtype = torch .long ),
134- "batch" : examples ,
135- }
136-
137-
138125class _DummyTrainer :
139126 def __init__ (self ):
140127 self .args = SimpleNamespace (remove_unused_columns = True )
141- self .data_collator = _DummyCollator ()
128+ self .data_collator = DataCollatorForLanguageModeling (
129+ pad_token_id = 0 ,
130+ completion_only_loss = False ,
131+ padding_free = True ,
132+ return_position_ids = False ,
133+ return_tensors = "pt" ,
134+ )
142135
143136
144137def test_enable_sample_packing ():
@@ -151,17 +144,21 @@ def test_enable_sample_packing():
151144 assert getattr (model , "_unsloth_allow_packed_overlength" ) is True
152145 assert getattr (model .child , "_unsloth_allow_packed_overlength" ) is True
153146
154- # trainer args are updated to keep the packed metadata
155- assert trainer .args .remove_unused_columns is False
156-
157147 collator = trainer .data_collator
158- assert collator .padding_free is True
159148 assert collator .return_position_ids is True
160149 assert getattr (collator , "_unsloth_packing_wrapped" ) is True
161150
162151 examples = [
163- {"seq_lengths" : [2 , 1 ]},
164- {"seq_lengths" : [3 ]},
152+ {
153+ "input_ids" : [0 , 1 , 2 ],
154+ "labels" : [0 , 1 , 2 ],
155+ "seq_lengths" : [2 , 1 ],
156+ },
157+ {
158+ "input_ids" : [3 , 4 , 5 ],
159+ "labels" : [3 , 4 , 5 ],
160+ "seq_lengths" : [3 ],
161+ },
165162 ]
166163 batch = collator .torch_call (examples )
167164
@@ -172,13 +169,43 @@ def test_enable_sample_packing():
172169 torch .tensor ([2 , 1 , 3 ], dtype = torch .int32 ),
173170 )
174171
175- assert "position_ids" in batch
176- assert torch .equal ( batch [ "position_ids" ][ 0 , : 3 ], torch . tensor ([ 0 , 1 , 0 ], dtype = torch .long ) )
177- assert torch .equal (batch ["position_ids" ][ 1 , : 3 ], torch . tensor ([ 0 , 1 , 2 ], dtype = torch . long ) )
172+ assert batch [ "input_ids" ]. shape == ( 1 , 6 )
173+ expected_positions = torch .tensor ([ 0 , 1 , 0 , 0 , 1 , 2 ], dtype = torch .long )
174+ assert torch .equal (batch ["position_ids" ]. view ( - 1 )[: 6 ], expected_positions )
178175
179- # attention_mask is dropped when return_position_ids is set
180- assert "attention_mask" not in batch
181- assert batch ["batch" ] == examples
176+
177+ def test_enable_sample_packing_trl_collator (tmp_path ):
178+ device = torch .device ("cuda" ) if torch .cuda .is_available () else torch .device ("cpu" )
179+ model , _ , trainer , _ = _build_packed_training_setup (tmp_path , device )
180+
181+ enable_sample_packing (model , trainer )
182+
183+ examples = [
184+ {
185+ "input_ids" : [0 , 1 , 2 ],
186+ "labels" : [0 , 1 , 2 ],
187+ "seq_lengths" : [2 , 1 ],
188+ },
189+ {
190+ "input_ids" : [3 , 4 , 5 ],
191+ "labels" : [3 , 4 , 5 ],
192+ "seq_lengths" : [3 ],
193+ },
194+ ]
195+
196+ batch = trainer .data_collator .torch_call (examples )
197+
198+ assert batch ["input_ids" ].shape == (1 , 6 )
199+ assert torch .equal (
200+ batch ["packed_seq_lengths" ],
201+ torch .tensor ([2 , 1 , 3 ], dtype = torch .int32 ),
202+ )
203+
204+ expected_positions = torch .tensor ([0 , 1 , 0 , 0 , 1 , 2 ], dtype = torch .long )
205+ assert torch .equal (batch ["position_ids" ].view (- 1 )[:6 ], expected_positions )
206+
207+ if hasattr (trainer , "accelerator" ):
208+ trainer .accelerator .free_memory ()
182209
183210
184211def test_packing_sdpa (tmp_path ):
0 commit comments