@@ -1402,6 +1402,7 @@ def _maybe_compute_kjt_to_jt_dict(
14021402 variable_stride_per_key : bool ,
14031403 weights : Optional [torch .Tensor ],
14041404 jt_dict : Optional [Dict [str , JaggedTensor ]],
1405+ compute_offsets : bool = True ,
14051406) -> Dict [str , JaggedTensor ]:
14061407 if not length_per_key :
14071408 return {}
@@ -1418,50 +1419,49 @@ def _maybe_compute_kjt_to_jt_dict(
14181419 torch ._check (cat_size <= total_size )
14191420 torch ._check (cat_size == total_size )
14201421 torch ._check_is_size (stride )
1422+
14211423 values_list = torch .split (values , length_per_key )
1424+ split_offsets : list [torch .Tensor ] = []
14221425 if variable_stride_per_key :
14231426 split_lengths = torch .split (lengths , stride_per_key )
1424- split_offsets = [
1425- torch .ops .fbgemm .asynchronous_complete_cumsum (lengths )
1426- for lengths in split_lengths
1427- ]
1427+ if compute_offsets :
1428+ split_offsets = [
1429+ torch .ops .fbgemm .asynchronous_complete_cumsum (lengths )
1430+ for lengths in split_lengths
1431+ ]
14281432 elif pt2_guard_size_oblivious (lengths .numel () > 0 ):
14291433 strided_lengths = lengths .view (len (keys ), stride )
14301434 if not torch .jit .is_scripting () and is_torchdynamo_compiling ():
14311435 torch ._check (strided_lengths .size (0 ) > 0 )
14321436 torch ._check (strided_lengths .size (1 ) > 0 )
1433- split_lengths = torch .unbind (
1434- strided_lengths ,
1435- dim = 0 ,
1436- )
1437- split_offsets = torch .unbind (
1438- _batched_lengths_to_offsets (strided_lengths ),
1439- dim = 0 ,
1440- )
1437+
1438+ split_lengths = torch .unbind (strided_lengths , dim = 0 )
1439+ if compute_offsets :
1440+ split_offsets = torch .unbind ( # pyre-ignore
1441+ _batched_lengths_to_offsets (strided_lengths ), dim = 0
1442+ )
14411443 else :
14421444 split_lengths = torch .unbind (lengths , dim = 0 )
1443- split_offsets = torch .unbind (lengths , dim = 0 )
1445+ if compute_offsets :
1446+ split_offsets = split_lengths # pyre-ignore
14441447
14451448 if weights is not None :
14461449 weights_list = torch .split (weights , length_per_key )
14471450 for idx , key in enumerate (keys ):
1448- length = split_lengths [idx ]
1449- offset = split_offsets [idx ]
14501451 _jt_dict [key ] = JaggedTensor (
1451- lengths = length ,
1452- offsets = offset ,
1452+ lengths = split_lengths [ idx ] ,
1453+ offsets = split_offsets [ idx ] if compute_offsets else None ,
14531454 values = values_list [idx ],
14541455 weights = weights_list [idx ],
14551456 )
14561457 else :
14571458 for idx , key in enumerate (keys ):
1458- length = split_lengths [idx ]
1459- offset = split_offsets [idx ]
14601459 _jt_dict [key ] = JaggedTensor (
1461- lengths = length ,
1462- offsets = offset ,
1460+ lengths = split_lengths [ idx ] ,
1461+ offsets = split_offsets [ idx ] if compute_offsets else None ,
14631462 values = values_list [idx ],
14641463 )
1464+
14651465 return _jt_dict
14661466
14671467
@@ -2698,11 +2698,14 @@ def __getitem__(self, key: str) -> JaggedTensor:
26982698 offsets = None ,
26992699 )
27002700
2701- def to_dict (self ) -> Dict [str , JaggedTensor ]:
2701+ def to_dict (self , compute_offsets : bool = True ) -> Dict [str , JaggedTensor ]:
27022702 """
27032703 Returns a dictionary of JaggedTensor for each key.
27042704 Will cache result in self._jt_dict.
27052705
2706+ Args:
2707+ compute_offsets (str): compute offsets when true.
2708+
27062709 Returns:
27072710 Dict[str, JaggedTensor]: dictionary of JaggedTensor for each key.
27082711 """
@@ -2720,6 +2723,7 @@ def to_dict(self) -> Dict[str, JaggedTensor]:
27202723 variable_stride_per_key = self .variable_stride_per_key (),
27212724 weights = self .weights_or_none (),
27222725 jt_dict = self ._jt_dict ,
2726+ compute_offsets = compute_offsets ,
27232727 )
27242728 self ._jt_dict = _jt_dict
27252729 return _jt_dict
0 commit comments