diff --git a/examples/hello-world/hello-cyclic/README.md b/examples/hello-world/hello-cyclic/README.md index c5fde50057..84d81eda1f 100644 --- a/examples/hello-world/hello-cyclic/README.md +++ b/examples/hello-world/hello-cyclic/README.md @@ -9,7 +9,8 @@ You can follow the [hello_world notebook](../hello_world.ipynb) or the following ### 1. Install NVIDIA FLARE -Follow the [Installation](https://nvflare.readthedocs.io/en/main/quickstart.html) instructions. +Follow the [Installation](https://nvflare.readthedocs.io/en/main/quickstart.html) instructions to install NVFlare. + Install additional requirements: ``` @@ -18,6 +19,12 @@ pip3 install tensorflow ### 2. Run the experiment +Prepare the data first: + +``` +bash ./prepare_data.sh +``` + Use nvflare simulator to run the hello-examples: ``` diff --git a/examples/hello-world/hello-cyclic/prepare_data.sh b/examples/hello-world/hello-cyclic/prepare_data.sh new file mode 100755 index 0000000000..a9bd5c7741 --- /dev/null +++ b/examples/hello-world/hello-cyclic/prepare_data.sh @@ -0,0 +1 @@ +python3 -c "from tensorflow.keras.datasets import mnist; mnist_data = mnist.load_data()" \ No newline at end of file diff --git a/examples/hello-world/hello-pt/README.md b/examples/hello-world/hello-pt/README.md index 2bb0c369ea..adabac1e53 100644 --- a/examples/hello-world/hello-pt/README.md +++ b/examples/hello-world/hello-pt/README.md @@ -11,7 +11,8 @@ You can follow the [hello_world notebook](../hello_world.ipynb) or the following ### 1. Install NVIDIA FLARE -Follow the [Installation](https://nvflare.readthedocs.io/en/main/quickstart.html) instructions. +Follow the [Installation](https://nvflare.readthedocs.io/en/main/quickstart.html) instructions to install NVFlare. + Install additional requirements: ``` @@ -20,6 +21,12 @@ pip3 install -r requirements.txt ### 2. Run the experiment +Prepare the data first: + +``` +bash ./prepare_data.sh +``` + Use nvflare simulator to run the hello-examples: ``` diff --git a/examples/hello-world/hello-pt/prepare_data.sh b/examples/hello-world/hello-pt/prepare_data.sh new file mode 100755 index 0000000000..b9bad39f45 --- /dev/null +++ b/examples/hello-world/hello-pt/prepare_data.sh @@ -0,0 +1,3 @@ +DATASET_ROOT="~/data" + +python3 -c "import torchvision.datasets as datasets; datasets.CIFAR10(root='${DATASET_ROOT}', train=True, download=True)" diff --git a/examples/hello-world/hello-tf2/README.md b/examples/hello-world/hello-tf2/README.md index aed48d5866..7a959c2a9e 100644 --- a/examples/hello-world/hello-tf2/README.md +++ b/examples/hello-world/hello-tf2/README.md @@ -4,13 +4,14 @@ Example of using [NVIDIA FLARE](https://nvflare.readthedocs.io/en/main/index.htm using federated averaging ([FedAvg](https://arxiv.org/abs/1602.05629)) and [TensorFlow](https://tensorflow.org/) as the deep learning training framework. -> **_NOTE:_** This example uses the [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset and will load its data within the trainer code. +> **_NOTE:_** This example uses the [MNIST](http://yann.lecun.com/exdb/mnist/) handwritten digits dataset and will load its data within the trainer code. You can follow the [hello_world notebook](../hello_world.ipynb) or the following: ### 1. Install NVIDIA FLARE -Follow the [Installation](https://nvflare.readthedocs.io/en/main/quickstart.html) instructions. +Follow the [Installation](https://nvflare.readthedocs.io/en/main/quickstart.html) instructions to install NVFlare. + Install additional requirements: ``` @@ -19,6 +20,12 @@ pip3 install tensorflow ### 2. Run the experiment +Prepare the data first: + +```bash +bash ./prepare_data.sh +``` + Use nvflare simulator to run the hello-examples: ``` diff --git a/examples/hello-world/hello-tf2/prepare_data.sh b/examples/hello-world/hello-tf2/prepare_data.sh new file mode 100755 index 0000000000..a9bd5c7741 --- /dev/null +++ b/examples/hello-world/hello-tf2/prepare_data.sh @@ -0,0 +1 @@ +python3 -c "from tensorflow.keras.datasets import mnist; mnist_data = mnist.load_data()" \ No newline at end of file diff --git a/nvflare/app_opt/lightning/api.py b/nvflare/app_opt/lightning/api.py index e692d62acf..d34025e913 100644 --- a/nvflare/app_opt/lightning/api.py +++ b/nvflare/app_opt/lightning/api.py @@ -72,7 +72,7 @@ def __init__(self, rank: int = 0, load_state_dict_strict: bool = True): """ super(FLCallback, self).__init__() init(rank=str(rank)) - self.train_with_evaluation = get_config().get(ConfigKey.TRAIN_WITH_EVAL, False) + self.train_with_evaluation = get_config().get(ConfigKey.TASK_EXCHANGE, {}).get(ConfigKey.TRAIN_WITH_EVAL, False) self.current_round = None self.metrics = None self.total_local_epochs = 0