Skip to content
11 changes: 7 additions & 4 deletions tests/tests_pytorch/utilities/test_model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,13 +319,16 @@ def test_empty_model_size(max_depth):


@pytest.mark.parametrize(
"accelerator",
("accelerator", "precision"),
[
pytest.param("gpu", marks=RunIf(min_cuda_gpus=1)),
pytest.param("mps", marks=RunIf(mps=True)),
pytest.param("gpu", "16-true", marks=RunIf(min_cuda_gpus=1)),
pytest.param("gpu", "32-true", marks=RunIf(min_cuda_gpus=1)),
pytest.param("gpu", "64-true", marks=RunIf(min_cuda_gpus=1)),
pytest.param("mps", "16-true", marks=RunIf(mps=True)),
pytest.param("mps", "32-true", marks=RunIf(mps=True)),
# Note: "64-true" with "mps" is skipped because MPS does not support float64
],
)
@pytest.mark.parametrize("precision", ["16-true", "32-true", "64-true"])
def test_model_size_precision(tmp_path, accelerator, precision):
"""Test model size for different precision types."""
model = PreCalculatedModel(precision=int(precision.split("-")[0]))
Expand Down
Loading