Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ pr-split split feature-branch --base main --dry-run
| Flag | Default | Description |
|------|---------|-------------|
| `--base` | `main` | Base branch for the diff |
| `--max-loc` | `400` | Soft limit on diff lines per sub-PR |
| `--min-loc` | unset | Optional minimum diff lines per sub-PR |
| `--max-loc` | `400` | Maximum target diff lines per sub-PR |
| `--strict-loc-bounds` | `false` | Exit instead of proceeding when the final plan violates configured LOC bounds |
| `--priority` | `orthogonal` | Grouping priority (`orthogonal` or `logical`) |
| `--chunk-strategy` | `dynamic_programming` | Large-diff chunking strategy (`dynamic_programming` or `greedy`) |
| `--partition-strategy` | `llm` | Hunk-to-PR partition backend (`llm`, `graph`, or `cp_sat`) |
Expand Down Expand Up @@ -167,7 +169,9 @@ Settings can be set via environment variables with the `PR_SPLIT_` prefix:
| `ANTHROPIC_API_KEY` | (required for Anthropic) | Anthropic API key |
| `OPENAI_API_KEY` | (required for OpenAI) | OpenAI API key |
| `PR_SPLIT_MODEL` | auto per provider | Model name (defaults to best available model for the chosen provider) |
| `PR_SPLIT_MAX_LOC` | `400` | Default soft limit on diff lines |
| `PR_SPLIT_MIN_LOC` | unset | Optional minimum target diff lines |
| `PR_SPLIT_MAX_LOC` | `400` | Default maximum target diff lines |
| `PR_SPLIT_STRICT_LOC_BOUNDS` | `false` | Fail if the final plan violates configured LOC bounds |
| `PR_SPLIT_PRIORITY` | `orthogonal` | Default grouping priority |
| `PR_SPLIT_CHUNK_STRATEGY` | `dynamic_programming` | Large-diff chunking strategy |
| `PR_SPLIT_PARTITION_STRATEGY` | `llm` | Hunk-to-PR partition backend |
Expand Down
77 changes: 62 additions & 15 deletions pr_split/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import typer
from loguru import logger
from pydantic import ValidationError
from rich.console import Console
from rich.panel import Panel
from rich.table import Table
Expand All @@ -24,7 +25,9 @@
DEFAULT_CHUNK_STRATEGY,
DEFAULT_CP_SAT_TIMEOUT_SECONDS,
DEFAULT_MAX_LOC,
DEFAULT_MIN_LOC,
DEFAULT_PARTITION_STRATEGY,
DEFAULT_STRICT_LOC_BOUNDS,
PLAN_DIR,
PLAN_FILE,
AssignmentType,
Expand Down Expand Up @@ -129,6 +132,17 @@ def _validate_inputs(dev_branch: str, base: str, *, dry_run: bool = False) -> No
raise typer.Exit(1)


def _handle_loc_bound_warnings(warnings: list[str], *, strict_loc_bounds: bool) -> None:
if strict_loc_bounds and warnings:
console.print(f"[red]{ErrorMsg.LOC_BOUNDS_STRICT_FAILED()}[/red]")
for warning in warnings:
console.print(f"[red]- {warning}[/red]")
raise typer.Exit(1)

for warning in warnings:
logger.warning(warning)


