From 73d953413ea6bc776b91f9cc43f8eb52ca461b20 Mon Sep 17 00:00:00 2001 From: Rittik Panda Date: Sun, 26 Oct 2025 13:13:45 +0530 Subject: [PATCH 1/6] Refactor DDP test to use multiprocessing context Refactor DDP test to use spawn context for process management. --- tests/unittests/image/test_ssim.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/tests/unittests/image/test_ssim.py b/tests/unittests/image/test_ssim.py index e327e7d7f70..35a0e7bd199 100644 --- a/tests/unittests/image/test_ssim.py +++ b/tests/unittests/image/test_ssim.py @@ -394,4 +394,14 @@ def test_ssim_reduction_none_ddp(): free_port = find_free_port() if free_port == -1: pytest.skip("No free port available for DDP test.") - mp.spawn(_run_ssim_ddp, args=(world_size, free_port), nprocs=world_size, join=True) + # Use spawn context to avoid module reimport issues + ctx = mp.get_context('spawn') + processes = [] + for rank in range(world_size): + p = ctx.Process(target=_run_ssim_ddp, args=(rank, world_size, free_port)) + p.start() + processes.append(p) + + for p in processes: + p.join() + assert p.exitcode == 0, f"Process failed with exit code {p.exitcode}" From a3ba71ddb2813b5fdc1033c91c2f945f7375ee04 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 26 Oct 2025 07:47:01 +0000 Subject: [PATCH 2/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/unittests/image/test_ssim.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unittests/image/test_ssim.py b/tests/unittests/image/test_ssim.py index 35a0e7bd199..5c793db8762 100644 --- a/tests/unittests/image/test_ssim.py +++ b/tests/unittests/image/test_ssim.py @@ -395,13 +395,13 @@ def test_ssim_reduction_none_ddp(): if free_port == -1: pytest.skip("No free port available for DDP test.") # Use spawn context to avoid module reimport issues - ctx = mp.get_context('spawn') + ctx = mp.get_context("spawn") processes = [] for rank in range(world_size): p = ctx.Process(target=_run_ssim_ddp, args=(rank, world_size, free_port)) p.start() processes.append(p) - + for p in processes: p.join() assert p.exitcode == 0, f"Process failed with exit code {p.exitcode}" From 95abf31cdba174d3d1bde2b3eadc28dc9fe6f2d6 Mon Sep 17 00:00:00 2001 From: Rittik Panda Date: Sun, 26 Oct 2025 13:20:07 +0530 Subject: [PATCH 3/6] Refactor DDP test process spawning Refactor DDP test to use spawn context for process management. --- tests/unittests/image/test_ms_ssim.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/tests/unittests/image/test_ms_ssim.py b/tests/unittests/image/test_ms_ssim.py index 89b9b5778bd..fb2430c7e42 100644 --- a/tests/unittests/image/test_ms_ssim.py +++ b/tests/unittests/image/test_ms_ssim.py @@ -139,4 +139,14 @@ def test_ms_ssim_reduction_none_ddp(): free_port = find_free_port() if free_port == -1: pytest.skip("No free port available for DDP test.") - mp.spawn(_run_ms_ssim_ddp, args=(world_size, free_port), nprocs=world_size, join=True) + # Use spawn context to avoid module reimport issues + ctx = mp.get_context('spawn') + processes = [] + for rank in range(world_size): + p = ctx.Process(target=_run_ms_ssim_ddp, args=(rank, world_size, free_port)) + p.start() + processes.append(p) + + for p in processes: + p.join() + assert p.exitcode == 0, f"Process failed with exit code {p.exitcode}" From 0847a6ce5697019fb4f03c5911b8ca76b618a357 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 26 Oct 2025 07:50:27 +0000 Subject: [PATCH 4/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/unittests/image/test_ms_ssim.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unittests/image/test_ms_ssim.py b/tests/unittests/image/test_ms_ssim.py index fb2430c7e42..3acc08800e1 100644 --- a/tests/unittests/image/test_ms_ssim.py +++ b/tests/unittests/image/test_ms_ssim.py @@ -140,13 +140,13 @@ def test_ms_ssim_reduction_none_ddp(): if free_port == -1: pytest.skip("No free port available for DDP test.") # Use spawn context to avoid module reimport issues - ctx = mp.get_context('spawn') + ctx = mp.get_context("spawn") processes = [] for rank in range(world_size): p = ctx.Process(target=_run_ms_ssim_ddp, args=(rank, world_size, free_port)) p.start() processes.append(p) - + for p in processes: p.join() assert p.exitcode == 0, f"Process failed with exit code {p.exitcode}" From 6ded0b29ff30117f78b7860205d6d55d548d5974 Mon Sep 17 00:00:00 2001 From: rittik9 Date: Sun, 26 Oct 2025 16:35:20 +0530 Subject: [PATCH 5/6] refactor --- tests/unittests/image/__init__.py | 34 +++++++++++++++++++++++++++ tests/unittests/image/test_ms_ssim.py | 19 +-------------- tests/unittests/image/test_ssim.py | 21 ++--------------- 3 files changed, 37 insertions(+), 37 deletions(-) diff --git a/tests/unittests/image/__init__.py b/tests/unittests/image/__init__.py index 8eea7d284b8..cebf9f88841 100644 --- a/tests/unittests/image/__init__.py +++ b/tests/unittests/image/__init__.py @@ -15,6 +15,8 @@ import torch import torch.distributed as dist +from torchmetrics.image import StructuralSimilarityIndexMeasure +from torchmetrics.image.ssim import MultiScaleStructuralSimilarityIndexMeasure from unittests import _PATH_ALL_TESTS @@ -33,3 +35,35 @@ def cleanup_ddp(): """Clean up the DDP process group if initialized.""" if dist.is_initialized(): dist.destroy_process_group() + +def _run_ssim_ddp(rank: int, world_size: int, free_port: int): + """Run SSIM metric computation in a DDP setup.""" + try: + setup_ddp(rank, world_size, free_port) + device = torch.device(f"cuda:{rank}") + metric = StructuralSimilarityIndexMeasure(reduction="none").to(device) + + for _ in range(3): + x, y = torch.rand(4, 3, 224, 224).to(device).chunk(2) + metric.update(x, y) + + result = metric.compute() + assert isinstance(result, torch.Tensor), "Expected compute result to be a tensor" + finally: + cleanup_ddp() + +def _run_ms_ssim_ddp(rank: int, world_size: int, free_port: int): + """Run MSSSIM metric computation in a DDP setup.""" + try: + setup_ddp(rank, world_size, free_port) + device = torch.device(f"cuda:{rank}") + metric = MultiScaleStructuralSimilarityIndexMeasure(reduction="none").to(device) + + for _ in range(3): + x, y = torch.rand(4, 3, 224, 224).to(device).chunk(2) + metric.update(x, y) + + result = metric.compute() + assert isinstance(result, torch.Tensor), "Expected compute result to be a tensor" + finally: + cleanup_ddp() diff --git a/tests/unittests/image/test_ms_ssim.py b/tests/unittests/image/test_ms_ssim.py index 3acc08800e1..dc779fd9510 100644 --- a/tests/unittests/image/test_ms_ssim.py +++ b/tests/unittests/image/test_ms_ssim.py @@ -23,7 +23,7 @@ from unittests import NUM_BATCHES, _Input from unittests._helpers import _IS_WINDOWS, seed_all from unittests._helpers.testers import MetricTester -from unittests.image import cleanup_ddp, setup_ddp +from unittests.image import _run_ms_ssim_ddp from unittests.utilities.test_utilities import find_free_port seed_all(42) @@ -110,23 +110,6 @@ def test_ms_ssim_contrast_sensitivity(): assert isinstance(out, torch.Tensor) -def _run_ms_ssim_ddp(rank: int, world_size: int, free_port: int): - """Run MSSSIM metric computation in a DDP setup.""" - try: - setup_ddp(rank, world_size, free_port) - device = torch.device(f"cuda:{rank}") - metric = MultiScaleStructuralSimilarityIndexMeasure(reduction="none").to(device) - - for _ in range(3): - x, y = torch.rand(4, 3, 224, 224).to(device).chunk(2) - metric.update(x, y) - - result = metric.compute() - assert isinstance(result, torch.Tensor), "Expected compute result to be a tensor" - finally: - cleanup_ddp() - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") @pytest.mark.skipif(_IS_WINDOWS, reason="DDP not supported on Windows") def test_ms_ssim_reduction_none_ddp(): diff --git a/tests/unittests/image/test_ssim.py b/tests/unittests/image/test_ssim.py index 5c793db8762..2105053d138 100644 --- a/tests/unittests/image/test_ssim.py +++ b/tests/unittests/image/test_ssim.py @@ -26,7 +26,7 @@ from unittests import NUM_BATCHES, _Input from unittests._helpers import _IS_WINDOWS, seed_all from unittests._helpers.testers import MetricTester -from unittests.image import cleanup_ddp, setup_ddp +from unittests.image import _run_ssim_ddp from unittests.utilities.test_utilities import find_free_port seed_all(42) @@ -365,23 +365,6 @@ def test_ssim_for_correct_padding(): assert structural_similarity_index_measure(preds, target) < 1.0 -def _run_ssim_ddp(rank: int, world_size: int, free_port: int): - """Run SSIM metric computation in a DDP setup.""" - try: - setup_ddp(rank, world_size, free_port) - device = torch.device(f"cuda:{rank}") - metric = StructuralSimilarityIndexMeasure(reduction="none").to(device) - - for _ in range(3): - x, y = torch.rand(4, 3, 224, 224).to(device).chunk(2) - metric.update(x, y) - - result = metric.compute() - assert isinstance(result, torch.Tensor), "Expected compute result to be a tensor" - finally: - cleanup_ddp() - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") @pytest.mark.skipif(_IS_WINDOWS, reason="DDP not supported on Windows") def test_ssim_reduction_none_ddp(): @@ -404,4 +387,4 @@ def test_ssim_reduction_none_ddp(): for p in processes: p.join() - assert p.exitcode == 0, f"Process failed with exit code {p.exitcode}" + assert p.exitcode == 0, f"Process failed with exit code {p.exitcode}" \ No newline at end of file From 6f6023d297b14943cf237cf3e51c135a0d1197a4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 26 Oct 2025 11:05:47 +0000 Subject: [PATCH 6/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/unittests/image/__init__.py | 4 +++- tests/unittests/image/test_ssim.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/unittests/image/__init__.py b/tests/unittests/image/__init__.py index cebf9f88841..a31775c0bde 100644 --- a/tests/unittests/image/__init__.py +++ b/tests/unittests/image/__init__.py @@ -15,9 +15,9 @@ import torch import torch.distributed as dist + from torchmetrics.image import StructuralSimilarityIndexMeasure from torchmetrics.image.ssim import MultiScaleStructuralSimilarityIndexMeasure - from unittests import _PATH_ALL_TESTS _SAMPLE_IMAGE = os.path.join(_PATH_ALL_TESTS, "_data", "image", "i01_01_5.bmp") @@ -36,6 +36,7 @@ def cleanup_ddp(): if dist.is_initialized(): dist.destroy_process_group() + def _run_ssim_ddp(rank: int, world_size: int, free_port: int): """Run SSIM metric computation in a DDP setup.""" try: @@ -52,6 +53,7 @@ def _run_ssim_ddp(rank: int, world_size: int, free_port: int): finally: cleanup_ddp() + def _run_ms_ssim_ddp(rank: int, world_size: int, free_port: int): """Run MSSSIM metric computation in a DDP setup.""" try: diff --git a/tests/unittests/image/test_ssim.py b/tests/unittests/image/test_ssim.py index 2105053d138..b1040ba911d 100644 --- a/tests/unittests/image/test_ssim.py +++ b/tests/unittests/image/test_ssim.py @@ -387,4 +387,4 @@ def test_ssim_reduction_none_ddp(): for p in processes: p.join() - assert p.exitcode == 0, f"Process failed with exit code {p.exitcode}" \ No newline at end of file + assert p.exitcode == 0, f"Process failed with exit code {p.exitcode}"