diff --git a/rye/src/cli/run.rs b/rye/src/cli/run.rs index 394c025827..eab545f21e 100644 --- a/rye/src/cli/run.rs +++ b/rye/src/cli/run.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::env::{self, join_paths, split_paths}; use std::ffi::OsString; use std::path::PathBuf; @@ -62,7 +62,7 @@ pub fn execute(cmd: Args) -> Result<(), Error> { None => unreachable!(), }; - invoke_script(&pyproject, args, true)?; + invoke_script(&pyproject, args, true, &mut HashSet::new())?; unreachable!(); } @@ -70,6 +70,7 @@ fn invoke_script( pyproject: &PyProject, mut args: Vec, exec: bool, + seen_chain: &mut HashSet, ) -> Result { let venv_bin = pyproject.venv_bin_path(); let mut env_overrides = None; @@ -126,9 +127,13 @@ fn invoke_script( if args.len() != 1 { bail!("extra arguments to chained commands are not allowed"); } + if seen_chain.contains(&args[0]) { + bail!("found recursive chain script"); + } + seen_chain.insert(args[0].clone()); for args in commands { let status = - invoke_script(pyproject, args.into_iter().map(Into::into).collect(), false)?; + invoke_script(pyproject, args.into_iter().map(Into::into).collect(), false, seen_chain)?; if !status.success() { if !exec { return Ok(status); diff --git a/rye/tests/test_run.rs b/rye/tests/test_run.rs index 7fd5fabc38..1362a53a6d 100644 --- a/rye/tests/test_run.rs +++ b/rye/tests/test_run.rs @@ -209,7 +209,7 @@ fn test_script_chain() { // A nested `chain` script scripts["script_6"]["chain"] = value(Array::from_iter(["script_1", "script_4", "script_5"])); - // NEED FIX: A recursive `chain` script + // A recursive `chain` script scripts["script_7"]["chain"] = value(Array::from_iter(["script_7"])); doc["tool"]["rye"]["scripts"] = scripts; @@ -249,7 +249,14 @@ fn test_script_chain() { ----- stderr ----- error: script failed with exit code: 1 "###); - // rye_cmd_snapshot!(space.rye_cmd().arg("run").arg("script_7"), @r###""###); + rye_cmd_snapshot!(space.rye_cmd().arg("run").arg("script_7"), @r###" + success: false + exit_code: 1 + ----- stdout ----- + + ----- stderr ----- + error: found recursive chain script + "###); } #[test]