Skip to content

Commit

Permalink
🧪 Test on CUDA device
Browse files Browse the repository at this point in the history
  • Loading branch information
francois-rozet committed Nov 24, 2024
1 parent b44446d commit 2f325e4
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 7 deletions.
11 changes: 7 additions & 4 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,20 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.10"]
torch-version: ["1.12", "2.0"]
version:
- {python: "3.9", torch: "1.13"}
- {python: "3.10", torch: "2.0"}
- {python: "3.11", torch: "2.2"}
- {python: "3.12", torch: "2.4"}
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
python-version: ${{ matrix.version.python }}
- name: Install dependencies
run: |
pip install pytest
pip install torch==${{ matrix.torch-version }} --extra-index-url https://download.pytorch.org/whl/cpu
pip install torch==${{ matrix.version.torch }} --extra-index-url https://download.pytorch.org/whl/cpu
pip install .
- name: Run tests
run: pytest tests
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ keywords = [
"deep learning",
]
readme = "README.md"
requires-python = ">=3.8"
requires-python = ">=3.9"

[project.optional-dependencies]
dev = [
Expand Down
31 changes: 29 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,36 @@
import torch


def pytest_addoption(parser):
parser.addoption("--device", type=str, default="cpu")
parser.addoption("--dtype", type=str, default="float64")


@pytest.fixture(autouse=True, scope="module")
def torch_device(pytestconfig):
device = pytestconfig.getoption("device")

if device == "cpu":
yield
else:
try:
yield torch.set_default_device(device)
finally:
torch.set_default_device("cpu")


@pytest.fixture(autouse=True, scope="module")
def torch_float64():
def torch_dtype(pytestconfig):
dtype = pytestconfig.getoption("dtype")

if dtype == "float32":
dtype = torch.float32
elif dtype == "float64":
dtype = torch.float64
else:
raise NotImplementedError()

try:
yield torch.set_default_dtype(torch.float64)
yield torch.set_default_dtype(dtype)
finally:
torch.set_default_dtype(torch.float32)

0 comments on commit 2f325e4

Please sign in to comment.