Skip to content

Commit 73f0d7f

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
add variable batch_size for training (#3388)
Summary: Pull Request resolved: #3388 Pull Request resolved: #3387 # context * APS is using "variable batch size" during training, e.g., using a smaller `batch_size` (like 32) to warm up then use a larger `batch_size` (like 64) for the rest of training. ``` batch_size_schedule: - batch_size: 32 max_iters: 5 - batch_size: 64 max_iters: 999999999 ``` * however, this becomes a problem for torch.export (PT2 IR) because the exported program assumes the `batch_size` to be constant. NOTE: this "variable batch" concept is fundamentally different from the "variable length" (VLE/VBE) * in the variable batch scenario, within the same batch/training iteration, each feature in the KJT shares the same `batch_size` (which can only vary in a later iteration), so it follows the correlation: `batch_size = length(kjt._lengths) // len(kjt._keys)`, and `kjt.stride()` returns the `batch_size` by calculation from `_lengths` and `_keys`. * in the variable length scenario, within the same batch/training iteration, each feature in the KJT could have different `batch_size`, and there's no correlation between `_lengths` and `_keys` or `batch_size`. * so this "variable batch size" **CAN NOT** simply be resolved by setting all input KJTs as variable lengths, instead, it has to use `batch_size` as a dynamic shape implicitly from the `mark_dynamic_kjt` util function. WARNING: it's the user's responsibility to make sure that the `variable_batch` is only used when setting `variable_length` to `False`, otherwise it will cause unexpected behavior with the dynamic shapes in torch.export Reviewed By: spmex, malaybag Differential Revision: D82792378
1 parent c9ec08f commit 73f0d7f

File tree

2 files changed

+84
-6
lines changed

2 files changed

+84
-6
lines changed

torchrec/ir/tests/test_serializer.py

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ def test_serialize_deserialize_ebc(self) -> None:
297297
self.assertEqual(len(deserialized_out), len(eager_out))
298298
for deserialized, orginal in zip(deserialized_out, eager_out):
299299
self.assertEqual(deserialized.shape, orginal.shape)
300-
self.assertTrue(torch.allclose(deserialized, orginal))
300+
torch.testing.assert_close(deserialized, orginal)
301301

302302
def test_serialize_deserialize_ebc_with_vbe_kjt(self) -> None:
303303
model = self.generate_model_for_vbe_kjt()
@@ -374,14 +374,14 @@ def test_serialize_deserialize_ebc_with_vbe_kjt(self) -> None:
374374
self.assertEqual(len(deserialized_out), len(eager_out))
375375
for deserialized, orginal in zip(deserialized_out, eager_out):
376376
self.assertEqual(deserialized.shape, orginal.shape)
377-
self.assertTrue(torch.allclose(deserialized, orginal))
377+
torch.testing.assert_close(deserialized, orginal)
378378

379379
deserialized_out_2 = deserialized_model(kjt_2)
380380

381381
self.assertEqual(len(deserialized_out_2), len(eager_out_2))
382382
for deserialized, orginal in zip(deserialized_out_2, eager_out_2):
383383
self.assertEqual(deserialized.shape, orginal.shape)
384-
self.assertTrue(torch.allclose(deserialized, orginal))
384+
torch.testing.assert_close(deserialized, orginal)
385385

386386
def test_dynamic_shape_ebc_disabled_in_oss_compatibility(self) -> None:
387387
model = self.generate_model()
@@ -428,7 +428,61 @@ def test_dynamic_shape_ebc_disabled_in_oss_compatibility(self) -> None:
428428

429429
for i, tensor in enumerate(deserialized_out):
430430
self.assertEqual(eager_out[i].shape, tensor.shape)
431-
assert torch.allclose(eager_out[i], tensor)
431+
torch.testing.assert_close(eager_out[i], tensor)
432+
433+
def test_variable_batch_size_ebc_disabled_in_oss_compatibility(self) -> None:
434+
model = self.generate_model()
435+
feature1 = KeyedJaggedTensor.from_offsets_sync(
436+
keys=["f1", "f2", "f3"],
437+
values=torch.tensor([0, 1, 2, 3, 2, 3]),
438+
offsets=torch.tensor([0, 2, 2, 3, 4, 5, 6]), # batch size = 2
439+
)
440+
441+
feature2 = KeyedJaggedTensor.from_offsets_sync(
442+
keys=["f1", "f2", "f3"],
443+
values=torch.tensor([0, 1, 2, 3, 2, 3, 4, 5, 6]),
444+
offsets=torch.tensor([0, 2, 2, 3, 4, 5, 7, 8, 8, 9]), # batch size = 3
445+
)
446+
eager_out1 = model(feature1)
447+
eager_out2 = model(feature2)
448+
# feature1.lengths()
449+
# feature2.lengths()
450+
451+
# Serialize EBC with sample input (feature1, batch size = 2)
452+
collection = mark_dynamic_kjt(feature1, variable_batch=True)
453+
model, sparse_fqns = encapsulate_ir_modules(model, JsonSerializer)
454+
ep = torch.export.export(
455+
model,
456+
(feature1,),
457+
{},
458+
dynamic_shapes=collection.dynamic_shapes(model, (feature1,)),
459+
strict=False,
460+
# Allows KJT to not be unflattened and run a forward on unflattened EP
461+
preserve_module_call_signature=tuple(sparse_fqns),
462+
)
463+
464+
# Run forward on ExportedProgram
465+
ep_output1 = ep.module()(feature1)
466+
ep_output2 = ep.module()(feature2)
467+
468+
# other asserts
469+
for eager_out, ep_out in [(eager_out1, ep_output1), (eager_out2, ep_output2)]:
470+
for a, b in zip(eager_out, ep_out):
471+
self.assertEqual(a.shape, b.shape)
472+
473+
# Deserialize EBC
474+
unflatten_ep = torch.export.unflatten(ep)
475+
deserialized_model = decapsulate_ir_modules(unflatten_ep, JsonSerializer)
476+
deserialized_model.load_state_dict(model.state_dict())
477+
478+
# Run forward on deserialized model
479+
deserialized_out1 = deserialized_model(feature1)
480+
deserialized_out2 = deserialized_model(feature2)
481+
482+
for e, d in ([eager_out1, deserialized_out1], [eager_out2, deserialized_out2]):
483+
for a, b in zip(e, d):
484+
self.assertEqual(a.shape, b.shape)
485+
torch.testing.assert_close(a, b)
432486

433487
def test_ir_emb_lookup_device(self) -> None:
434488
model = self.generate_model()
@@ -573,7 +627,7 @@ def forward(self, features: KeyedJaggedTensor) -> List[torch.Tensor]:
573627
deserialized_out = deserialized_model(id_list_features)
574628
self.assertEqual(len(deserialized_out), len(eager_out))
575629
for x, y in zip(deserialized_out, eager_out):
576-
self.assertTrue(torch.allclose(x, y))
630+
torch.testing.assert_close(x, y)
577631

578632
def test_regroup_as_dict_module(self) -> None:
579633
class Model(nn.Module):

torchrec/ir/utils.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ def mark_dynamic_kjt(
195195
kjt: KeyedJaggedTensor,
196196
shapes_collection: Optional[ShapesCollection] = None,
197197
variable_length: bool = False,
198+
variable_batch: bool = False,
198199
vlen: Optional[DIM] = None,
199200
llen: Optional[DIM] = None,
200201
) -> ShapesCollection:
@@ -211,10 +212,18 @@ def mark_dynamic_kjt(
211212
it will use the default name "vlen" for values, and "llen", "lofs" if variable length.
212213
A passed-in dynamic dim is useful if the dynamic dim is already used in other places.
213214
215+
variable batch size means the batch size is dynamic during different training iterations
216+
the batch size for all features are the same within one iteration/batch. so it still follows
217+
the correlation: len(lengths) == len(keys) * batch_size
218+
219+
in the variable length scenario, the batch size could be different for each feature within
220+
the iteration/batch, so it doesn't follow the correlation: len(lengths) == len(keys) * batch_size
221+
214222
Args:
215223
kjt (KeyedJaggedTensor): The KJT to make dynamic.
216224
shapes_collection (Optional[ShapesCollection]): The collection to update.
217-
variable_length (bool): Whether the KJT is variable length.
225+
variable_length (bool): Whether the KJT is variable length len(lengths) != len(keys) * batch_size
226+
variable_batch (bool): Whether the KJT is variable batch size, len(lengths) == len(keys) * batch_size, it only works when variable_length is False.
218227
vlen (Optional[DIM]): The dynamic length for the values. If it's None, it will use the default name "vlen".
219228
llen (Optional[DIM]): The dynamic length for the lengths, it's only used when variable_length is true. If it's None, it will use the default name "llen".
220229
batch_size (Optional[DIM]): The dynamic length for the batch_size, it's only used when variable_length and mark_batch_size are both true.
@@ -245,6 +254,21 @@ def _has_dim(t: Optional[torch.Tensor]) -> bool:
245254
shapes_collection[kjt._lengths] = (llen,)
246255
if _has_dim(kjt._offsets):
247256
shapes_collection[kjt._offsets] = (llen + 1,)
257+
elif variable_batch:
258+
# variable batch size means the batch size is dynamic during different training iterations
259+
# the batch size for all features are the same within one iteration/batch
260+
#
261+
# this is fundamentally different from variable length, where the batch size is different
262+
# for each feature within one iteration/batch
263+
#
264+
# it's the user's responsibility to make sure that in a variable batch scenario,
265+
# the argument variable_batch is only used when setting variable_length to False,
266+
# otherwise it will lead to unexpected behavior with the dynamic shapes in torch.export
267+
batch_size = _get_dim("batch_size")
268+
if _has_dim(kjt._lengths):
269+
shapes_collection[kjt._lengths] = (batch_size * len(kjt.keys()),)
270+
if _has_dim(kjt._offsets):
271+
shapes_collection[kjt._offsets] = (batch_size * len(kjt.keys()) + 1,)
248272
return shapes_collection
249273

250274

0 commit comments

Comments
 (0)