From 3c8cf3469d50008a50c2711c20d3cc4d0e27e8a2 Mon Sep 17 00:00:00 2001
From: Sebastian Hoffmann <shoffmann.git@gmail.com>
Date: Mon, 18 Mar 2024 16:54:40 +0100
Subject: [PATCH] fix: metric reduction with empty values

---
 dmlcloud/metrics.py  | 21 ++++++++++++++++++---
 dmlcloud/pipeline.py |  4 ++--
 test/test_metrics.py |  8 ++++++++
 3 files changed, 28 insertions(+), 5 deletions(-)

diff --git a/dmlcloud/metrics.py b/dmlcloud/metrics.py
index 09ff494..3a2644a 100644
--- a/dmlcloud/metrics.py
+++ b/dmlcloud/metrics.py
@@ -105,6 +105,9 @@ def reduce_and_append(self, value):
         self.values.append(value)
 
     def reduce_locally(self):
+        if len(self.values) == 0:
+            return None
+
         if isinstance(self.dim, list):
             dim = [0] + [d + 1 for d in self.dim]
         elif isinstance(self.dim, int):
@@ -115,14 +118,26 @@ def reduce_locally(self):
         tensor = reduce_tensor(tensor, reduction=self.reduction, dim=dim)
         return tensor
 
-    def reduce_globally(self, group=None, async_op=False):
+    def reduce_globally(self, group=None):
+        # if the list of values is empty, the result is None
+        if self.globally:
+            empty_workers = [None] * dist.get_world_size(group)
+            dist.all_gather_object(empty_workers, len(self.values) == 0, group=group)
+            if any(empty_workers):
+                if len(empty_workers) > 1 and not all(empty_workers):
+                    raise ValueError('Some workers tracked values this epoch and some did not. This is likely a bug.')
+                else:
+                    return None
+        elif len(self.values) == 0:
+            return None
+
         tensor = self.reduce_locally()
         if self.globally:
             if self.reduction == Reduction.MEAN:
-                dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=async_op)
+                dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group)
                 tensor /= dist.get_world_size(group)
             else:
-                dist.all_reduce(tensor, op=self.reduction.as_torch(), group=group, async_op=async_op)
+                dist.all_reduce(tensor, op=self.reduction.as_torch(), group=group)
         return tensor
 
     def state_dict(self):
diff --git a/dmlcloud/pipeline.py b/dmlcloud/pipeline.py
index 7a7ac82..e4a5609 100644
--- a/dmlcloud/pipeline.py
+++ b/dmlcloud/pipeline.py
@@ -94,8 +94,8 @@ def register_dataset(self, name: str, dataset: Union[DataLoader, Dataset, Sequen
                 msg += f'  - Batches (Total): ~{length * dist.get_world_size()}\n'
                 msg += f'  - Batches (/Worker): {length}\n'
             except TypeError:  # __len__ not implemented
-                msg += f'  - Batches (Total): N/A\n'
-                msg += f'  - Batches (/Worker): N/A\n'
+                msg += '  - Batches (Total): N/A\n'
+                msg += '  - Batches (/Worker): N/A\n'
             self.logger.info(msg)
 
     def append_stage(self, stage: Stage, max_epochs: Optional[int] = None, name: Optional[str] = None):
diff --git a/test/test_metrics.py b/test/test_metrics.py
index 1deca76..8787517 100644
--- a/test/test_metrics.py
+++ b/test/test_metrics.py
@@ -79,6 +79,14 @@ def test_serialization(self):
         assert new_reducer.dim == [1, 2, 3]
         assert new_reducer.values == reducer.values
 
+    def test_empty_reduction(self, torch_distributed):
+        reducer = MetricReducer(reduction=Reduction.MIN, globally=True)
+        result = reducer.reduce_locally()
+        assert result is None
+
+        result = reducer.reduce_globally()
+        assert result is None
+
 
 class TestMetricTracker:
     def test_dictionary(self):