Skip to content

Commit 101e253

Browse files
coufonfacebook-github-bot
authored andcommitted
Support skipping offsets computation in KJT to_dict (#3503)
Summary: This diff optimizes KeyedJaggedTensor.to_dict by optionally removing several unnecessary offset computation. It can evidently improve latency if offsets are not needed in the outputs. Differential Revision: D85995629
1 parent 75dfb7f commit 101e253

File tree

2 files changed

+76
-22
lines changed

2 files changed

+76
-22
lines changed

torchrec/sparse/jagged_tensor.py

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1402,6 +1402,7 @@ def _maybe_compute_kjt_to_jt_dict(
14021402
variable_stride_per_key: bool,
14031403
weights: Optional[torch.Tensor],
14041404
jt_dict: Optional[Dict[str, JaggedTensor]],
1405+
compute_offsets: bool = True,
14051406
) -> Dict[str, JaggedTensor]:
14061407
if not length_per_key:
14071408
return {}
@@ -1418,50 +1419,51 @@ def _maybe_compute_kjt_to_jt_dict(
14181419
torch._check(cat_size <= total_size)
14191420
torch._check(cat_size == total_size)
14201421
torch._check_is_size(stride)
1422+
14211423
values_list = torch.split(values, length_per_key)
1424+
split_offsets: tuple[torch.Tensor] | None = None
14221425
if variable_stride_per_key:
14231426
split_lengths = torch.split(lengths, stride_per_key)
1424-
split_offsets = [
1425-
torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)
1426-
for lengths in split_lengths
1427-
]
1427+
if compute_offsets:
1428+
split_offsets = [
1429+
torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)
1430+
for lengths in split_lengths
1431+
]
14281432
elif pt2_guard_size_oblivious(lengths.numel() > 0):
14291433
strided_lengths = lengths.view(len(keys), stride)
14301434
if not torch.jit.is_scripting() and is_torchdynamo_compiling():
14311435
torch._check(strided_lengths.size(0) > 0)
14321436
torch._check(strided_lengths.size(1) > 0)
1433-
split_lengths = torch.unbind(
1434-
strided_lengths,
1435-
dim=0,
1436-
)
1437-
split_offsets = torch.unbind(
1438-
_batched_lengths_to_offsets(strided_lengths),
1439-
dim=0,
1440-
)
1437+
1438+
split_lengths = torch.unbind(strided_lengths, dim=0)
1439+
if compute_offsets:
1440+
split_offsets = torch.unbind(
1441+
_batched_lengths_to_offsets(strided_lengths), dim=0
1442+
)
14411443
else:
14421444
split_lengths = torch.unbind(lengths, dim=0)
1443-
split_offsets = torch.unbind(lengths, dim=0)
1445+
if compute_offsets:
1446+
split_offsets = split_lengths
14441447

14451448
if weights is not None:
14461449
weights_list = torch.split(weights, length_per_key)
14471450
for idx, key in enumerate(keys):
1448-
length = split_lengths[idx]
1449-
offset = split_offsets[idx]
14501451
_jt_dict[key] = JaggedTensor(
1451-
lengths=length,
1452-
offsets=offset,
1452+
lengths=split_lengths[idx],
1453+
# pyre-ignore[16]: Undefined attribute: None has no attribute __getitem__
1454+
offsets=split_offsets[idx] if split_offsets is not None else None,
14531455
values=values_list[idx],
14541456
weights=weights_list[idx],
14551457
)
14561458
else:
14571459
for idx, key in enumerate(keys):
1458-
length = split_lengths[idx]
1459-
offset = split_offsets[idx]
14601460
_jt_dict[key] = JaggedTensor(
1461-
lengths=length,
1462-
offsets=offset,
1461+
lengths=split_lengths[idx],
1462+
# pyre-ignore[16]: Undefined attribute: None has no attribute __getitem__
1463+
offsets=split_offsets[idx] if split_offsets is not None else None,
14631464
values=values_list[idx],
14641465
)
1466+
14651467
return _jt_dict
14661468

