Skip to content

Commit 4ed9d8a

Browse files
aliafzalfacebook-github-bot
authored andcommitted
Add plan paste for easily accessing the sharding plan (#3389)
Summary: internal This diff introduces additional logging of the complete sharding plan in a human-readable format, making it easily accessible for any review and analysis through the planner db dataset. Differential Revision: D82945862
1 parent f48f626 commit 4ed9d8a

File tree

1 file changed

+17
-9
lines changed

1 file changed

+17
-9
lines changed

torchrec/distributed/planner/types.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -669,15 +669,23 @@ def __deepcopy__(
669669
return result
670670

671671
def __str__(self) -> str:
672-
str_obj: str = ""
673-
str_obj += f"name: {self.name}"
674-
str_obj += f"\nsharding type: {self.sharding_type}"
675-
str_obj += f"\ncompute kernel: {self.compute_kernel}"
676-
str_obj += f"\nnum shards: {len(self.shards)}"
677-
for shard in self.shards:
678-
str_obj += f"\n\t{str(shard)}"
679-
680-
return str_obj
672+
tensor_metadata = f"{{shape: {tuple(self.tensor.shape)}, dtype: {self.tensor.dtype}, device: {self.tensor.device}}}"
673+
shards_str = f"[{', '.join(str(shard) for shard in self.shards)}]"
674+
675+
return f"""{{
676+
"name": "{self.name}",
677+
"module_fqn": "{self.module[0]}",
678+
"tensor": {tensor_metadata},
679+
"input_lengths": {self.input_lengths},
680+
"batch_size": {self.batch_size},
681+
"sharding_type": "{self.sharding_type}",
682+
"compute_kernel": "{self.compute_kernel}",
683+
"shards": {shards_str},
684+
"is_pooled": {self.is_pooled},
685+
"feature_names": {self.feature_names},
686+
"cache_params": {self.cache_params},
687+
"is_weighted": {self.is_weighted}
688+
}}"""
681689

682690

683691
class PartitionByType(Enum):

0 commit comments

Comments
 (0)