diff --git a/pr_split/planner/partitioning.py b/pr_split/planner/partitioning.py index 840109f..71d95e7 100644 --- a/pr_split/planner/partitioning.py +++ b/pr_split/planner/partitioning.py @@ -176,6 +176,147 @@ def _group_units_graph( return grouped_units +def _group_load(group_units: list[PartitionUnit]) -> int: + return sum(unit.loc for unit in group_units) + + +def _group_anchor_position(group_units: list[PartitionUnit]) -> int: + return min(unit.position for unit in group_units) + + +def _group_affinity( + left_group: list[PartitionUnit], right_group: list[PartitionUnit], priority: Priority +) -> int: + return sum( + _affinity_score(left_unit, right_unit, priority) + for left_unit in left_group + for right_unit in right_group + ) + + +def _shared_file_merge_is_contiguous( + grouped_units: list[list[PartitionUnit]], source_idx: int, target_idx: int +) -> bool: + source_files = {unit.file_path for unit in grouped_units[source_idx]} + target_files = {unit.file_path for unit in grouped_units[target_idx]} + shared_files = source_files & target_files + + if not shared_files: + return True + + for file_path in shared_files: + ordered_occurrences: list[tuple[int, int]] = [] + for group_idx, group_units in enumerate(grouped_units): + positions = [unit.position for unit in group_units if unit.file_path == file_path] + if positions: + ordered_occurrences.append((min(positions), group_idx)) + + occurrence_order = [group_idx for _, group_idx in sorted(ordered_occurrences)] + if abs(occurrence_order.index(source_idx) - occurrence_order.index(target_idx)) != 1: + return False + + return True + + +def _best_graph_merge_target( + grouped_units: list[list[PartitionUnit]], source_idx: int, *, settings: Settings +) -> int | None: + if settings.min_loc is None: + return None + + source_group = grouped_units[source_idx] + source_load = _group_load(source_group) + source_underflow = max(0, settings.min_loc - source_load) + if source_underflow == 0: + return None + + best_idx: int | None = None + best_key: tuple[int, int, int, int, int, int] | None = None + source_anchor = _group_anchor_position(source_group) + + for target_idx, target_group in enumerate(grouped_units): + if target_idx == source_idx: + continue + + merged_load = source_load + _group_load(target_group) + if merged_load > settings.max_loc: + continue + if not _shared_file_merge_is_contiguous(grouped_units, source_idx, target_idx): + continue + + current_underflow = source_underflow + max(0, settings.min_loc - _group_load(target_group)) + merged_underflow = max(0, settings.min_loc - merged_load) + if merged_underflow >= current_underflow: + continue + + target_anchor = _group_anchor_position(target_group) + merge_key = ( + int(merged_load >= settings.min_loc), + _group_affinity(source_group, target_group, settings.priority), + -abs(source_anchor - target_anchor), + -abs(settings.max_loc - merged_load), + -min(source_idx, target_idx), + -max(source_idx, target_idx), + ) + if best_key is None or merge_key > best_key: + best_idx = target_idx + best_key = merge_key + + return best_idx + + +def _merge_group_units( + grouped_units: list[list[PartitionUnit]], source_idx: int, target_idx: int +) -> list[list[PartitionUnit]]: + merged_group = sorted( + grouped_units[target_idx] + grouped_units[source_idx], + key=lambda unit: unit.position, + ) + repaired_groups = [list(group_units) for group_units in grouped_units] + repaired_groups[target_idx] = merged_group + del repaired_groups[source_idx] + return repaired_groups + + +def _repair_graph_min_loc( + grouped_units: list[list[PartitionUnit]], *, settings: Settings +) -> list[list[PartitionUnit]]: + if settings.min_loc is None or len(grouped_units) < 2: + return grouped_units + + repaired_groups = [list(group_units) for group_units in grouped_units] + + while True: + undersized_group_indices = sorted( + ( + group_idx + for group_idx, group_units in enumerate(repaired_groups) + if _group_load(group_units) < settings.min_loc + ), + key=lambda group_idx: ( + _group_load(repaired_groups[group_idx]), + _group_anchor_position(repaired_groups[group_idx]), + group_idx, + ), + ) + + merged = False + for source_idx in undersized_group_indices: + target_idx = _best_graph_merge_target( + repaired_groups, + source_idx, + settings=settings, + ) + if target_idx is None: + continue + repaired_groups = _merge_group_units(repaired_groups, source_idx, target_idx) + merged = True + break + + if not merged: + return repaired_groups + + def _group_units_cp_sat( units: list[PartitionUnit], *, settings: Settings ) -> list[list[PartitionUnit]]: @@ -392,6 +533,7 @@ def partition_diff(parsed_diff: ParsedDiff, settings: Settings) -> list[Group]: match settings.partition_strategy: case PartitionStrategy.GRAPH: grouped_units = _group_units_graph(units, settings=settings) + grouped_units = _repair_graph_min_loc(grouped_units, settings=settings) case PartitionStrategy.CP_SAT: grouped_units = _group_units_cp_sat(units, settings=settings) case _: diff --git a/tests/test_partitioning_extensive.py b/tests/test_partitioning_extensive.py index bdd9e55..34e5090 100644 --- a/tests/test_partitioning_extensive.py +++ b/tests/test_partitioning_extensive.py @@ -79,6 +79,7 @@ def _settings( monkeypatch: pytest.MonkeyPatch, *, max_loc: int, + min_loc: int | None = None, partition_strategy: PartitionStrategy, priority: Priority = Priority.ORTHOGONAL, ) -> Settings: @@ -86,6 +87,7 @@ def _settings( monkeypatch.delenv("OPENAI_API_KEY", raising=False) return Settings( max_loc=max_loc, + min_loc=min_loc, partition_strategy=partition_strategy, priority=priority, ) @@ -104,13 +106,32 @@ def _group_signature(groups: list[Group]) -> tuple[tuple[tuple[str, tuple[int, . return tuple(normalized) -def _assert_valid_plan(groups: list[Group], diff_text: str, max_loc: int) -> None: +def _assert_valid_plan( + groups: list[Group], + diff_text: str, + max_loc: int, + min_loc: int | None = None, +) -> None: parsed = parse_diff(diff_text) - warnings = validate_plan(groups, parsed, PlanDAG(groups), max_loc) + warnings = validate_plan(groups, parsed, PlanDAG(groups), max_loc, min_loc=min_loc) assert warnings == [] class TestPartitionDiffGraphExtensive: + def test_min_loc_merges_undersized_groups_when_possible( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + settings = _settings( + monkeypatch, + max_loc=10, + min_loc=5, + partition_strategy=PartitionStrategy.GRAPH, + priority=Priority.ORTHOGONAL, + ) + groups = partition_diff(parse_diff(UNRELATED_DIFF), settings) + assert len(groups) == 1 + _assert_valid_plan(groups, UNRELATED_DIFF, 10, min_loc=5) + def test_orthogonal_keeps_unrelated_files_separate( self, monkeypatch: pytest.MonkeyPatch ) -> None: @@ -163,6 +184,18 @@ def test_is_deterministic_across_runs(self, monkeypatch: pytest.MonkeyPatch) -> signatures = {_group_signature(partition_diff(parsed, settings)) for _ in range(3)} assert len(signatures) == 1 + def test_min_loc_repair_is_deterministic(self, monkeypatch: pytest.MonkeyPatch) -> None: + settings = _settings( + monkeypatch, + max_loc=10, + min_loc=5, + partition_strategy=PartitionStrategy.GRAPH, + priority=Priority.ORTHOGONAL, + ) + parsed = parse_diff(UNRELATED_DIFF) + signatures = {_group_signature(partition_diff(parsed, settings)) for _ in range(3)} + assert len(signatures) == 1 + class TestPartitionDiffCpSatExtensive: def test_missing_ortools_raises_clean_error(self, monkeypatch: pytest.MonkeyPatch) -> None: