Skip to content

Commit 412ed27

Browse files
coufonfacebook-github-bot
authored andcommitted
Support skipping offsets computation in KJT to_dict (#3503)
Summary: This diff optimizes KeyedJaggedTensor.to_dict by optionally removing several unnecessary offset computation. It can evidently improve latency if offsets are not needed in the outputs. Reviewed By: guowentian Differential Revision: D85995629
1 parent ce16675 commit 412ed27

File tree

2 files changed

+74
-22
lines changed

2 files changed

+74
-22
lines changed

torchrec/sparse/jagged_tensor.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1402,6 +1402,7 @@ def _maybe_compute_kjt_to_jt_dict(
14021402
variable_stride_per_key: bool,
14031403
weights: Optional[torch.Tensor],
14041404
jt_dict: Optional[Dict[str, JaggedTensor]],
1405+
compute_offsets: bool = True,
14051406
) -> Dict[str, JaggedTensor]:
14061407
if not length_per_key:
14071408
return {}
@@ -1418,50 +1419,49 @@ def _maybe_compute_kjt_to_jt_dict(
14181419
torch._check(cat_size <= total_size)
14191420
torch._check(cat_size == total_size)
14201421
torch._check_is_size(stride)
1422+
14211423
values_list = torch.split(values, length_per_key)
1424+
split_offsets: list[torch.Tensor] = []
14221425
if variable_stride_per_key:
14231426
split_lengths = torch.split(lengths, stride_per_key)
1424-
split_offsets = [
1425-
torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)
1426-
for lengths in split_lengths
1427-
]
1427+
if compute_offsets:
1428+
split_offsets = [
1429+
torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)
1430+
for lengths in split_lengths
1431+
]
14281432
elif pt2_guard_size_oblivious(lengths.numel() > 0):
14291433
strided_lengths = lengths.view(len(keys), stride)
14301434
if not torch.jit.is_scripting() and is_torchdynamo_compiling():
14311435
torch._check(strided_lengths.size(0) > 0)
14321436
torch._check(strided_lengths.size(1) > 0)
1433-
split_lengths = torch.unbind(
1434-
strided_lengths,
1435-
dim=0,
1436-
)
1437-
split_offsets = torch.unbind(
1438-
_batched_lengths_to_offsets(strided_lengths),
1439-
dim=0,
1440-
)
1437+
1438+
split_lengths = torch.unbind(strided_lengths, dim=0)
1439+
if compute_offsets:
1440+
split_offsets = torch.unbind( # pyre-ignore
1441+
_batched_lengths_to_offsets(strided_lengths), dim=0
1442+
)
14411443
else:
14421444
split_lengths = torch.unbind(lengths, dim=0)
1443-
split_offsets = torch.unbind(lengths, dim=0)
1445+
if compute_offsets:
1446+
split_offsets = split_lengths # pyre-ignore
14441447

14451448
if weights is not None:
14461449
weights_list = torch.split(weights, length_per_key)
14471450
for idx, key in enumerate(keys):
1448-
length = split_lengths[idx]
1449-
offset = split_offsets[idx]
14501451
_jt_dict[key] = JaggedTensor(
1451-
lengths=length,
1452-
offsets=offset,
1452+
lengths=split_lengths[idx],
1453+
offsets=split_offsets[idx] if compute_offsets else None,
14531454
values=values_list[idx],
14541455
weights=weights_list[idx],
14551456
)
14561457
else:
14571458
for idx, key in enumerate(keys):
1458-
length = split_lengths[idx]
1459-
offset = split_offsets[idx]
14601459
_jt_dict[key] = JaggedTensor(
1461-
lengths=length,
1462-
offsets=offset,
1460+
lengths=split_lengths[idx],
1461+
offsets=split_offsets[idx] if compute_offsets else None,
14631462
values=values_list[idx],
14641463
)
1464+
14651465
return _jt_dict
14661466

14671467

