diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 5f72d91..23d4860 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -1,6 +1,6 @@ plugin_base: &plugin_base service-account-name: monorepo-ci - image: gcr.io/embark-shared/ml/ci-runner@sha256:54904440250d9ae14f6ddf6d72c577f2e06c85f79aa6ebe31558c35cbb93280f + image: gcr.io/embark-shared/ml/ci-runner@sha256:dac3595ade7e3e92ed006f6c29f461b71bb3a6b0ade8d3afb88ba8e55b9601d6 default-secret-name: buildkite-k8s-plugin always-pull: false use-agent-node-affinity: true @@ -30,7 +30,7 @@ small: &small resources-limit-memory: 20Gi env: - PDM_COMMAND: pdm27 + PDM_COMMAND: pdm210 PYTHON_VERSION: '3.9' steps: @@ -69,7 +69,9 @@ steps: - "pdm25" - "pdm26" - "pdm27" - - "pdm28" + - "pdm29" + - "pdm210" + command: bash .buildkite/run-pytest.sh {{matrix}} << : *small diff --git a/pdm-plugin-torch/pdm_plugin_torch/config.py b/pdm-plugin-torch/pdm_plugin_torch/config.py index 2b86046..a5c00f3 100644 --- a/pdm-plugin-torch/pdm_plugin_torch/config.py +++ b/pdm-plugin-torch/pdm_plugin_torch/config.py @@ -40,6 +40,6 @@ def variants(self): ) if self.enable_cpu: - resolves["cpu"] = ("https://download.pytorch.org/whl/cpu", "") + resolves["cpu"] = ("https://download.pytorch.org/whl/cpu", "+cpu") return resolves diff --git a/pdm-plugin-torch/pdm_plugin_torch/main.py b/pdm-plugin-torch/pdm_plugin_torch/main.py index 9773156..fd5326c 100644 --- a/pdm-plugin-torch/pdm_plugin_torch/main.py +++ b/pdm-plugin-torch/pdm_plugin_torch/main.py @@ -6,7 +6,7 @@ import tomlkit -from pdm import termui +from pdm import __version__, termui from pdm._types import RepositoryConfig from pdm.cli.commands.base import BaseCommand from pdm.cli.utils import fetch_hashes, format_lockfile, format_resolution_impossible @@ -14,7 +14,7 @@ from pdm.models.candidates import Candidate from pdm.models.repositories import BaseRepository, LockedRepository from pdm.models.requirements import Requirement, parse_requirement -from pdm.models.specifiers import get_specifier +from pdm.models.specifiers import PySpecSet, get_specifier from pdm.project import Project from pdm.resolver import resolve from pdm.resolver.providers import BaseProvider @@ -26,6 +26,11 @@ from pdm_plugin_torch.config import Configuration +is_pdm210 = PySpecSet(">=2.10").contains(__version__.__version__) +is_pdm29 = PySpecSet(">=2.9").contains(__version__.__version__) +is_pdm28 = PySpecSet(">=2.8").contains(__version__.__version__) + + def sources(project: Project, sources: list) -> list[RepositoryConfig]: result: dict[str, RepositoryConfig] = {} for source in project.pyproject.settings.get("source", []): @@ -180,7 +185,26 @@ def do_lock( ui.echo(format_resolution_impossible(err), err=True) raise ResolutionImpossible("Unable to find a resolution") from None else: - data = format_lockfile(project, mapping, dependencies) + if is_pdm210: + from pdm.project.lockfile import FLAG_STATIC_URLS + + data = format_lockfile( + project, + mapping, + dependencies, + groups=[], + strategy={FLAG_STATIC_URLS}, + ) + + elif is_pdm29: + data = format_lockfile(project, mapping, dependencies, static_urls=True) + + elif is_pdm28: + data = format_lockfile(project, mapping, dependencies, static_urls=True) + + else: + data = format_lockfile(project, mapping, dependencies) + ui.echo(f"{termui.Emoji.LOCK} Lock successful") return data @@ -249,8 +273,9 @@ def do_sync( dry_run=False, no_editable=True, install_self=False, - reinstall=True, + reinstall=False, only_keep=False, + fail_fast=True, ) with project.core.ui.logging("install"): @@ -343,7 +368,30 @@ def handle(self, project: Project, options: dict): lockfile = read_lockfile(project, plugin_config.lockfile) spec_for_version = lockfile[options.api] + (source, local_version) = resolves[options.api] + + if is_pdm210: + from pdm.project.lockfile import FLAG_STATIC_URLS + + class OverrideLockfile: + def __init__(self, lockfile): + self._lockfile = lockfile + + @property + def strategy(self): + strategies = self._lockfile.strategy + strategies.add(FLAG_STATIC_URLS) + + return strategies + + def __getattr__(self, name): + return getattr(self._lockfile, name) + + original_lockfile = project.lockfile + + project._lockfile = OverrideLockfile(original_lockfile) + reqs = [ parse_requirement(f"{req}{local_version}", False) for req in plugin_config.dependencies @@ -363,6 +411,9 @@ def handle(self, project: Project, options: dict): lockfile=spec_for_version, ) + if is_pdm210: + project._lockfile = original_lockfile + class LockCommand(BaseCommand): name = "lock" @@ -439,4 +490,8 @@ def handle(self, project: Project, options) -> None: def torch_plugin(core: Core): + if is_pdm28 and not is_pdm29: + raise RuntimeError( + "pdm 2.8.* is not supported due to not https://github.com/pdm-project/pdm/issues/2151" + ) core.register_command(TorchCommand, "torch")