@@ -1402,6 +1402,7 @@ def _maybe_compute_kjt_to_jt_dict(
14021402 variable_stride_per_key : bool ,
14031403 weights : Optional [torch .Tensor ],
14041404 jt_dict : Optional [Dict [str , JaggedTensor ]],
1405+ compute_offsets : bool = True ,
14051406) -> Dict [str , JaggedTensor ]:
14061407 if not length_per_key :
14071408 return {}
@@ -1418,50 +1419,50 @@ def _maybe_compute_kjt_to_jt_dict(
14181419 torch ._check (cat_size <= total_size )
14191420 torch ._check (cat_size == total_size )
14201421 torch ._check_is_size (stride )
1422+
14211423 values_list = torch .split (values , length_per_key )
1424+ split_offsets : list [torch .Tensor ] | None = None
14221425 if variable_stride_per_key :
14231426 split_lengths = torch .split (lengths , stride_per_key )
1424- split_offsets = [
1425- torch .ops .fbgemm .asynchronous_complete_cumsum (lengths )
1426- for lengths in split_lengths
1427- ]
1427+ if compute_offsets :
1428+ split_offsets = [
1429+ torch .ops .fbgemm .asynchronous_complete_cumsum (lengths )
1430+ for lengths in split_lengths
1431+ ]
14281432 elif pt2_guard_size_oblivious (lengths .numel () > 0 ):
14291433 strided_lengths = lengths .view (len (keys ), stride )
14301434 if not torch .jit .is_scripting () and is_torchdynamo_compiling ():
14311435 torch ._check (strided_lengths .size (0 ) > 0 )
14321436 torch ._check (strided_lengths .size (1 ) > 0 )
1433- split_lengths = torch .unbind (
1434- strided_lengths ,
1435- dim = 0 ,
1436- )
1437- split_offsets = torch .unbind (
1438- _batched_lengths_to_offsets (strided_lengths ),
1439- dim = 0 ,
1440- )
1437+
1438+ split_lengths = torch .unbind (strided_lengths , dim = 0 )
1439+ if compute_offsets :
1440+ split_offsets = torch .unbind ( # pyre-ignore
1441+ _batched_lengths_to_offsets (strided_lengths ), dim = 0
1442+ )
14411443 else :
14421444 split_lengths = torch .unbind (lengths , dim = 0 )
1443- split_offsets = torch .unbind (lengths , dim = 0 )
1445+ if compute_offsets :
1446+ split_offsets = split_lengths # pyre-ignore
14441447
14451448 if weights is not None :
14461449 weights_list = torch .split (weights , length_per_key )
14471450 for idx , key in enumerate (keys ):
1448- length = split_lengths [idx ]
1449- offset = split_offsets [idx ]
14501451 _jt_dict [key ] = JaggedTensor (
1451- lengths = length ,
1452- offsets = offset ,
1452+ lengths = split_lengths [idx ],
1453+ # pyre-ignore[16]: Undefined attribute: None has no attribute __getitem__
1454+ offsets = split_offsets [idx ] if split_offsets is not None else None ,
14531455 values = values_list [idx ],
14541456 weights = weights_list [idx ],
14551457 )
14561458 else :
14571459 for idx , key in enumerate (keys ):
1458- length = split_lengths [idx ]
1459- offset = split_offsets [idx ]
14601460 _jt_dict [key ] = JaggedTensor (
1461- lengths = length ,
1462- offsets = offset ,
1461+ lengths = split_lengths [ idx ] ,
1462+ offsets = split_offsets [ idx ] if split_offsets is not None else None ,
14631463 values = values_list [idx ],
14641464 )
1465+
14651466 return _jt_dict
14661467
14671468
@@ -2698,11 +2699,14 @@ def __getitem__(self, key: str) -> JaggedTensor:
26982699 offsets = None ,
26992700 )
27002701
2701- def to_dict (self ) -> Dict [str , JaggedTensor ]:
2702+ def to_dict (self , compute_offsets : bool = True ) -> Dict [str , JaggedTensor ]:
27022703 """
27032704 Returns a dictionary of JaggedTensor for each key.
27042705 Will cache result in self._jt_dict.
27052706
2707+ Args:
2708+ compute_offsets (str): compute offsets when true.
2709+
27062710 Returns:
27072711 Dict[str, JaggedTensor]: dictionary of JaggedTensor for each key.
27082712 """
@@ -2720,6 +2724,7 @@ def to_dict(self) -> Dict[str, JaggedTensor]:
27202724 variable_stride_per_key = self .variable_stride_per_key (),
27212725 weights = self .weights_or_none (),
27222726 jt_dict = self ._jt_dict ,
2727+ compute_offsets = compute_offsets ,
27232728 )
27242729 self ._jt_dict = _jt_dict
27252730 return _jt_dict
0 commit comments