@@ -2698,11 +2698,14 @@ def __getitem__(self, key: str) -> JaggedTensor:
26982698
offsets=None,
26992699
)
27002700

2701-
def to_dict(self) -> Dict[str, JaggedTensor]:
2701+
def to_dict(self, compute_offsets: bool = True) -> Dict[str, JaggedTensor]:
27022702
"""
27032703
Returns a dictionary of JaggedTensor for each key.
27042704
Will cache result in self._jt_dict.
27052705
2706+
Args:
2707+
compute_offsets (str): compute offsets when true.
2708+
27062709
Returns:
27072710
Dict[str, JaggedTensor]: dictionary of JaggedTensor for each key.
27082711
"""
@@ -2720,6 +2723,7 @@ def to_dict(self) -> Dict[str, JaggedTensor]:
27202723
variable_stride_per_key=self.variable_stride_per_key(),
27212724
weights=self.weights_or_none(),
27222725
jt_dict=self._jt_dict,
2726+
compute_offsets=compute_offsets,
27232727
)
27242728
self._jt_dict = _jt_dict
27252729
return _jt_dict

torchrec/sparse/tests/test_jagged_tensor.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -811,6 +811,54 @@ def test_from_jt_dict_vb(self) -> None:
811811
torch.equal(j1.values(), torch.Tensor([3.0, 4.0, 5.0, 6.0, 7.0, 8.0]))
812812
)
813813

814+
def test_to_dict_compute_offsets_false(self) -> None:
815+
# Setup: KJT with two keys and standard stride
816+
values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
817+
keys = ["f1", "f2"]
818+
lengths = torch.IntTensor([2, 0, 1, 1, 1, 3])
819+
820+
kjt = KeyedJaggedTensor(values=values, keys=keys, lengths=lengths)
821+
822+
# Execute: call to_dict with compute_offsets=False
823+
jt_dict = kjt.to_dict(compute_offsets=False)
824+
825+
# Assert: offsets_or_none() should be None for each JaggedTensor
826+
self.assertIsNone(jt_dict["f1"].offsets_or_none())
827+
self.assertIsNone(jt_dict["f2"].offsets_or_none())
828+
# Lengths should still be available
829+
self.assertTrue(
830+
torch.equal(jt_dict["f1"].lengths(), torch.IntTensor([2, 0, 1]))
831+
)
832+
self.assertTrue(
833+
torch.equal(jt_dict["f2"].lengths(), torch.IntTensor([1, 1, 3]))
834+
)
835+
836+
def test_to_dict_compute_offsets_false_variable_stride(self) -> None:
837+
# Setup: KJT with variable stride per key (reusing test_from_jt_dict_vb data)
838+
values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
839+
keys = ["f1", "f2"]
840+
lengths = torch.IntTensor([2, 0, 1, 1, 1, 3])
841+
stride_per_key_per_rank = [[2], [4]]
842+
843+
kjt = KeyedJaggedTensor(
844+
values=values,
845+
keys=keys,
846+
lengths=lengths,
847+
stride_per_key_per_rank=stride_per_key_per_rank,
848+
)
849+
850+
# Execute: call to_dict with compute_offsets=False
851+
jt_dict = kjt.to_dict(compute_offsets=False)
852+
853+
# Assert: offsets_or_none() should be None for each JaggedTensor
854+
self.assertIsNone(jt_dict["f1"].offsets_or_none())
855+
self.assertIsNone(jt_dict["f2"].offsets_or_none())
856+
# Lengths should still be available
857+
self.assertTrue(torch.equal(jt_dict["f1"].lengths(), torch.IntTensor([2, 0])))
858+
self.assertTrue(
859+
torch.equal(jt_dict["f2"].lengths(), torch.IntTensor([1, 1, 1, 3]))
860+
)
861+
814862

815863
class TestJaggedTensorTracing(unittest.TestCase):
816864
def test_jagged_tensor(self) -> None:

0 commit comments

Comments
 (0)