diff --git a/CHANGELOG.md b/CHANGELOG.md index 5f425c29..3de321bf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,16 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +--- + +## [UnReleased] - 2025-08-DD + +### Changed + +- CLI: replace dependencies also in `pyproject.toml` ([#423](https://github.com/Lightning-AI/utilities/pull/423)) + + --- ## [0.15.1] - 2025-08-04 diff --git a/src/lightning_utilities/cli/dependencies.py b/src/lightning_utilities/cli/dependencies.py index 34d7ec17..04c770e5 100644 --- a/src/lightning_utilities/cli/dependencies.py +++ b/src/lightning_utilities/cli/dependencies.py @@ -45,19 +45,19 @@ def _prune_packages(req_file: str, packages: Sequence[str]) -> None: fp.writelines(lines) -def _replace_min_txt(fname: str) -> None: - with open(fname) as fopen: +def _replace_min_req_in_txt(req_file: str) -> None: + with open(req_file) as fopen: req = fopen.read().replace(">=", "==") - with open(fname, "w") as fw: + with open(req_file, "w") as fw: fw.write(req) -def _replace_min_pyproject_toml(fname: str) -> None: +def _replace_min_req_in_pyproject_toml(proj_file: str = "pyproject.toml") -> None: """Replace all `>=` with `==` in the standard pyproject.toml file in [project.dependencies].""" import tomlkit # Load and parse the existing pyproject.toml - with open(fname, encoding="utf-8") as f: + with open(proj_file, encoding="utf-8") as f: content = f.read() doc = tomlkit.parse(content) @@ -72,7 +72,7 @@ def _replace_min_pyproject_toml(fname: str) -> None: deps[i] = req.replace(">=", "==") # Dump back out, preserving layout - with open(fname, "w", encoding="utf-8") as f: + with open(proj_file, "w", encoding="utf-8") as f: f.write(tomlkit.dumps(doc)) @@ -82,28 +82,53 @@ def replace_oldest_version(req_files: Union[str, Sequence[str]] = REQUIREMENT_FI req_files = [req_files] for fname in req_files: if fname.endswith(".txt"): - _replace_min_txt(fname) + _replace_min_req_in_txt(fname) elif os.path.basename(fname) == "pyproject.toml": - _replace_min_pyproject_toml(fname) + _replace_min_req_in_pyproject_toml(fname) else: warnings.warn( "Only *.txt with plain list of requirements or standard pyproject.toml are supported." - f" File '{fname}' is not supported.", + f"Provided '{fname}' is not supported.", UserWarning, stacklevel=2, ) -def _replace_package_name(requirements: list[str], old_package: str, new_package: str) -> list[str]: - """Replace one package by another with the same version in a given requirement file. - - >>> _replace_package_name(["torch>=1.0 # comment", "torchvision>=0.2", "torchtext <0.3"], "torch", "pytorch") - ['pytorch>=1.0 # comment', 'torchvision>=0.2', 'torchtext <0.3'] - - """ +def _replace_package_name_in_txt(req_file: str, old_package: str, new_package: str) -> None: + """Replace one package by another with the same version in a given requirement file.""" + # load file + with open(req_file) as fopen: + requirements = fopen.readlines() + # replace all occurrences for i, req in enumerate(requirements): requirements[i] = re.sub(r"^" + re.escape(old_package) + r"(?=[ <=>#]|$)", new_package, req) - return requirements + # save file + with open(req_file, "w") as fw: + fw.writelines(requirements) + + +def _replace_package_name_in_pyproject_toml(proj_file: str, old_package: str, new_package: str) -> None: + """Replace one package by another with the same version in the standard pyproject.toml file.""" + import tomlkit + + # Load and parse the existing pyproject.toml + with open(proj_file, encoding="utf-8") as f: + content = f.read() + doc = tomlkit.parse(content) + + # todo: consider also replace extras in [dependency-groups] -> extras = [...] + deps = doc.get("project", {}).get("dependencies") + if not deps: + return + + # Replace '>=version' with '==version' in each dependency + for i, req in enumerate(deps): + # Simple string value + deps[i] = re.sub(r"^" + re.escape(old_package) + r"(?=[ <=>]|$)", new_package, req) + + # Dump back out, preserving layout + with open(proj_file, "w", encoding="utf-8") as f: + f.write(tomlkit.dumps(doc)) def replace_package_in_requirements( @@ -113,8 +138,14 @@ def replace_package_in_requirements( if isinstance(req_files, str): req_files = [req_files] for fname in req_files: - with open(fname) as fopen: - reqs = fopen.readlines() - reqs = _replace_package_name(reqs, old_package, new_package) - with open(fname, "w") as fw: - fw.writelines(reqs) + if fname.endswith(".txt"): + _replace_package_name_in_txt(fname, old_package, new_package) + elif os.path.basename(fname) == "pyproject.toml": + _replace_package_name_in_pyproject_toml(fname, old_package, new_package) + else: + warnings.warn( + "Only *.txt with plain list of requirements or standard pyproject.toml are supported." + f"Provided '{fname}' is not supported.", + UserWarning, + stacklevel=2, + ) diff --git a/tests/unittests/cli/test_dependencies.py b/tests/unittests/cli/test_dependencies.py index 5185a6f4..008b6669 100644 --- a/tests/unittests/cli/test_dependencies.py +++ b/tests/unittests/cli/test_dependencies.py @@ -49,7 +49,7 @@ def test_oldest_packages_pyproject_toml(tmpdir): ' "abc>=0.1",\n', "]\n", ]) - replace_oldest_version(req_files=[str(req_file)]) + replace_oldest_version(req_files=str(req_file)) with open(req_file) as fp: lines = fp.readlines() assert lines == [ @@ -59,3 +59,25 @@ def test_oldest_packages_pyproject_toml(tmpdir): ' "abc==0.1",\n', "]\n", ] + + +def test_replace_packages_pyproject_toml(tmpdir): + req_file = tmpdir / "pyproject.toml" + with open(req_file, "w") as fp: + fp.writelines([ + "[project]\n", + "dependencies = [\n", + ' "fire>0.2",\n', + ' "abc>=0.1",\n', + "]\n", + ]) + replace_package_in_requirements(req_files=str(req_file), old_package="fire", new_package="water") + with open(req_file) as fp: + lines = fp.readlines() + assert lines == [ + "[project]\n", + "dependencies = [\n", + ' "water>0.2",\n', + ' "abc>=0.1",\n', + "]\n", + ]