Skip to content

Commit ba72d57

Browse files
author
pytorchbot
committed
2025-09-20 nightly release (2a68dde)
1 parent 930ac32 commit ba72d57

File tree

4 files changed

+100
-7
lines changed

4 files changed

+100
-7
lines changed

torchrec/ir/tests/test_serializer.py

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ def test_serialize_deserialize_ebc(self) -> None:
297297
self.assertEqual(len(deserialized_out), len(eager_out))
298298
for deserialized, orginal in zip(deserialized_out, eager_out):
299299
self.assertEqual(deserialized.shape, orginal.shape)
300-
self.assertTrue(torch.allclose(deserialized, orginal))
300+
torch.testing.assert_close(deserialized, orginal)
301301

302302
def test_serialize_deserialize_ebc_with_vbe_kjt(self) -> None:
303303
model = self.generate_model_for_vbe_kjt()
@@ -374,14 +374,14 @@ def test_serialize_deserialize_ebc_with_vbe_kjt(self) -> None:
374374
self.assertEqual(len(deserialized_out), len(eager_out))
375375
for deserialized, orginal in zip(deserialized_out, eager_out):
376376
self.assertEqual(deserialized.shape, orginal.shape)
377-
self.assertTrue(torch.allclose(deserialized, orginal))
377+
torch.testing.assert_close(deserialized, orginal)
378378

379379
deserialized_out_2 = deserialized_model(kjt_2)
380380

381381
self.assertEqual(len(deserialized_out_2), len(eager_out_2))
382382
for deserialized, orginal in zip(deserialized_out_2, eager_out_2):
383383
self.assertEqual(deserialized.shape, orginal.shape)
384-
self.assertTrue(torch.allclose(deserialized, orginal))
384+
torch.testing.assert_close(deserialized, orginal)
385385

386386
def test_dynamic_shape_ebc_disabled_in_oss_compatibility(self) -> None:
387387
model = self.generate_model()
@@ -428,7 +428,61 @@ def test_dynamic_shape_ebc_disabled_in_oss_compatibility(self) -> None:
428428

429429
for i, tensor in enumerate(deserialized_out):
430430
self.assertEqual(eager_out[i].shape, tensor.shape)
431-
assert torch.allclose(eager_out[i], tensor)
431+
torch.testing.assert_close(eager_out[i], tensor)
432+
433+
def test_variable_batch_size_ebc_disabled_in_oss_compatibility(self) -> None:
434+
model = self.generate_model()
435+
feature1 = KeyedJaggedTensor.from_offsets_sync(
436+
keys=["f1", "f2", "f3"],
437+
values=torch.tensor([0, 1, 2, 3, 2, 3]),
438+
offsets=torch.tensor([0, 2, 2, 3, 4, 5, 6]), # batch size = 2
439+
)
440+
441+
feature2 = KeyedJaggedTensor.from_offsets_sync(
442+
keys=["f1", "f2", "f3"],
443+
values=torch.tensor([0, 1, 2, 3, 2, 3, 4, 5, 6]),
444+
offsets=torch.tensor([0, 2, 2, 3, 4, 5, 7, 8, 8, 9]), # batch size = 3
445+
)
446+
eager_out1 = model(feature1)
447+
eager_out2 = model(feature2)
448+
# feature1.lengths()
449+
# feature2.lengths()
450+
451+
# Serialize EBC with sample input (feature1, batch size = 2)
452+
collection = mark_dynamic_kjt(feature1, variable_batch=True)
453+
model, sparse_fqns = encapsulate_ir_modules(model, JsonSerializer)
454+
ep = torch.export.export(
455+
model,
456+
(feature1,),
457+
{},
458+
dynamic_shapes=collection.dynamic_shapes(model, (feature1,)),
459+
strict=False,
460+
# Allows KJT to not be unflattened and run a forward on unflattened EP
461+
preserve_module_call_signature=tuple(sparse_fqns),
462+
)
463+
464+
# Run forward on ExportedProgram
465+
ep_output1 = ep.module()(feature1)
466+
ep_output2 = ep.module()(feature2)
467+
468+
# other asserts
469+
for eager_out, ep_out in [(eager_out1, ep_output1), (eager_out2, ep_output2)]:
470+
for a, b in zip(eager_out, ep_out):
471+
self.assertEqual(a.shape, b.shape)
472+
473+
# Deserialize EBC
474+
unflatten_ep = torch.export.unflatten(ep)
475+
deserialized_model = decapsulate_ir_modules(unflatten_ep, JsonSerializer)
476+
deserialized_model.load_state_dict(model.state_dict())
477+
478+
# Run forward on deserialized model
479+
deserialized_out1 = deserialized_model(feature1)
480+
deserialized_out2 = deserialized_model(feature2)
481+
482+
for e, d in ([eager_out1, deserialized_out1], [eager_out2, deserialized_out2]):
483+
for a, b in zip(e, d):
484+
self.assertEqual(a.shape, b.shape)
485+
torch.testing.assert_close(a, b)
432486

433487
def test_ir_emb_lookup_device(self) -> None:
434488
model = self.generate_model()
@@ -573,7 +627,7 @@ def forward(self, features: KeyedJaggedTensor) -> List[torch.Tensor]:
573627
deserialized_out = deserialized_model(id_list_features)
574628
self.assertEqual(len(deserialized_out), len(eager_out))
575629
for x, y in zip(deserialized_out, eager_out):
576-
self.assertTrue(torch.allclose(x, y))
630+
torch.testing.assert_close(x, y)
577631

578632
def test_regroup_as_dict_module(self) -> None:
579633
class Model(nn.Module):

