Skip to content

Commit 8adb0e7

Browse files
aliafzalfacebook-github-bot
authored andcommitted
PlanLoader addition into planner (meta-pytorch#3355)
Summary: Pull Request resolved: meta-pytorch#3355 **Summary:** * Added PlanLoader abstract base class to enable loading pre-computed sharding plans from stored locations within planner. * Supports two key scenarios: 1. Reusing previously computed and stored sharding plans to avoid regeneration costs 2. Using sharding plans from previous runs as starting points for iterative improvements * Defines two abstract methods: * `load()`: Returns a dictionary mapping sharding option hashes to ShardingOption objects * `plan_validation_str()`: Provides validation string for plan integrity checks * Part of the broader effort to improve planner UX and reliability by enabling plan persistence and reuse across training runs Differential Revision: D81571293
1 parent a1689cd commit 8adb0e7

File tree

1 file changed

+34
-0
lines changed

1 file changed

+34
-0
lines changed

torchrec/distributed/planner/types.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -995,6 +995,40 @@ def log(
995995
...
996996

997997

998+
class PlanLoader(abc.ABC):
999+
"""
1000+
Retrieves a pre-computed sharding plan from its stored location. This is useful in two scenarios:
1001+
1. To utilize a specific sharding plan that was previously computed and stored, saving the cost of re-generating the plan
1002+
2. To use a sharding plan from previous runs as a starting point for the next run, allowing for improvement over time.
1003+
"""
1004+
1005+
@abc.abstractmethod
1006+
def load(
1007+
self,
1008+
) -> Optional[Dict[int, ShardingOption]]:
1009+
"""
1010+
Load sharding plan from its stored location.
1011+
1012+
Returns:
1013+
Dict[int, ShardingOption]: loaded sharding plan. key is hash of sharding option to map to sharding option with enumerated sharding option.
1014+
"""
1015+
...
1016+
1017+
@abc.abstractmethod
1018+
def plan_context_hash(
1019+
self,
1020+
) -> Optional[str]:
1021+
"""
1022+
Input context hash of a sharding plan.
1023+
1024+
Returns:
1025+
str: hash of sharding plan context.
1026+
"""
1027+
...
1028+
1029+
...
1030+
1031+
9981032
@dataclass
9991033
class CriticalPathEstimate:
10001034
comms_estimate: float

0 commit comments

Comments
 (0)