@@ -960,6 +960,16 @@ def partition(
960
960
...
961
961
962
962
963
+ @dataclass
964
+ class PlanDebugStats :
965
+ """
966
+ Representation of debug stats associated with a sharding plan, used for logging.
967
+ """
968
+
969
+ planner_type : str
970
+ timeout_seconds : Optional [int ]
971
+
972
+
963
973
class Stats (abc .ABC ):
964
974
"""
965
975
Logs statistics related to the sharding plan.
@@ -980,6 +990,7 @@ def log(
980
990
sharders : Optional [List [ModuleSharder [nn .Module ]]] = None ,
981
991
enumerator : Optional [Enumerator ] = None ,
982
992
debug : bool = False ,
993
+ debug_stats : Optional [PlanDebugStats ] = None ,
983
994
) -> None :
984
995
"""
985
996
See class description
@@ -1007,6 +1018,16 @@ def hash_sha256_to_int(hashable_list: List[Any]) -> int: # pyre-ignore
1007
1018
return int (hash_digest , 16 )
1008
1019
1009
1020
1021
+ def hash_sha256_str (hashable_list : List [Any ]) -> str : # pyre-ignore
1022
+ """
1023
+ Hashes the given data using SHA256 and returns the hash as an string
1024
+ """
1025
+ serialized_list = str (hashable_list ).encode ("utf-8" )
1026
+ hash_object = hashlib .sha256 (serialized_list )
1027
+ hash_digest = hash_object .hexdigest ()
1028
+ return hash_digest
1029
+
1030
+
1010
1031
def hash_planner_context_inputs (
1011
1032
topology : Topology ,
1012
1033
batch_size : int ,
@@ -1047,3 +1068,45 @@ def hash_planner_context_inputs(
1047
1068
constraints .items () if constraints else None ,
1048
1069
]
1049
1070
return hash_function (hashable_list )
1071
+
1072
+
1073
+ def hash_planner_context_inputs_str (
1074
+ topology : Topology ,
1075
+ batch_size : int ,
1076
+ enumerator : Enumerator ,
1077
+ storage_reservation : StorageReservation ,
1078
+ constraints : Optional [Dict [str , ParameterConstraints ]],
1079
+ # pyre-ignore
1080
+ hash_function : Callable [[List [Any ]], str ] = hash_sha256_str ,
1081
+ ) -> str :
1082
+ assert hasattr (
1083
+ enumerator , "last_stored_search_space"
1084
+ ), "This enumerator is not compatible with hashing"
1085
+ assert (
1086
+ enumerator .last_stored_search_space is not None # pyre-ignore
1087
+ ), "Unable to hash planner context without an enumerator that has a precomputed search space"
1088
+ search_space = enumerator .last_stored_search_space
1089
+ storage_reservation_policy = type (storage_reservation ).__name__
1090
+
1091
+ assert (
1092
+ storage_reservation ._last_reserved_topology is not None # pyre-ignore
1093
+ ), "Unable to hash planner context without a storage reservation that has a precomputed topology"
1094
+
1095
+ hashable_list = [
1096
+ topology ,
1097
+ batch_size ,
1098
+ [
1099
+ [
1100
+ shard_option .fqn ,
1101
+ shard_option .sharding_type ,
1102
+ shard_option .compute_kernel ,
1103
+ tuple (shard_option .shards ),
1104
+ shard_option .cache_params ,
1105
+ ]
1106
+ for shard_option in search_space
1107
+ ],
1108
+ storage_reservation_policy ,
1109
+ storage_reservation ._last_reserved_topology ,
1110
+ constraints .items () if constraints else None ,
1111
+ ]
1112
+ return hash_function (hashable_list )
0 commit comments