Skip to content

Commit

Permalink
Fix recursive chained script
Browse files Browse the repository at this point in the history
Based on astral-sh#982, need to rebase after that get merged.
  • Loading branch information
j178 committed Apr 2, 2024
1 parent 8c7abf2 commit 75b30bb
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
11 changes: 8 additions & 3 deletions rye/src/cli/run.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use std::env::{self, join_paths, split_paths};
use std::ffi::OsString;
use std::path::PathBuf;
Expand Down Expand Up @@ -62,14 +62,15 @@ pub fn execute(cmd: Args) -> Result<(), Error> {
None => unreachable!(),
};

invoke_script(&pyproject, args, true)?;
invoke_script(&pyproject, args, true, &mut HashSet::new())?;
unreachable!();
}

fn invoke_script(
pyproject: &PyProject,
mut args: Vec<OsString>,
exec: bool,
seen_chain: &mut HashSet<OsString>,
) -> Result<ExitStatus, Error> {
let venv_bin = pyproject.venv_bin_path();
let mut env_overrides = None;
Expand Down Expand Up @@ -126,9 +127,13 @@ fn invoke_script(
if args.len() != 1 {
bail!("extra arguments to chained commands are not allowed");
}
if seen_chain.contains(&args[0]) {
bail!("found recursive chain script");
}
seen_chain.insert(args[0].clone());
for args in commands {
let status =
invoke_script(pyproject, args.into_iter().map(Into::into).collect(), false)?;
invoke_script(pyproject, args.into_iter().map(Into::into).collect(), false, seen_chain)?;
if !status.success() {
if !exec {
return Ok(status);
Expand Down
11 changes: 9 additions & 2 deletions rye/tests/test_run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ fn test_script_chain() {
// A nested `chain` script
scripts["script_6"]["chain"] =
value(Array::from_iter(["script_1", "script_4", "script_5"]));
// NEED FIX: A recursive `chain` script
// A recursive `chain` script
scripts["script_7"]["chain"] = value(Array::from_iter(["script_7"]));

doc["tool"]["rye"]["scripts"] = scripts;
Expand Down Expand Up @@ -249,7 +249,14 @@ fn test_script_chain() {
----- stderr -----
error: script failed with exit code: 1
"###);
// rye_cmd_snapshot!(space.rye_cmd().arg("run").arg("script_7"), @r###""###);
rye_cmd_snapshot!(space.rye_cmd().arg("run").arg("script_7"), @r###"
success: false
exit_code: 1
----- stdout -----
----- stderr -----
error: found recursive chain script
"###);
}

#[test]
Expand Down

0 comments on commit 75b30bb

Please sign in to comment.