From 5171404a8c6aa61dbdd10d1241f0247f0b7ec4e0 Mon Sep 17 00:00:00 2001 From: chesterxgchen Date: Thu, 9 May 2024 08:00:02 -0700 Subject: [PATCH] fix global steps --- .../step-by-step/cifar10/code/fl/train_with_mlflow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/hello-world/step-by-step/cifar10/code/fl/train_with_mlflow.py b/examples/hello-world/step-by-step/cifar10/code/fl/train_with_mlflow.py index 898d5aa507..63f18fad3d 100644 --- a/examples/hello-world/step-by-step/cifar10/code/fl/train_with_mlflow.py +++ b/examples/hello-world/step-by-step/cifar10/code/fl/train_with_mlflow.py @@ -139,7 +139,7 @@ def evaluate(input_weights): running_loss += loss.item() if i % 2000 == 1999: # print every 2000 mini-batches print(f"({client_id}) [{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}") - global_step = input_model.current_round * local_epochs * batch_size + epoch * batch_size + i + global_step = input_model.current_round * steps + epoch * len(trainloader) + i mlflow.log_metric("loss", running_loss / 2000, global_step) running_loss = 0.0