diff --git a/torchrec/distributed/embedding_kernel.py b/torchrec/distributed/embedding_kernel.py index e444f59c8..09a2c0375 100644 --- a/torchrec/distributed/embedding_kernel.py +++ b/torchrec/distributed/embedding_kernel.py @@ -105,7 +105,10 @@ def create_virtual_table_global_metadata( # The param size only has the information for my_rank. In order to # correctly calculate the size for other ranks, we need to use the current # rank's shard size compared to the shard size of my_rank. - curr_rank_rows = (param.size()[0] * metadata.shards_metadata[rank].shard_sizes[0]) // my_rank_shard_size # pyre-ignore[16] + curr_rank_rows = ( + param.size()[0] # pyre-ignore[16] + * metadata.shards_metadata[rank].shard_sizes[0] + ) // my_rank_shard_size else: curr_rank_rows = ( weight_count_per_rank[rank] if weight_count_per_rank is not None else 1 diff --git a/torchrec/distributed/planner/planners.py b/torchrec/distributed/planner/planners.py index 6f0b3da2f..8b48865b6 100644 --- a/torchrec/distributed/planner/planners.py +++ b/torchrec/distributed/planner/planners.py @@ -39,10 +39,12 @@ from torchrec.distributed.planner.types import ( Enumerator, hash_planner_context_inputs, + hash_planner_context_inputs_str, ParameterConstraints, Partitioner, PerfModel, PlanDebugStats, + PlanLoader, PlannerError, PlannerErrorType, Proposer, @@ -118,6 +120,60 @@ def to_sharding_plan( return ShardingPlan(plan) +def extract_plan( + search_space: List[ShardingOption], + loaded_sharding_options: Dict[int, ShardingOption], +) -> List[ShardingOption]: + + new_search_space: List[ShardingOption] = [] + seen_hash_set = set() + + for so in search_space: + + # Validate that the storage hash is unique and isn't mapped to multiple sharding options + if so.storage_hash() in seen_hash_set: + raise PlannerError( + error_type=PlannerErrorType.PLAN_LOADING_FAILED, + message=f"Found a duplicate storage hash {so.storage_hash()} for FQNs {[so.fqn for so in search_space]}\n", + ) + else: + seen_hash_set.add(so.storage_hash()) + + loaded_so = loaded_sharding_options.get(so.storage_hash()) + if loaded_so is not None: + new_search_space.append( + ShardingOption( + name=so.name, + tensor=so.tensor, + module=so.module, + input_lengths=so.input_lengths, + batch_size=so.batch_size, + compute_kernel=so.compute_kernel, + sharding_type=so.sharding_type, + partition_by=so.partition_by, + # We only need to update the shards from the loaded plan + shards=loaded_so.shards, + cache_params=so.cache_params, + enforce_hbm=so.enforce_hbm, + stochastic_rounding=so.stochastic_rounding, + bounds_check_mode=so.bounds_check_mode, + dependency=so.dependency, + is_pooled=so.is_pooled, + feature_names=so.feature_names, + output_dtype=so.output_dtype, + key_value_params=so.key_value_params, + ) + ) + + # Validate that populated search space is the same size as the enumerated search space + if len(loaded_sharding_options) != len(new_search_space): + raise PlannerError( + error_type=PlannerErrorType.PLAN_LOADING_FAILED, + message=f"Loaded sharding options from Storage, but not all search space is covered. Merged search space len {len(new_search_space)} != loaded Sharding options len {len(loaded_sharding_options)}\n", + ) + return new_search_space + + def _merge_plans(best_plans: List[ShardingPlan]) -> ShardingPlan: if len(best_plans) == 1: return best_plans[0] @@ -269,6 +325,22 @@ def hash_planner_context_inputs(self) -> int: self._constraints, ) + def hash_planner_context_inputs_str(self) -> str: + """ + Generates a hash for all planner inputs except for partitioner, proposer, performance model, and stats. + These are all the inputs needed to verify whether a previously generated sharding plan is still valid in a new context. + + Returns: + Generates a hash capturing topology, batch size, enumerator, storage reservation, stats and constraints. + """ + return hash_planner_context_inputs_str( + self._topology, + self._batch_size, + self._enumerator, + self._storage_reservation, + self._constraints, + ) + class EmbeddingShardingPlanner(EmbeddingPlannerBase): """ @@ -315,6 +387,7 @@ def __init__( List[Callable[[List[ShardingOption]], List[ShardingOption]]] ] = None, timeout_seconds: Optional[int] = None, + plan_loader: Optional[PlanLoader] = None, ) -> None: super().__init__( topology=topology, @@ -347,6 +420,8 @@ def __init__( else NoopPerfModel(topology=self._topology) ) + self.plan_loader = plan_loader + self._num_proposals: int = 0 self._num_plans: int = 0 self._best_plan: Optional[List[ShardingOption]] = None @@ -427,86 +502,113 @@ def plan( # No shardable parameters return ShardingPlan({}) - proposal_cache: Dict[ - Tuple[int, ...], - Tuple[bool, Optional[List[ShardingOption]], Optional[float]], - ] = {} - - for proposer in self._proposers: - proposer.load(search_space=search_space, enumerator=self._enumerator) - - start = time.time() - for proposer in self._proposers: - proposal = proposer.propose() - - while proposal: - end = time.time() - elapsed = end - start - if self._timeout_seconds: - if elapsed > self._timeout_seconds: - logger.info( - f"Exceeded time limit of {self._timeout_seconds}s. Took {elapsed}s" - ) - break - proposal_key = tuple(sorted(map(hash, proposal))) - if proposal_key in proposal_cache: - partitionable, plan, perf_rating = proposal_cache[proposal_key] - proposer.feedback( - partitionable=partitionable, - plan=plan, - perf_rating=perf_rating, - storage_constraint=storage_constraint, - ) - proposal = proposer.propose() - continue - - self._num_proposals += 1 - try: - # plan is just proposal where shard.rank is populated - plan = self._partitioner.partition( - proposal=proposal, - storage_constraint=storage_constraint, - ) - self._num_plans += 1 - perf_rating = self._perf_model.rate(plan=plan) - if perf_rating < best_perf_rating: - best_perf_rating = perf_rating - best_plan = copy.deepcopy(plan) - proposal_cache[proposal_key] = (True, plan, perf_rating) - proposer.feedback( - partitionable=True, - plan=plan, - perf_rating=perf_rating, - storage_constraint=storage_constraint, - ) - except PlannerError as planner_error: - last_planner_error = planner_error - # shallow copy of the proposal - last_proposal: List[ShardingOption] = copy.copy(proposal) - current_storage = cast( - Storage, - reduce( - lambda x, y: x + y, - [ - shard.storage - for option in proposal - for shard in option.shards - ], - ), - ) - if current_storage < lowest_storage: - lowest_storage = current_storage - proposal_cache[proposal_key] = (False, proposal, None) - proposer.feedback( - partitionable=False, - plan=proposal, - storage_constraint=storage_constraint, - ) + loaded_sharding_options = None + loaded_best_plan: List[ShardingOption] = [] + + if self.plan_loader is not None: + # validate plan before loading + self._loader_plan_validation( + current_planner_hash=self.hash_planner_context_inputs_str(), + # pyre-fixme[16]: `Optional` has no attribute `plan_context_hash`. + loaded_plan_hash=self.plan_loader.plan_context_hash(), + ) + # pyre-ignore + loaded_sharding_options = self.plan_loader.load() + if loaded_sharding_options is not None: + # Merging sharding options from loaded plan with enumerated search space + loaded_best_plan = extract_plan( + search_space=search_space, + loaded_sharding_options=loaded_sharding_options, + ) + + # Loaded plan is validated successfully and can be used for generate the sharding plan, skipping new plan generation. + if loaded_best_plan: + logger.info( + # pyre-ignore + f"Loded sharding options from Storage with plan id: {self.plan_loader.get_plan_id()} skipping new plan generation" + ) + best_plan = copy.deepcopy(loaded_best_plan) + else: + proposal_cache: Dict[ + Tuple[int, ...], + Tuple[bool, Optional[List[ShardingOption]], Optional[float]], + ] = {} + + for proposer in self._proposers: + proposer.load(search_space=search_space, enumerator=self._enumerator) - # clear shard.rank for each sharding_option - reset_shard_rank(proposal) + start = time.time() + for proposer in self._proposers: proposal = proposer.propose() + while proposal: + end = time.time() + elapsed = end - start + if self._timeout_seconds: + if elapsed > self._timeout_seconds: + logger.info( + f"Exceeded time limit of {self._timeout_seconds}s. Took {elapsed}s" + ) + break + proposal_key = tuple(sorted(map(hash, proposal))) + if proposal_key in proposal_cache: + partitionable, plan, perf_rating = proposal_cache[proposal_key] + proposer.feedback( + partitionable=partitionable, + plan=plan, + perf_rating=perf_rating, + storage_constraint=storage_constraint, + ) + proposal = proposer.propose() + continue + + self._num_proposals += 1 + try: + # plan is just proposal where shard.rank is populated + plan = self._partitioner.partition( + proposal=proposal, + storage_constraint=storage_constraint, + ) + self._num_plans += 1 + perf_rating = self._perf_model.rate(plan=plan) + if perf_rating < best_perf_rating: + best_perf_rating = perf_rating + best_plan = copy.deepcopy(plan) + proposal_cache[proposal_key] = (True, plan, perf_rating) + proposer.feedback( + partitionable=True, + plan=plan, + perf_rating=perf_rating, + storage_constraint=storage_constraint, + ) + except PlannerError as planner_error: + last_planner_error = planner_error + # shallow copy of the proposal + last_proposal: List[ShardingOption] = copy.copy(proposal) + current_storage = cast( + Storage, + reduce( + lambda x, y: x + y, + [ + shard.storage + for option in proposal + for shard in option.shards + ], + ), + ) + if current_storage < lowest_storage: + lowest_storage = current_storage + proposal_cache[proposal_key] = (False, proposal, None) + proposer.feedback( + partitionable=False, + plan=proposal, + storage_constraint=storage_constraint, + ) + + # clear shard.rank for each sharding_option + reset_shard_rank(proposal) + proposal = proposer.propose() + if best_plan: for callback in self._callbacks: best_plan = callback(best_plan) @@ -607,6 +709,32 @@ def plan( + last_planner_error_info, ) + def _loader_plan_validation( + self, current_planner_hash: str, loaded_plan_hash: Optional[str] + ) -> None: + """ + Validates that the current planner context hash matches the loaded plan context hash. + + Args: + current_planner_hash (str): Hash from current planner context + loaded_plan_hash (Optional[str]): Hash from loaded plan context + + Raises: + PlannerError: If hashes don't match + """ + if loaded_plan_hash is not None and current_planner_hash != loaded_plan_hash: + # pyre-fixme[16]: `Optional` has no attribute `get_plan_id`. + plan_id = self.plan_loader.get_plan_id() if self.plan_loader else None + error_msg = ( + f"Planner input context mismatch detected for {plan_id} and current planner set up:" + f"\nCurrent planner hash: {current_planner_hash}, Loaded plan hash: {loaded_plan_hash}" + ) + raise PlannerError( + error_type=PlannerErrorType.PLANNER_INPUT_CONTEXT_MISMATCH, + message="Unable to load, because of planner input mismatch - cannot validate this plan is the best plan for current context.. \n" + + error_msg, + ) + class HeteroEmbeddingShardingPlanner(ShardingPlanner): """ diff --git a/torchrec/distributed/planner/tests/test_planners.py b/torchrec/distributed/planner/tests/test_planners.py index 64f96c4d0..44cae7dbe 100644 --- a/torchrec/distributed/planner/tests/test_planners.py +++ b/torchrec/distributed/planner/tests/test_planners.py @@ -8,7 +8,7 @@ # pyre-strict import unittest -from typing import cast, List, Optional +from typing import cast, Dict, List, Optional import torch from torch import nn @@ -18,7 +18,7 @@ from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder from torchrec.distributed.planner.enumerators import EmbeddingEnumerator from torchrec.distributed.planner.perf_models import NoopPerfModel -from torchrec.distributed.planner.planners import EmbeddingShardingPlanner +from torchrec.distributed.planner.planners import EmbeddingShardingPlanner, extract_plan from torchrec.distributed.planner.proposers import EmbeddingOffloadScaleupProposer from torchrec.distributed.planner.stats import EmbeddingStats from torchrec.distributed.planner.storage_reservations import ( @@ -26,8 +26,10 @@ ) from torchrec.distributed.planner.types import ( ParameterConstraints, + PlanLoader, PlannerError, PlannerErrorType, + Shard, ShardingOption, Topology, ) @@ -44,6 +46,7 @@ ShardingPlan, ShardingType, ) +from torchrec.distributed.utils import none_throws from torchrec.modules.embedding_configs import EmbeddingBagConfig @@ -842,3 +845,420 @@ def test_planner_with_virtual_table(self) -> None: self.assertTrue( any("Min HBM: 0.016 GB on ranks [0, 1]" in line for line in stats) ) + + +class MockPlanLoader(PlanLoader): + """Mock PlanLoader implementation for testing.""" + + def __init__( + self, + loaded_sharding_options: Optional[Dict[int, ShardingOption]] = None, + context_hash: Optional[str] = None, + plan_id: str = "test_plan_123", + ) -> None: + self._loaded_sharding_options = loaded_sharding_options + self._context_hash = context_hash + self._plan_id = plan_id + + def load(self) -> Optional[Dict[int, ShardingOption]]: + return self._loaded_sharding_options + + def plan_context_hash(self) -> Optional[str]: + return self._context_hash + + def get_plan_id(self) -> str: + return self._plan_id + + +class TestPlanLoaderIntegration(unittest.TestCase): + def setUp(self) -> None: + compute_device = "cuda" + self.topology = Topology( + world_size=2, hbm_cap=1024 * 1024 * 2, compute_device=compute_device + ) + self.tables = [ + EmbeddingBagConfig( + num_embeddings=100, + embedding_dim=64, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(2) # Reduced to 2 tables for simplicity + ] + self.constraints = { + "table_0": ParameterConstraints( + enforce_hbm=True, + cache_params=CacheParams( + algorithm=CacheAlgorithm.LFU, + ), + feature_names=self.tables[0].feature_names, + ), + "table_1": ParameterConstraints( + enforce_hbm=False, + stochastic_rounding=True, + feature_names=self.tables[1].feature_names, + ), + } + self.model = TestSparseNN( + tables=self.tables, sparse_device=torch.device("meta") + ) + + def test_plan_loader_with_valid_plan(self) -> None: + """Test EmbeddingShardingPlanner with PlanLoader that provides a valid plan.""" + # First, create a planner without loader to generate a baseline plan + baseline_planner = EmbeddingShardingPlanner( + topology=self.topology, constraints=self.constraints + ) + baseline_plan = baseline_planner.plan( + module=self.model, sharders=get_default_sharders() + ) + + # Extract the best plan from baseline planner + best_plan = baseline_planner._best_plan + self.assertIsNotNone(best_plan) + + # Create loaded sharding options map from the best plan + loaded_sharding_options = {} + for so in best_plan: + # Modify the shards to simulate a loaded plan with different shard assignments + modified_shards = [ + Shard( + size=shard.size, + offset=shard.offset, + storage=shard.storage, + perf=shard.perf, + rank=( + 1 - shard.rank if shard.rank is not None else None + ), # Flip ranks + ) + for shard in so.shards + ] + loaded_so = ShardingOption( + name=so.name, + tensor=so.tensor, + module=so.module, + input_lengths=so.input_lengths, + batch_size=so.batch_size, + compute_kernel=so.compute_kernel, + sharding_type=so.sharding_type, + partition_by=so.partition_by, + shards=modified_shards, + cache_params=so.cache_params, + enforce_hbm=so.enforce_hbm, + stochastic_rounding=so.stochastic_rounding, + bounds_check_mode=so.bounds_check_mode, + feature_names=so.feature_names, + ) + loaded_sharding_options[so.storage_hash()] = loaded_so + + # Create mock plan loader with matching context hash + context_hash = baseline_planner.hash_planner_context_inputs_str() + mock_loader = MockPlanLoader( + loaded_sharding_options=loaded_sharding_options, + context_hash=context_hash, + ) + + # Create planner with plan loader + planner_with_loader = EmbeddingShardingPlanner( + topology=self.topology, + constraints=self.constraints, + plan_loader=mock_loader, + ) + + # Plan with loader should use the loaded plan + loaded_plan = planner_with_loader.plan( + module=self.model, sharders=get_default_sharders() + ) + + # Verify the plan was loaded (should have flipped rank assignments) + self.assertIsNotNone(loaded_plan) + self.assertEqual(len(loaded_plan.plan), len(baseline_plan.plan)) + + # Check that ranks were actually flipped in the loaded plan + for module_name, module_plan in loaded_plan.plan.items(): + baseline_module_plan = baseline_plan.plan[module_name] + for param_name, param_sharding in cast( + EmbeddingModuleShardingPlan, module_plan + ).items(): + baseline_param_sharding = cast( + EmbeddingModuleShardingPlan, baseline_module_plan + )[param_name] + # The ranks should be different (flipped) from baseline + self.assertNotEqual(param_sharding.ranks, baseline_param_sharding.ranks) + + def test_plan_loader_with_context_mismatch(self) -> None: + """Test EmbeddingShardingPlanner with PlanLoader that has mismatched context hash.""" + # Create mock plan loader with different context hash + mock_loader = MockPlanLoader( + loaded_sharding_options={}, + context_hash="mismatched_hash", + ) + + # Create planner with plan loader + planner_with_loader = EmbeddingShardingPlanner( + topology=self.topology, + constraints=self.constraints, + plan_loader=mock_loader, + ) + + # Planning should raise PlannerError due to context mismatch + with self.assertRaises(PlannerError) as context: + planner_with_loader.plan(module=self.model, sharders=get_default_sharders()) + + self.assertEqual( + context.exception.error_type, + PlannerErrorType.PLANNER_INPUT_CONTEXT_MISMATCH, + ) + self.assertIn("planner input mismatch", str(context.exception)) + + def test_plan_loader_with_no_loaded_options(self) -> None: + """Test EmbeddingShardingPlanner with PlanLoader that returns no loaded options.""" + # First get the correct context hash + baseline_planner = EmbeddingShardingPlanner( + topology=self.topology, constraints=self.constraints + ) + baseline_planner.plan(module=self.model, sharders=get_default_sharders()) + context_hash = baseline_planner.hash_planner_context_inputs_str() + + # Create mock plan loader with no loaded options but matching context + mock_loader = MockPlanLoader( + loaded_sharding_options=None, + context_hash=context_hash, + ) + + # Create planner with plan loader + planner_with_loader = EmbeddingShardingPlanner( + topology=self.topology, + constraints=self.constraints, + plan_loader=mock_loader, + ) + + # Planning should succeed and generate a new plan (no loading) + loaded_plan = planner_with_loader.plan( + module=self.model, sharders=get_default_sharders() + ) + + # Verify a plan was generated + self.assertIsNotNone(loaded_plan) + self.assertTrue(len(loaded_plan.plan) > 0) + + +class TestExtractPlan(unittest.TestCase): + def setUp(self) -> None: + compute_device = "cuda" + self.topology = Topology( + world_size=2, hbm_cap=1024 * 1024 * 2, compute_device=compute_device + ) + self.tables = [ + EmbeddingBagConfig( + num_embeddings=100, + embedding_dim=64, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + ) + for i in range(4) + ] + self.constraints = { + "table_0": ParameterConstraints( + enforce_hbm=True, + cache_params=CacheParams( + algorithm=CacheAlgorithm.LFU, + ), + feature_names=self.tables[0].feature_names, + ), + "table_1": ParameterConstraints( + enforce_hbm=False, + stochastic_rounding=True, + feature_names=self.tables[1].feature_names, + ), + "table_2": ParameterConstraints( + bounds_check_mode=BoundsCheckMode.FATAL, + feature_names=self.tables[2].feature_names, + ), + "table_3": ParameterConstraints( + cache_params=CacheParams( + algorithm=CacheAlgorithm.LFU, + load_factor=0.1, + reserved_memory=1.0, + precision=DataType.FP16, + ), + feature_names=self.tables[3].feature_names, + ), + } + self.planner = EmbeddingShardingPlanner( + topology=self.topology, constraints=self.constraints + ) + self.model = TestSparseNN( + tables=self.tables, sparse_device=torch.device("meta") + ) + self.sharding_plan = self.planner.plan( + module=self.model, sharders=get_default_sharders() + ) + + def _create_loaded_sharding_options_map( + self, best_plan: List[ShardingOption] + ) -> Dict[int, ShardingOption]: + """Creates a loaded sharding options map from enumerated sharding options.""" + loaded_map = {} + for so in best_plan: + sharding_options = ShardingOption( + name=so.name, + tensor=so.tensor, + module=so.module, + input_lengths=so.input_lengths, + sharding_type=so.sharding_type, + batch_size=so.batch_size, + partition_by=so.partition_by, + compute_kernel=so.compute_kernel, + shards=so.shards, + is_pooled=so.is_pooled, + feature_names=so.feature_names, + cache_params=so.cache_params, + ) + + loaded_map[so.storage_hash()] = sharding_options + + return loaded_map + + def test_extract_plan_success(self) -> None: + """Test successful extraction of plan.""" + enumerated_plan = ( + # pyre-ignore + self.planner._enumerator.last_stored_search_space + ) + best_plan = none_throws(self.planner._best_plan) + loaded_sharding_options = self._create_loaded_sharding_options_map(best_plan) + + result = extract_plan(enumerated_plan, loaded_sharding_options) + + self.assertEqual(len(result), len(best_plan)) + + for i, result_so in enumerate(result): + expected_so = best_plan[i] + self.assertEqual(result_so.name, expected_so.name) + self.assertEqual(result_so.tensor.shape, expected_so.tensor.shape) + self.assertEqual(result_so.tensor.dtype, expected_so.tensor.dtype) + self.assertEqual(result_so.tensor.device, expected_so.tensor.device) + self.assertEqual(result_so.module, expected_so.module) + self.assertEqual(result_so.input_lengths, expected_so.input_lengths) + self.assertEqual(result_so.batch_size, expected_so.batch_size) + self.assertEqual(result_so.compute_kernel, expected_so.compute_kernel) + self.assertEqual(result_so.sharding_type, expected_so.sharding_type) + self.assertEqual(result_so.partition_by, expected_so.partition_by) + self.assertEqual(result_so.shards, expected_so.shards) + self.assertEqual(result_so.is_pooled, expected_so.is_pooled) + self.assertEqual(result_so.feature_names, expected_so.feature_names) + self.assertEqual(result_so.cache_params, expected_so.cache_params) + + def test_extract_plan_duplicate_storage_hash_error(self) -> None: + """Test extract_plan failure when duplicate storage hashes exist.""" + # Create search space with duplicate storage hashes by modifying sharding options + # to have the same storage hash + enumerated_plan = ( + # pyre-ignore + self.planner._enumerator.last_stored_search_space + ) + best_plan = none_throws(self.planner._best_plan) + loaded_sharding_options = self._create_loaded_sharding_options_map(best_plan) + + # Create a search space with duplicate storage hashes by duplicating first option + duplicate_search_space = [ + enumerated_plan[0], + enumerated_plan[0], + ] # Same option twice + + with self.assertRaises(PlannerError) as context: + extract_plan(duplicate_search_space, loaded_sharding_options) + + self.assertEqual( + context.exception.error_type, PlannerErrorType.PLAN_LOADING_FAILED + ) + self.assertIn("Found a duplicate storage hash", str(context.exception)) + + def test_extract_plan_empty_search_space(self) -> None: + """Test extract_plan with empty search space.""" + result = extract_plan([], {}) + self.assertEqual(result, []) + + def test_extract_plan_empty_loaded_options(self) -> None: + """Test extract_plan with empty loaded options but non-empty search space.""" + enumerated_plan = ( + # pyre-ignore + self.planner._enumerator.last_stored_search_space + ) + + # When loaded options is empty, extract_plan should return empty list + # This is actually the correct behavior - no matching options means no extracted options + result = extract_plan(enumerated_plan, {}) + self.assertEqual(result, []) + + def test_extract_plan_excess_loaded_options(self) -> None: + """Test extract_plan when loaded options contain more entries than search space.""" + enumerated_plan = ( + # pyre-ignore + self.planner._enumerator.last_stored_search_space + ) + best_plan = none_throws(self.planner._best_plan) + loaded_sharding_options = self._create_loaded_sharding_options_map(best_plan) + + extra_so = ShardingOption( + name="extra_table", + tensor=torch.tensor([1, 2, 3], device=torch.device("meta")), + module=("extra_table.test", torch.nn.Module()), + input_lengths=[100], + batch_size=128, + compute_kernel="fused", + sharding_type=ShardingType.TABLE_WISE.value, + partition_by="uniform", + shards=[Shard(size=[100, 64], offset=[0, 0])], + feature_names=["extra_feature"], + ) + loaded_sharding_options[99999] = extra_so # Arbitrary hash that won't match + + with self.assertRaises(PlannerError) as context: + extract_plan(enumerated_plan, loaded_sharding_options) + + self.assertEqual( + context.exception.error_type, PlannerErrorType.PLAN_LOADING_FAILED + ) + self.assertIn("not all search space is covered", str(context.exception)) + + def test_extract_plan_properties_preservation(self) -> None: + """Test that extract_plan preserves all non-shard properties from search space.""" + enumerated_plan = ( + # pyre-ignore + self.planner._enumerator.last_stored_search_space + ) + best_plan = none_throws(self.planner._best_plan) + loaded_sharding_options = self._create_loaded_sharding_options_map(best_plan) + + # Modify loaded options to have different shards but keep other properties + for loaded_so in loaded_sharding_options.values(): + # Change the shard data to verify only shards are updated + loaded_so.shards = [ + Shard(size=[200, 128], offset=[0, 0], rank=0) # Different shard + ] + + result = extract_plan(enumerated_plan, loaded_sharding_options) + + # Verify that result has search space properties but loaded shards + for result_so in result: + # Find the matching search space option by storage hash + search_so = next( + so + for so in enumerated_plan + if so.storage_hash() == result_so.storage_hash() + ) + loaded_so = loaded_sharding_options[result_so.storage_hash()] + + # Properties from search space should be preserved + self.assertEqual(result_so.name, search_so.name) + self.assertEqual(result_so.compute_kernel, search_so.compute_kernel) + self.assertEqual(result_so.sharding_type, search_so.sharding_type) + self.assertEqual(result_so.batch_size, search_so.batch_size) + self.assertEqual(result_so.feature_names, search_so.feature_names) + + # Shards should come from loaded options + self.assertEqual(result_so.shards, loaded_so.shards) + self.assertEqual(len(result_so.shards), 1) + self.assertEqual(result_so.shards[0].size, [200, 128])