Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion torchrec/distributed/embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,10 @@ def create_virtual_table_global_metadata(
# The param size only has the information for my_rank. In order to
# correctly calculate the size for other ranks, we need to use the current
# rank's shard size compared to the shard size of my_rank.
curr_rank_rows = (param.size()[0] * metadata.shards_metadata[rank].shard_sizes[0]) // my_rank_shard_size # pyre-ignore[16]
curr_rank_rows = (
param.size()[0] # pyre-ignore[16]
* metadata.shards_metadata[rank].shard_sizes[0]
) // my_rank_shard_size
else:
curr_rank_rows = (
weight_count_per_rank[rank] if weight_count_per_rank is not None else 1
Expand Down
282 changes: 205 additions & 77 deletions torchrec/distributed/planner/planners.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,12 @@
from torchrec.distributed.planner.types import (
Enumerator,
hash_planner_context_inputs,
hash_planner_context_inputs_str,
ParameterConstraints,
Partitioner,
PerfModel,
PlanDebugStats,
PlanLoader,
PlannerError,
PlannerErrorType,
Proposer,
Expand Down Expand Up @@ -118,6 +120,60 @@ def to_sharding_plan(
return ShardingPlan(plan)


def extract_plan(
search_space: List[ShardingOption],
loaded_sharding_options: Dict[int, ShardingOption],
) -> List[ShardingOption]:

new_search_space: List[ShardingOption] = []
seen_hash_set = set()

for so in search_space:

# Validate that the storage hash is unique and isn't mapped to multiple sharding options
if so.storage_hash() in seen_hash_set:
raise PlannerError(
error_type=PlannerErrorType.PLAN_LOADING_FAILED,
message=f"Found a duplicate storage hash {so.storage_hash()} for FQNs {[so.fqn for so in search_space]}\n",
)
else:
seen_hash_set.add(so.storage_hash())

loaded_so = loaded_sharding_options.get(so.storage_hash())
if loaded_so is not None:
new_search_space.append(
ShardingOption(
name=so.name,
tensor=so.tensor,
module=so.module,
input_lengths=so.input_lengths,
batch_size=so.batch_size,
compute_kernel=so.compute_kernel,
sharding_type=so.sharding_type,
partition_by=so.partition_by,
# We only need to update the shards from the loaded plan
shards=loaded_so.shards,
cache_params=so.cache_params,
enforce_hbm=so.enforce_hbm,
stochastic_rounding=so.stochastic_rounding,
bounds_check_mode=so.bounds_check_mode,
dependency=so.dependency,
is_pooled=so.is_pooled,
feature_names=so.feature_names,
output_dtype=so.output_dtype,
key_value_params=so.key_value_params,
)
)

# Validate that populated search space is the same size as the enumerated search space
if len(loaded_sharding_options) != len(new_search_space):
raise PlannerError(
error_type=PlannerErrorType.PLAN_LOADING_FAILED,
message=f"Loaded sharding options from Storage, but not all search space is covered. Merged search space len {len(new_search_space)} != loaded Sharding options len {len(loaded_sharding_options)}\n",
)
return new_search_space


def _merge_plans(best_plans: List[ShardingPlan]) -> ShardingPlan:
if len(best_plans) == 1:
return best_plans[0]
Expand Down Expand Up @@ -269,6 +325,22 @@ def hash_planner_context_inputs(self) -> int:
self._constraints,
)

def hash_planner_context_inputs_str(self) -> str:
"""
Generates a hash for all planner inputs except for partitioner, proposer, performance model, and stats.
These are all the inputs needed to verify whether a previously generated sharding plan is still valid in a new context.

Returns:
Generates a hash capturing topology, batch size, enumerator, storage reservation, stats and constraints.
"""
return hash_planner_context_inputs_str(
self._topology,
self._batch_size,
self._enumerator,
self._storage_reservation,
self._constraints,
)


class EmbeddingShardingPlanner(EmbeddingPlannerBase):
"""
Expand Down Expand Up @@ -315,6 +387,7 @@ def __init__(
List[Callable[[List[ShardingOption]], List[ShardingOption]]]
] = None,
timeout_seconds: Optional[int] = None,
plan_loader: Optional[PlanLoader] = None,
) -> None:
super().__init__(
topology=topology,
Expand Down Expand Up @@ -347,6 +420,8 @@ def __init__(
else NoopPerfModel(topology=self._topology)
)

self.plan_loader = plan_loader

self._num_proposals: int = 0
self._num_plans: int = 0
self._best_plan: Optional[List[ShardingOption]] = None
Expand Down Expand Up @@ -427,86 +502,113 @@ def plan(
# No shardable parameters
return ShardingPlan({})

proposal_cache: Dict[
Tuple[int, ...],
Tuple[bool, Optional[List[ShardingOption]], Optional[float]],
] = {}

for proposer in self._proposers:
proposer.load(search_space=search_space, enumerator=self._enumerator)

start = time.time()
for proposer in self._proposers:
proposal = proposer.propose()

while proposal:
end = time.time()
elapsed = end - start
if self._timeout_seconds:
if elapsed > self._timeout_seconds:
logger.info(
f"Exceeded time limit of {self._timeout_seconds}s. Took {elapsed}s"
)
break
proposal_key = tuple(sorted(map(hash, proposal)))
if proposal_key in proposal_cache:
partitionable, plan, perf_rating = proposal_cache[proposal_key]
proposer.feedback(
partitionable=partitionable,
plan=plan,
perf_rating=perf_rating,
storage_constraint=storage_constraint,
)
proposal = proposer.propose()
continue

self._num_proposals += 1
try:
# plan is just proposal where shard.rank is populated
plan = self._partitioner.partition(
proposal=proposal,
storage_constraint=storage_constraint,
)
self._num_plans += 1
perf_rating = self._perf_model.rate(plan=plan)
if perf_rating < best_perf_rating:
best_perf_rating = perf_rating
best_plan = copy.deepcopy(plan)
proposal_cache[proposal_key] = (True, plan, perf_rating)
proposer.feedback(
partitionable=True,
plan=plan,
perf_rating=perf_rating,
storage_constraint=storage_constraint,
)
except PlannerError as planner_error:
last_planner_error = planner_error
# shallow copy of the proposal
last_proposal: List[ShardingOption] = copy.copy(proposal)
current_storage = cast(
Storage,
reduce(
lambda x, y: x + y,
[
shard.storage
for option in proposal
for shard in option.shards
],
),
)
if current_storage < lowest_storage:
lowest_storage = current_storage
proposal_cache[proposal_key] = (False, proposal, None)
proposer.feedback(
partitionable=False,
plan=proposal,
storage_constraint=storage_constraint,
)
loaded_sharding_options = None
loaded_best_plan: List[ShardingOption] = []

if self.plan_loader is not None:
# validate plan before loading
self._loader_plan_validation(
current_planner_hash=self.hash_planner_context_inputs_str(),
# pyre-fixme[16]: `Optional` has no attribute `plan_context_hash`.
loaded_plan_hash=self.plan_loader.plan_context_hash(),
)
# pyre-ignore
loaded_sharding_options = self.plan_loader.load()
if loaded_sharding_options is not None:
# Merging sharding options from loaded plan with enumerated search space
loaded_best_plan = extract_plan(
search_space=search_space,
loaded_sharding_options=loaded_sharding_options,
)

# Loaded plan is validated successfully and can be used for generate the sharding plan, skipping new plan generation.
if loaded_best_plan:
logger.info(
# pyre-ignore
f"Loded sharding options from Storage with plan id: {self.plan_loader.get_plan_id()} skipping new plan generation"
)
best_plan = copy.deepcopy(loaded_best_plan)
else:
proposal_cache: Dict[
Tuple[int, ...],
Tuple[bool, Optional[List[ShardingOption]], Optional[float]],
] = {}

for proposer in self._proposers:
proposer.load(search_space=search_space, enumerator=self._enumerator)

# clear shard.rank for each sharding_option
reset_shard_rank(proposal)
start = time.time()
for proposer in self._proposers:
proposal = proposer.propose()

while proposal:
end = time.time()
elapsed = end - start
if self._timeout_seconds:
if elapsed > self._timeout_seconds:
logger.info(
f"Exceeded time limit of {self._timeout_seconds}s. Took {elapsed}s"
)
break
proposal_key = tuple(sorted(map(hash, proposal)))
if proposal_key in proposal_cache:
partitionable, plan, perf_rating = proposal_cache[proposal_key]
proposer.feedback(
partitionable=partitionable,
plan=plan,
perf_rating=perf_rating,
storage_constraint=storage_constraint,
)
proposal = proposer.propose()
continue

self._num_proposals += 1
try:
# plan is just proposal where shard.rank is populated
plan = self._partitioner.partition(
proposal=proposal,
storage_constraint=storage_constraint,
)
self._num_plans += 1
perf_rating = self._perf_model.rate(plan=plan)
if perf_rating < best_perf_rating:
best_perf_rating = perf_rating
best_plan = copy.deepcopy(plan)
proposal_cache[proposal_key] = (True, plan, perf_rating)
proposer.feedback(
partitionable=True,
plan=plan,
perf_rating=perf_rating,
storage_constraint=storage_constraint,
)
except PlannerError as planner_error:
last_planner_error = planner_error
# shallow copy of the proposal
last_proposal: List[ShardingOption] = copy.copy(proposal)
current_storage = cast(
Storage,
reduce(
lambda x, y: x + y,
[
shard.storage
for option in proposal
for shard in option.shards
],
),
)
if current_storage < lowest_storage:
lowest_storage = current_storage
proposal_cache[proposal_key] = (False, proposal, None)
proposer.feedback(
partitionable=False,
plan=proposal,
storage_constraint=storage_constraint,
)

# clear shard.rank for each sharding_option
reset_shard_rank(proposal)
proposal = proposer.propose()

if best_plan:
for callback in self._callbacks:
best_plan = callback(best_plan)
Expand Down Expand Up @@ -607,6 +709,32 @@ def plan(
+ last_planner_error_info,
)

def _loader_plan_validation(
self, current_planner_hash: str, loaded_plan_hash: Optional[str]
) -> None:
"""
Validates that the current planner context hash matches the loaded plan context hash.

Args:
current_planner_hash (str): Hash from current planner context
loaded_plan_hash (Optional[str]): Hash from loaded plan context

Raises:
PlannerError: If hashes don't match
"""
if loaded_plan_hash is not None and current_planner_hash != loaded_plan_hash:
# pyre-fixme[16]: `Optional` has no attribute `get_plan_id`.
plan_id = self.plan_loader.get_plan_id() if self.plan_loader else None
error_msg = (
f"Planner input context mismatch detected for {plan_id} and current planner set up:"
f"\nCurrent planner hash: {current_planner_hash}, Loaded plan hash: {loaded_plan_hash}"
)
raise PlannerError(
error_type=PlannerErrorType.PLANNER_INPUT_CONTEXT_MISMATCH,
message="Unable to load, because of planner input mismatch - cannot validate this plan is the best plan for current context.. \n"
+ error_msg,
)


class HeteroEmbeddingShardingPlanner(ShardingPlanner):
"""
Expand Down
Loading
Loading