diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a17f093..ae2442f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -21,17 +21,20 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.8", "3.10"] - torch-version: ["1.12", "2.0"] + version: + - {python: "3.9", torch: "1.13"} + - {python: "3.10", torch: "2.0"} + - {python: "3.11", torch: "2.2"} + - {python: "3.12", torch: "2.4"} steps: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: - python-version: ${{ matrix.python-version }} + python-version: ${{ matrix.version.python }} - name: Install dependencies run: | pip install pytest - pip install torch==${{ matrix.torch-version }} --extra-index-url https://download.pytorch.org/whl/cpu + pip install torch==${{ matrix.version.torch }} --extra-index-url https://download.pytorch.org/whl/cpu pip install . - name: Run tests run: pytest tests diff --git a/pyproject.toml b/pyproject.toml index 657f51f..bed81ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ keywords = [ "deep learning", ] readme = "README.md" -requires-python = ">=3.8" +requires-python = ">=3.9" [project.optional-dependencies] dev = [ diff --git a/tests/conftest.py b/tests/conftest.py index 6cccc9e..f88303e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,9 +4,36 @@ import torch +def pytest_addoption(parser): + parser.addoption("--device", type=str, default="cpu") + parser.addoption("--dtype", type=str, default="float64") + + +@pytest.fixture(autouse=True, scope="module") +def torch_device(pytestconfig): + device = pytestconfig.getoption("device") + + if device == "cpu": + yield + else: + try: + yield torch.set_default_device(device) + finally: + torch.set_default_device("cpu") + + @pytest.fixture(autouse=True, scope="module") -def torch_float64(): +def torch_dtype(pytestconfig): + dtype = pytestconfig.getoption("dtype") + + if dtype == "float32": + dtype = torch.float32 + elif dtype == "float64": + dtype = torch.float64 + else: + raise NotImplementedError() + try: - yield torch.set_default_dtype(torch.float64) + yield torch.set_default_dtype(dtype) finally: torch.set_default_dtype(torch.float32)