Skip to content

Commit 933ee24

Browse files
aliafzalfacebook-github-bot
authored andcommitted
Integrating planner stats db with ConfigeratorStats (#3331)
Summary: internal Context: Planner stats db is introduced in this diff to track metadata and perf metrics associated with sharding plan. This Diff: 1. Added methods to insert, select and delete planner stats db row. 2. UTs for planner stats db 3. Integration of planner stats db with ConfigeratorStats Reviewed By: ge0405 Differential Revision: D81216987
1 parent de5a17b commit 933ee24

File tree

3 files changed

+71
-0
lines changed

3 files changed

+71
-0
lines changed

torchrec/distributed/planner/planners.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
ParameterConstraints,
4343
Partitioner,
4444
PerfModel,
45+
PlanDebugStats,
4546
PlannerError,
4647
PlannerErrorType,
4748
Proposer,
@@ -528,6 +529,10 @@ def plan(
528529
enumerator=self._enumerator,
529530
sharders=sharders,
530531
debug=self._debug,
532+
debug_stats=PlanDebugStats(
533+
planner_type=self.__class__.__name__,
534+
timeout_seconds=self._timeout_seconds,
535+
),
531536
)
532537
return sharding_plan
533538
else:

torchrec/distributed/planner/stats.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
Enumerator,
4141
ParameterConstraints,
4242
Perf,
43+
PlanDebugStats,
4344
ShardingOption,
4445
Stats,
4546
Storage,
@@ -160,6 +161,7 @@ def log(
160161
sharders: Optional[List[ModuleSharder[nn.Module]]] = None,
161162
enumerator: Optional[Enumerator] = None,
162163
debug: bool = True,
164+
debug_stats: Optional[PlanDebugStats] = None,
163165
) -> None:
164166
"""
165167
Logs stats for a given sharding plan.
@@ -1138,5 +1140,6 @@ def log(
11381140
sharders: Optional[List[ModuleSharder[nn.Module]]] = None,
11391141
enumerator: Optional[Enumerator] = None,
11401142
debug: bool = True,
1143+
debug_stats: Optional[PlanDebugStats] = None,
11411144
) -> None:
11421145
pass

torchrec/distributed/planner/types.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -960,6 +960,16 @@ def partition(
960960
...
961961

962962

963+
@dataclass
964+
class PlanDebugStats:
965+
"""
966+
Representation of debug stats associated with a sharding plan, used for logging.
967+
"""
968+
969+
planner_type: str
970+
timeout_seconds: Optional[int]
971+
972+
963973
class Stats(abc.ABC):
964974
"""
965975
Logs statistics related to the sharding plan.
@@ -980,6 +990,7 @@ def log(
980990
sharders: Optional[List[ModuleSharder[nn.Module]]] = None,
981991
enumerator: Optional[Enumerator] = None,
982992
debug: bool = False,
993+
debug_stats: Optional[PlanDebugStats] = None,
983994
) -> None:
984995
"""
985996
See class description
@@ -1007,6 +1018,16 @@ def hash_sha256_to_int(hashable_list: List[Any]) -> int: # pyre-ignore
10071018
return int(hash_digest, 16)
10081019

10091020

1021+
def hash_sha256_str(hashable_list: List[Any]) -> str: # pyre-ignore
1022+
"""
1023+
Hashes the given data using SHA256 and returns the hash as an string
1024+
"""
1025+
serialized_list = str(hashable_list).encode("utf-8")
1026+
hash_object = hashlib.sha256(serialized_list)
1027+
hash_digest = hash_object.hexdigest()
1028+
return hash_digest
1029+
1030+
10101031
def hash_planner_context_inputs(
10111032
topology: Topology,
10121033
batch_size: int,
@@ -1047,3 +1068,45 @@ def hash_planner_context_inputs(
10471068
constraints.items() if constraints else None,
10481069
]
10491070
return hash_function(hashable_list)
1071+
1072+
1073+
def hash_planner_context_inputs_str(
1074+
topology: Topology,
1075+
batch_size: int,
1076+
enumerator: Enumerator,
1077+
storage_reservation: StorageReservation,
1078+
constraints: Optional[Dict[str, ParameterConstraints]],
1079+
# pyre-ignore
1080+
hash_function: Callable[[List[Any]], str] = hash_sha256_str,
1081+
) -> str:
1082+
assert hasattr(
1083+
enumerator, "last_stored_search_space"
1084+
), "This enumerator is not compatible with hashing"
1085+
assert (
1086+
enumerator.last_stored_search_space is not None # pyre-ignore
1087+
), "Unable to hash planner context without an enumerator that has a precomputed search space"
1088+
search_space = enumerator.last_stored_search_space
1089+
storage_reservation_policy = type(storage_reservation).__name__
1090+
1091+
assert (
1092+
storage_reservation._last_reserved_topology is not None # pyre-ignore
1093+
), "Unable to hash planner context without a storage reservation that has a precomputed topology"
1094+
1095+
hashable_list = [
1096+
topology,
1097+
batch_size,
1098+
[
1099+
[
1100+
shard_option.fqn,
1101+
shard_option.sharding_type,
1102+
shard_option.compute_kernel,
1103+
tuple(shard_option.shards),
1104+
shard_option.cache_params,
1105+
]
1106+
for shard_option in search_space
1107+
],
1108+
storage_reservation_policy,
1109+
storage_reservation._last_reserved_topology,
1110+
constraints.items() if constraints else None,
1111+
]
1112+
return hash_function(hashable_list)

0 commit comments

Comments
 (0)