diff --git a/agents/tests/unit/test_tools.py b/agents/tests/unit/test_tools.py index 0b028ce6..f8c245c4 100644 --- a/agents/tests/unit/test_tools.py +++ b/agents/tests/unit/test_tools.py @@ -20,6 +20,7 @@ GitPatchCreationToolInput, GitLogSearchTool, GitLogSearchToolInput, + discover_patch_p, ) from tools.commands import RunShellCommandTool, RunShellCommandToolInput from tools.specfile import ( @@ -533,3 +534,36 @@ async def test_git_log_search_tool_found(git_repo, cve_id, jira_issue, expected) ).middleware(GlobalTrajectoryMiddleware(pretty=True)) result = output.result assert expected in result + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "patch_content, expected_n", + [ + ( + "diff --git a/file.txt b/file.txt\n" + "index cb752151e..ceb5c5dca 100644\n" + "--- a/file.txt\n" + "+++ b/file.txt\n" + "@@ -1,2 +1,3 @@\n" + " Line 1\n" + " Line 2\n" + "+Line 3\n", + 1), + ( + "diff --git a/z/file.txt b/z/file.txt\n" + "index cb752151e..ceb5c5dca 100644\n" + "--- a/z/file.txt\n" + "+++ b/z/file.txt\n" + "@@ -1,2 +1,3 @@\n" + " Line 1\n" + " Line 2\n" + "+Line 3\n", + 2), + ] +) +async def test_discover_patch_p(git_repo, tmp_path, patch_content, expected_n): + patch_file = tmp_path / f"{expected_n}.patch" + patch_file.write_text(patch_content) + n = await discover_patch_p(patch_file, git_repo) + assert n == expected_n diff --git a/agents/tools/wicked_git.py b/agents/tools/wicked_git.py index 30f2dded..b282a13c 100644 --- a/agents/tools/wicked_git.py +++ b/agents/tools/wicked_git.py @@ -106,6 +106,42 @@ async def git_am_show_current_patch(repository_path: AbsolutePath) -> str: return "" +async def discover_patch_p(patch_file_path: AbsolutePath, repository_path: AbsolutePath) -> int: + """ + Process the given patch file and figure out with which `-p` value the patch should be applied + in the given repository. + + Using `git apply --stat` we parse the given patch and try to fit it into the given repository. + """ + cmd = ["git", "apply", "--stat", str(patch_file_path)] + exit_code, stdout, stderr = await run_subprocess(cmd, cwd=repository_path) + if exit_code != 0: + # this means the patch is borked + raise ToolError(f"Command git-apply --stat failed: {stderr}") + # expat/lib/xmlparse.c | 8 - + # .github/workflows/scripts/mass-cppcheck.sh | 1 + # .github/workflows/data/exported-symbols.txt | 2 + # expat/lib/expat.h | 15 + + lines = stdout.splitlines() + files = [line.split("|")[0].strip() for line in lines if "|" in line] + + # git-apply hates -p0: + # "git diff header lacks filename information when removing 1 leading pathname component (line 5)" + # but /usr/bin/patch should be able to handle -p0, so this is a TODO + # Nikola checked Fedora spec files: 17 -p3, 10 -p4, 2 -p5 + for n in range(1, 6): + split_this_many = n - 1 + for fi in files: + stripped_fi = fi + if split_this_many > 0: + stripped_fi = fi.split("/", split_this_many)[-1] + if (repository_path / stripped_fi).exists(): + # I know this is naive, but we certainly cannot check all files + # because some may be missing in the checkout + return n + raise ToolError(f"Failed to discover the value for `-p` for patch file: {patch_file_path}") + + class GitPatchCreationToolInput(BaseModel): repository_path: AbsolutePath = Field(description="Absolute path to the git repository") patch_file_path: AbsolutePath = Field(description="Absolute path where the patch file should be saved") @@ -131,8 +167,9 @@ async def _run( self, tool_input: GitPatchApplyToolInput, options: ToolRunOptions | None, context: RunContext ) -> StringToolOutput: ensure_git_repository(tool_input.repository_path) + p = await discover_patch_p(tool_input.patch_file_path, tool_input.repository_path) try: - cmd = ["git", "am", "--reject", str(tool_input.patch_file_path)] + cmd = ["git", "am", "--reject", f"-p{p}", str(tool_input.patch_file_path)] exit_code, stdout, stderr = await run_subprocess(cmd, cwd=tool_input.repository_path) if exit_code != 0: return StringToolOutput(