@@ -1402,6 +1402,7 @@ def _maybe_compute_kjt_to_jt_dict(
14021402 variable_stride_per_key : bool ,
14031403 weights : Optional [torch .Tensor ],
14041404 jt_dict : Optional [Dict [str , JaggedTensor ]],
1405+ compute_offsets : bool = True ,
14051406) -> Dict [str , JaggedTensor ]:
14061407 if not length_per_key :
14071408 return {}
@@ -1418,50 +1419,51 @@ def _maybe_compute_kjt_to_jt_dict(
14181419 torch ._check (cat_size <= total_size )
14191420 torch ._check (cat_size == total_size )
14201421 torch ._check_is_size (stride )
1422+
14211423 values_list = torch .split (values , length_per_key )
1424+ split_offsets : tuple [torch .Tensor ] | None = None
14221425 if variable_stride_per_key :
14231426 split_lengths = torch .split (lengths , stride_per_key )
1424- split_offsets = [
1425- torch .ops .fbgemm .asynchronous_complete_cumsum (lengths )
1426- for lengths in split_lengths
1427- ]
1427+ if compute_offsets :
1428+ split_offsets = [
1429+ torch .ops .fbgemm .asynchronous_complete_cumsum (lengths )
1430+ for lengths in split_lengths
1431+ ]
14281432 elif pt2_guard_size_oblivious (lengths .numel () > 0 ):
14291433 strided_lengths = lengths .view (len (keys ), stride )
14301434 if not torch .jit .is_scripting () and is_torchdynamo_compiling ():
14311435 torch ._check (strided_lengths .size (0 ) > 0 )
14321436 torch ._check (strided_lengths .size (1 ) > 0 )
1433- split_lengths = torch .unbind (
1434- strided_lengths ,
1435- dim = 0 ,
1436- )
1437- split_offsets = torch .unbind (
1438- _batched_lengths_to_offsets (strided_lengths ),
1439- dim = 0 ,
1440- )
1437+
1438+ split_lengths = torch .unbind (strided_lengths , dim = 0 )
1439+ if compute_offsets :
1440+ split_offsets = torch .unbind (
1441+ _batched_lengths_to_offsets (strided_lengths ), dim = 0
1442+ )
14411443 else :
14421444 split_lengths = torch .unbind (lengths , dim = 0 )
1443- split_offsets = torch .unbind (lengths , dim = 0 )
1445+ if compute_offsets :
1446+ split_offsets = split_lengths
14441447
14451448 if weights is not None :
14461449 weights_list = torch .split (weights , length_per_key )
14471450 for idx , key in enumerate (keys ):
1448- length = split_lengths [idx ]
1449- offset = split_offsets [idx ]
14501451 _jt_dict [key ] = JaggedTensor (
1451- lengths = length ,
1452- offsets = offset ,
1452+ lengths = split_lengths [idx ],
1453+ # pyre-ignore[16]: Undefined attribute: None has no attribute __getitem__
1454+ offsets = split_offsets [idx ] if split_offsets is not None else None ,
14531455 values = values_list [idx ],
14541456 weights = weights_list [idx ],
14551457 )
14561458 else :
14571459 for idx , key in enumerate (keys ):
1458- length = split_lengths [idx ]
1459- offset = split_offsets [idx ]
14601460 _jt_dict [key ] = JaggedTensor (
1461- lengths = length ,
1462- offsets = offset ,
1461+ lengths = split_lengths [idx ],
1462+ # pyre-ignore[16]: Undefined attribute: None has no attribute __getitem__
1463+ offsets = split_offsets [idx ] if split_offsets is not None else None ,
14631464 values = values_list [idx ],
14641465 )
1466+
14651467 return _jt_dict
14661468
14671469
@@ -2698,11 +2700,14 @@ def __getitem__(self, key: str) -> JaggedTensor:
26982700 offsets = None ,
26992701 )
27002702
2701- def to_dict (self ) -> Dict [str , JaggedTensor ]:
2703+ def to_dict (self , compute_offsets : bool = True ) -> Dict [str , JaggedTensor ]:
27022704 """
27032705 Returns a dictionary of JaggedTensor for each key.
27042706 Will cache result in self._jt_dict.
27052707
2708+ Args:
2709+ compute_offsets (str): compute offsets when true.
2710+
27062711 Returns:
27072712 Dict[str, JaggedTensor]: dictionary of JaggedTensor for each key.
27082713 """
@@ -2720,6 +2725,7 @@ def to_dict(self) -> Dict[str, JaggedTensor]:
27202725 variable_stride_per_key = self .variable_stride_per_key (),
27212726 weights = self .weights_or_none (),
27222727 jt_dict = self ._jt_dict ,
2728+ compute_offsets = compute_offsets ,
27232729 )
27242730 self ._jt_dict = _jt_dict
27252731 return _jt_dict
0 commit comments