def _present_plan(groups: list[Group]) -> None:
table = Table(title="Split Plan")
table.add_column("ID")
Expand Down Expand Up @@ -543,9 +557,30 @@ def _resolve_fork_ref(dev_branch: str) -> ForkPRInfo | None:
def split(
dev_branch: Annotated[str, typer.Argument(help="Branch name, PR number, or user:branch")],
base: Annotated[str, typer.Option(help="Base branch")] = "main",
min_loc: Annotated[
int | None,
typer.Option(
"--min-loc",
envvar="PR_SPLIT_MIN_LOC",
help="Minimum target diff lines per sub-PR",
),
] = DEFAULT_MIN_LOC,
max_loc: Annotated[
int, typer.Option(help="Soft limit on diff lines per sub-PR")
int,
typer.Option(
"--max-loc",
envvar="PR_SPLIT_MAX_LOC",
help="Maximum target diff lines per sub-PR",
),
] = DEFAULT_MAX_LOC,
strict_loc_bounds: Annotated[
bool,
typer.Option(
"--strict-loc-bounds",
envvar="PR_SPLIT_STRICT_LOC_BOUNDS",
help="Fail if the final plan violates configured LOC bounds",
),
] = DEFAULT_STRICT_LOC_BOUNDS,
priority: Annotated[Priority, typer.Option(help="Grouping priority")] = Priority.ORTHOGONAL,
chunk_strategy: Annotated[
ChunkStrategy, typer.Option(help="Chunking strategy for large diffs")
Expand Down Expand Up @@ -614,20 +649,25 @@ def split(
)
)

settings = Settings(
max_loc=max_loc,
cp_sat_timeout=cp_sat_timeout,
priority=priority,
chunk_strategy=chunk_strategy,
partition_strategy=partition_strategy,
)
try:
settings = Settings(
min_loc=min_loc,
max_loc=max_loc,
strict_loc_bounds=strict_loc_bounds,
cp_sat_timeout=cp_sat_timeout,
priority=priority,
chunk_strategy=chunk_strategy,
partition_strategy=partition_strategy,
)
except (ValidationError, ValueError) as exc:
console.print(f"[red]{exc}[/red]")
raise typer.Exit(1) from exc
groups = plan_split(parsed_diff, settings)

logger.info(logs.VALIDATING_PLAN)
dag = PlanDAG(groups)
warnings = validate_plan(groups, parsed_diff, dag, max_loc)
for warning in warnings:
logger.warning(warning)
warnings = validate_plan(groups, parsed_diff, dag, settings.max_loc, min_loc=settings.min_loc)
_handle_loc_bound_warnings(warnings, strict_loc_bounds=settings.strict_loc_bounds)
logger.success(logs.VALIDATION_PASSED)

logger.info(logs.PRESENTING_PLAN)
Expand All @@ -645,9 +685,14 @@ def split(
raise typer.Exit(1)
try:
dag = PlanDAG(groups)
warnings = validate_plan(groups, parsed_diff, dag, max_loc)
for warning in warnings:
logger.warning(warning)
warnings = validate_plan(
groups,
parsed_diff,
dag,
settings.max_loc,
min_loc=settings.min_loc,
)
_handle_loc_bound_warnings(warnings, strict_loc_bounds=settings.strict_loc_bounds)
logger.success("Edited plan validation passed")
except PRSplitError as exc:
console.print(f"[red]Edited plan is invalid: {exc}[/red]")
Expand All @@ -658,7 +703,9 @@ def split(
split_plan = SplitPlan(
dev_branch=dev_branch,
base_branch=base,
max_loc=max_loc,
min_loc=settings.min_loc,
max_loc=settings.max_loc,
strict_loc_bounds=settings.strict_loc_bounds,
priority=priority,
groups=groups,
author=author,
Expand Down
15 changes: 14 additions & 1 deletion pr_split/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,18 @@
DEFAULT_CHUNK_STRATEGY,
DEFAULT_CP_SAT_TIMEOUT_SECONDS,
DEFAULT_MAX_LOC,
DEFAULT_MIN_LOC,
DEFAULT_MODEL,
DEFAULT_PARTITION_STRATEGY,
DEFAULT_STRICT_LOC_BOUNDS,
OPENAI_MAX_CONTEXT_TOKENS,
OPENAI_MODEL,
ChunkStrategy,
PartitionStrategy,
Priority,
Provider,
)
from .exceptions import ErrorMsg


class Settings(BaseSettings):
Expand All @@ -30,7 +33,9 @@ class Settings(BaseSettings):
)
provider: Provider = Provider.ANTHROPIC
model: str = ""
max_loc: int = DEFAULT_MAX_LOC
min_loc: int | None = Field(default=DEFAULT_MIN_LOC, ge=1)
max_loc: int = Field(default=DEFAULT_MAX_LOC, gt=0)
strict_loc_bounds: bool = DEFAULT_STRICT_LOC_BOUNDS
cp_sat_timeout: float = DEFAULT_CP_SAT_TIMEOUT_SECONDS
priority: Priority = Priority.ORTHOGONAL
chunk_strategy: ChunkStrategy = DEFAULT_CHUNK_STRATEGY
Expand All @@ -48,6 +53,14 @@ def set_default_model(self):
raise NotImplementedError(f"No default model for provider '{self.provider}'")
return self

@model_validator(mode="after")
def validate_loc_bounds(self):
if self.min_loc is not None and self.min_loc >= self.max_loc:
raise ValueError(
ErrorMsg.MIN_LOC_GE_MAX_LOC(min_loc=self.min_loc, max_loc=self.max_loc)
)
return self

@model_validator(mode="after")
def check_api_key_is_present(self):
if self.partition_strategy != PartitionStrategy.LLM:
Expand Down
7 changes: 7 additions & 0 deletions pr_split/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ class PartitionStrategy(StrEnum):
CP_SAT = "cp_sat"


class LocViolationType(StrEnum):
BELOW_MIN = "below_min"
ABOVE_MAX = "above_max"


class PRState(StrEnum):
OPEN = "open"
CLOSED = "closed"
Expand All @@ -36,7 +41,9 @@ class Provider(StrEnum):
BRANCH_PREFIX = "pr-split/"
PLAN_DIR = ".pr-split"
PLAN_FILE = ".pr-split/plan.json"
DEFAULT_MIN_LOC: int | None = None
DEFAULT_MAX_LOC = 400
DEFAULT_STRICT_LOC_BOUNDS = False
DEFAULT_CP_SAT_TIMEOUT_SECONDS = 15.0
DEFAULT_CHUNK_STRATEGY = ChunkStrategy.DYNAMIC_PROGRAMMING
DEFAULT_PARTITION_STRATEGY = PartitionStrategy.LLM
Expand Down
2 changes: 2 additions & 0 deletions pr_split/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class ErrorMsg(StrEnum):
PR_FETCH_FAILED = "Failed to fetch fork branch for PR #{number}: {detail}"
FORK_FETCH_FAILED = "Failed to fetch {user}:{branch}: {detail}"
HUNK_TOO_LARGE = "Hunk {file}[{index}] has ~{tokens} estimated tokens, exceeds budget {budget}"
MIN_LOC_GE_MAX_LOC = "min_loc {min_loc} must be less than max_loc {max_loc}"
LOC_BOUNDS_STRICT_FAILED = "Plan violates configured LOC bounds"

def __call__(self, **kwargs: object) -> str:
return self.value.format(**kwargs) if kwargs else self.value
Expand Down
11 changes: 8 additions & 3 deletions pr_split/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
LLM_RESPONSE_RECEIVED = "Received split plan with {count} groups"
VALIDATING_PLAN = "Validating split plan"
VALIDATION_PASSED = "Plan validation passed"
LOC_SOFT_WARN = "Group '{group}' has +{added}/-{removed} diff lines (soft limit: {limit})"
LOC_MIN_WARN = (
"Group '{group}' has {loc} diff lines (+{added}/-{removed}) below minimum: {limit}"
)
LOC_MAX_WARN = (
"Group '{group}' has {loc} diff lines (+{added}/-{removed}) above maximum: {limit}"
)
PRESENTING_PLAN = "Split plan ready for review"
CREATING_BRANCH = "Creating branch {branch} from {base}"
CREATING_MERGE_BASE = "Creating merge base {branch} from parents: {parents}"
Expand Down Expand Up @@ -46,6 +51,6 @@
HUNK_AUTO_ASSIGNED = "Auto-assigned uncovered hunk {file}[{index}] to group '{group}'"
UNCOVERED_HUNKS_FIXED = "Auto-assigned {count} uncovered hunk(s) to existing groups"
PLAN_METRICS = (
"Plan metrics: groups={groups}, max_group_loc={max_loc}, overflow={overflow}, "
"width={width}, depth={depth}, scatter={scatter}, objective={objective}"
"Plan metrics: groups={groups}, max_group_loc={max_loc}, underflow={underflow}, "
"overflow={overflow}, width={width}, depth={depth}, scatter={scatter}, objective={objective}"
)
3 changes: 2 additions & 1 deletion pr_split/planner/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,11 +362,12 @@ def plan_split(
f"Unsupported partition strategy '{settings.partition_strategy}'"
)

metrics = score_plan(groups, settings.max_loc)
metrics = score_plan(groups, settings.max_loc, settings.min_loc)
logger.info(
logs.PLAN_METRICS.format(
groups=metrics.total_groups,
max_loc=metrics.max_group_loc,
underflow=metrics.loc_underflow,
overflow=metrics.loc_overflow,
width=metrics.dag_width,
depth=metrics.dag_depth,
Expand Down
21 changes: 19 additions & 2 deletions pr_split/planner/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
class PlanMetrics:
total_groups: int
max_group_loc: int
loc_underflow: int
loc_overflow: int
dependency_edges: int
dag_depth: int
dag_width: int
file_scatter: int
undersized_groups: int
tiny_groups: int
objective: int

Expand All @@ -27,10 +29,15 @@ def _dag_depth(dag: PlanDAG) -> int:
return max(depths.values(), default=0)


def score_plan(groups: list[Group], max_loc: int) -> PlanMetrics:
def score_plan(groups: list[Group], max_loc: int, min_loc: int | None = None) -> PlanMetrics:
dag = PlanDAG(groups)

max_group_loc = max((group.estimated_loc for group in groups), default=0)
loc_underflow = (
sum(max(0, min_loc - group.estimated_loc) for group in groups)
if min_loc is not None
else 0
)
loc_overflow = sum(max(0, group.estimated_loc - max_loc) for group in groups)
dependency_edges = sum(len(group.depends_on) for group in groups)
dag_width = max((len(batch) for batch in dag.iter_ready()), default=0)
Expand All @@ -42,11 +49,19 @@ def score_plan(groups: list[Group], max_loc: int) -> PlanMetrics:
file_groups.setdefault(assignment.file_path, set()).add(group.id)
file_scatter = sum(max(0, len(group_ids) - 1) for group_ids in file_groups.values())

undersized_groups = (
# Includes empty groups (estimated_loc == 0), unlike tiny_groups below.
sum(1 for group in groups if group.estimated_loc < min_loc)
if min_loc is not None
else 0
)
tiny_threshold = max(1, max_loc // 4)
tiny_groups = sum(1 for group in groups if 0 < group.estimated_loc < tiny_threshold)

objective = (
loc_overflow * 1000
loc_underflow * 1000
+ undersized_groups * 100
+ loc_overflow * 1000
+ dependency_edges * 20
+ file_scatter * 50
+ tiny_groups * 10
Expand All @@ -58,11 +73,13 @@ def score_plan(groups: list[Group], max_loc: int) -> PlanMetrics:
return PlanMetrics(
total_groups=len(groups),
max_group_loc=max_group_loc,
loc_underflow=loc_underflow,
loc_overflow=loc_overflow,
dependency_edges=dependency_edges,
dag_depth=dag_depth,
dag_width=dag_width,
file_scatter=file_scatter,
undersized_groups=undersized_groups,
tiny_groups=tiny_groups,
objective=objective,
)
Loading
Loading