diff --git a/examples/hello-world/step-by-step/cifar10/code/fl/train_with_mlflow.py b/examples/hello-world/step-by-step/cifar10/code/fl/train_with_mlflow.py index 898d5aa507..63f18fad3d 100644 --- a/examples/hello-world/step-by-step/cifar10/code/fl/train_with_mlflow.py +++ b/examples/hello-world/step-by-step/cifar10/code/fl/train_with_mlflow.py @@ -139,7 +139,7 @@ def evaluate(input_weights): running_loss += loss.item() if i % 2000 == 1999: # print every 2000 mini-batches print(f"({client_id}) [{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}") - global_step = input_model.current_round * local_epochs * batch_size + epoch * batch_size + i + global_step = input_model.current_round * steps + epoch * len(trainloader) + i mlflow.log_metric("loss", running_loss / 2000, global_step) running_loss = 0.0