14671469

@@ -2698,11 +2700,14 @@ def __getitem__(self, key: str) -> JaggedTensor:
26982700
offsets=None,
26992701
)
27002702

2701-
def to_dict(self) -> Dict[str, JaggedTensor]:
2703+
def to_dict(self, compute_offsets: bool = True) -> Dict[str, JaggedTensor]:
27022704
"""
27032705
Returns a dictionary of JaggedTensor for each key.
27042706
Will cache result in self._jt_dict.
27052707
2708+
Args:
2709+
compute_offsets (str): compute offsets when true.
2710+
27062711
Returns:
27072712
Dict[str, JaggedTensor]: dictionary of JaggedTensor for each key.
27082713
"""
@@ -2720,6 +2725,7 @@ def to_dict(self) -> Dict[str, JaggedTensor]:
27202725
variable_stride_per_key=self.variable_stride_per_key(),
27212726
weights=self.weights_or_none(),
27222727
jt_dict=self._jt_dict,
2728+
compute_offsets=compute_offsets,
27232729
)
27242730
self._jt_dict = _jt_dict
27252731
return _jt_dict

torchrec/sparse/tests/test_jagged_tensor.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -811,6 +811,54 @@ def test_from_jt_dict_vb(self) -> None:
811811
torch.equal(j1.values(), torch.Tensor([3.0, 4.0, 5.0, 6.0, 7.0, 8.0]))
812812
)
813813

814+
def test_to_dict_compute_offsets_false(self) -> None:
815+
# Setup: KJT with two keys and standard stride
816+
values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
817+
keys = ["f1", "f2"]
818+
lengths = torch.IntTensor([2, 0, 1, 1, 1, 3])
819+
820+
kjt = KeyedJaggedTensor(values=values, keys=keys, lengths=lengths)
821+
822+
# Execute: call to_dict with compute_offsets=False
823+
jt_dict = kjt.to_dict(compute_offsets=False)
824+
825+
# Assert: offsets_or_none() should be None for each JaggedTensor
826+
self.assertIsNone(jt_dict["f1"].offsets_or_none())
827+
self.assertIsNone(jt_dict["f2"].offsets_or_none())
828+
# Lengths should still be available
829+
self.assertTrue(
830+
torch.equal(jt_dict["f1"].lengths(), torch.IntTensor([2, 0, 1]))
831+
)
832+
self.assertTrue(
833+
torch.equal(jt_dict["f2"].lengths(), torch.IntTensor([1, 1, 3]))
834+
)
835+
836+
def test_to_dict_compute_offsets_false_variable_stride(self) -> None:
837+
# Setup: KJT with variable stride per key (reusing test_from_jt_dict_vb data)
838+
values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
839+
keys = ["f1", "f2"]
840+
lengths = torch.IntTensor([2, 0, 1, 1, 1, 3])
841+
stride_per_key_per_rank = [[2], [4]]
842+
843+
kjt = KeyedJaggedTensor(
844+
values=values,
845+
keys=keys,
846+
lengths=lengths,
847+
stride_per_key_per_rank=stride_per_key_per_rank,
848+
)
849+
850+
# Execute: call to_dict with compute_offsets=False
851+
jt_dict = kjt.to_dict(compute_offsets=False)
852+
853+
# Assert: offsets_or_none() should be None for each JaggedTensor
854+
self.assertIsNone(jt_dict["f1"].offsets_or_none())
855+
self.assertIsNone(jt_dict["f2"].offsets_or_none())
856+
# Lengths should still be available
857+
self.assertTrue(torch.equal(jt_dict["f1"].lengths(), torch.IntTensor([2, 0])))
858+
self.assertTrue(
859+
torch.equal(jt_dict["f2"].lengths(), torch.IntTensor([1, 1, 1, 3]))
860+
)
861+
814862

815863
class TestJaggedTensorTracing(unittest.TestCase):
816864
def test_jagged_tensor(self) -> None:

0 commit comments

Comments
 (0)