torchrec/ir/utils.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ def mark_dynamic_kjt(
195195
kjt: KeyedJaggedTensor,
196196
shapes_collection: Optional[ShapesCollection] = None,
197197
variable_length: bool = False,
198+
variable_batch: bool = False,
198199
vlen: Optional[DIM] = None,
199200
llen: Optional[DIM] = None,
200201
) -> ShapesCollection:
@@ -211,10 +212,18 @@ def mark_dynamic_kjt(
211212
it will use the default name "vlen" for values, and "llen", "lofs" if variable length.
212213
A passed-in dynamic dim is useful if the dynamic dim is already used in other places.
213214
215+
variable batch size means the batch size is dynamic during different training iterations
216+
the batch size for all features are the same within one iteration/batch. so it still follows
217+
the correlation: len(lengths) == len(keys) * batch_size
218+
219+
in the variable length scenario, the batch size could be different for each feature within
220+
the iteration/batch, so it doesn't follow the correlation: len(lengths) == len(keys) * batch_size
221+
214222
Args:
215223
kjt (KeyedJaggedTensor): The KJT to make dynamic.
216224
shapes_collection (Optional[ShapesCollection]): The collection to update.
217-
variable_length (bool): Whether the KJT is variable length.
225+
variable_length (bool): Whether the KJT is variable length len(lengths) != len(keys) * batch_size
226+
variable_batch (bool): Whether the KJT is variable batch size, len(lengths) == len(keys) * batch_size, it only works when variable_length is False.
218227
vlen (Optional[DIM]): The dynamic length for the values. If it's None, it will use the default name "vlen".
219228
llen (Optional[DIM]): The dynamic length for the lengths, it's only used when variable_length is true. If it's None, it will use the default name "llen".
220229
batch_size (Optional[DIM]): The dynamic length for the batch_size, it's only used when variable_length and mark_batch_size are both true.
@@ -245,6 +254,21 @@ def _has_dim(t: Optional[torch.Tensor]) -> bool:
245254
shapes_collection[kjt._lengths] = (llen,)
246255
if _has_dim(kjt._offsets):
247256
shapes_collection[kjt._offsets] = (llen + 1,)
257+
elif variable_batch:
258+
# variable batch size means the batch size is dynamic during different training iterations
259+
# the batch size for all features are the same within one iteration/batch
260+
#
261+
# this is fundamentally different from variable length, where the batch size is different
262+
# for each feature within one iteration/batch
263+
#
264+
# it's the user's responsibility to make sure that in a variable batch scenario,
265+
# the argument variable_batch is only used when setting variable_length to False,
266+
# otherwise it will lead to unexpected behavior with the dynamic shapes in torch.export
267+
batch_size = _get_dim("batch_size")
268+
if _has_dim(kjt._lengths):
269+
shapes_collection[kjt._lengths] = (batch_size * len(kjt.keys()),)
270+
if _has_dim(kjt._offsets):
271+
shapes_collection[kjt._offsets] = (batch_size * len(kjt.keys()) + 1,)
248272
return shapes_collection
249273

250274

torchrec/modules/tests/test_itep_embedding_modules.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,11 @@ def generate_expected_address_lookup_buffer(
190190

191191
return torch.tensor(address_lookup, dtype=torch.int64)
192192

193+
# pyre-ignore[56]: Pyre was not able to infer the type of argument
194+
@unittest.skipIf(
195+
torch.cuda.device_count() <= 1,
196+
"Not enough GPUs, this test requires at least two GPUs",
197+
)
193198
def test_init_itep_module(self) -> None:
194199
itep_module = GenericITEPModule(
195200
table_name_to_unpruned_hash_sizes=self._table_name_to_unpruned_hash_sizes,
@@ -222,6 +227,11 @@ def test_init_itep_module(self) -> None:
222227
equal_nan=True,
223228
)
224229

230+
# pyre-ignore[56]: Pyre was not able to infer the type of argument
231+
@unittest.skipIf(
232+
torch.cuda.device_count() <= 1,
233+
"Not enough GPUs, this test requires at least two GPUs",
234+
)
225235
def test_init_itep_module_without_pruned_table(self) -> None:
226236
itep_module = GenericITEPModule(
227237
table_name_to_unpruned_hash_sizes={},
@@ -353,6 +363,11 @@ def test_eval_forward(
353363
# Check that reset_weight_momentum is not called
354364
self.assertEqual(mock_reset_weight_momentum.call_count, 0)
355365

366+
# pyre-ignore[56]: Pyre was not able to infer the type of argument
367+
@unittest.skipIf(
368+
torch.cuda.device_count() <= 1,
369+
"Not enough GPUs, this test requires at least two GPUs",
370+
)
356371
def test_iter_increment_per_forward(self) -> None:
357372
"""Test that the iteration counter increments correctly with each forward pass."""
358373
itep_module = GenericITEPModule(

torchrec/sparse/jagged_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1066,7 +1066,7 @@ def _jt_flatten_spec(t: JaggedTensor, spec: TreeSpec) -> List[Optional[torch.Ten
10661066
def _assert_tensor_has_no_elements_or_has_integers(
10671067
tensor: Optional[torch.Tensor], tensor_name: str
10681068
) -> None:
1069-
if is_torchdynamo_compiling() or tensor is None:
1069+
if torch.compiler.is_compiling() or tensor is None:
10701070
# Skipping the check tensor.numel() == 0 to not guard on pt2 symbolic shapes.
10711071
# TODO(ivankobzarev): Use guard_size_oblivious to pass tensor.numel() == 0 once it is torch scriptable.
10721072
return

0 commit comments

Comments
 (0)