Skip to content

Commit

Permalink
Merge branch 'main' into enable_extra_information_in_params_converter
Browse files Browse the repository at this point in the history
  • Loading branch information
SYangster authored Dec 20, 2023
2 parents 730e1ef + 79b253a commit a72375b
Show file tree
Hide file tree
Showing 7 changed files with 31 additions and 5 deletions.
9 changes: 8 additions & 1 deletion examples/hello-world/hello-cyclic/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ You can follow the [hello_world notebook](../hello_world.ipynb) or the following

### 1. Install NVIDIA FLARE

Follow the [Installation](https://nvflare.readthedocs.io/en/main/quickstart.html) instructions.
Follow the [Installation](https://nvflare.readthedocs.io/en/main/quickstart.html) instructions to install NVFlare.

Install additional requirements:

```
Expand All @@ -18,6 +19,12 @@ pip3 install tensorflow

### 2. Run the experiment

Prepare the data first:

```
bash ./prepare_data.sh
```

Use nvflare simulator to run the hello-examples:

```
Expand Down
1 change: 1 addition & 0 deletions examples/hello-world/hello-cyclic/prepare_data.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python3 -c "from tensorflow.keras.datasets import mnist; mnist_data = mnist.load_data()"
9 changes: 8 additions & 1 deletion examples/hello-world/hello-pt/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ You can follow the [hello_world notebook](../hello_world.ipynb) or the following

### 1. Install NVIDIA FLARE

Follow the [Installation](https://nvflare.readthedocs.io/en/main/quickstart.html) instructions.
Follow the [Installation](https://nvflare.readthedocs.io/en/main/quickstart.html) instructions to install NVFlare.

Install additional requirements:

```
Expand All @@ -20,6 +21,12 @@ pip3 install -r requirements.txt

### 2. Run the experiment

Prepare the data first:

```
bash ./prepare_data.sh
```

Use nvflare simulator to run the hello-examples:

```
Expand Down
3 changes: 3 additions & 0 deletions examples/hello-world/hello-pt/prepare_data.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
DATASET_ROOT="~/data"

python3 -c "import torchvision.datasets as datasets; datasets.CIFAR10(root='${DATASET_ROOT}', train=True, download=True)"
11 changes: 9 additions & 2 deletions examples/hello-world/hello-tf2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@ Example of using [NVIDIA FLARE](https://nvflare.readthedocs.io/en/main/index.htm
using federated averaging ([FedAvg](https://arxiv.org/abs/1602.05629))
and [TensorFlow](https://tensorflow.org/) as the deep learning training framework.

> **_NOTE:_** This example uses the [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset and will load its data within the trainer code.
> **_NOTE:_** This example uses the [MNIST](http://yann.lecun.com/exdb/mnist/) handwritten digits dataset and will load its data within the trainer code.
You can follow the [hello_world notebook](../hello_world.ipynb) or the following:

### 1. Install NVIDIA FLARE

Follow the [Installation](https://nvflare.readthedocs.io/en/main/quickstart.html) instructions.
Follow the [Installation](https://nvflare.readthedocs.io/en/main/quickstart.html) instructions to install NVFlare.

Install additional requirements:

```
Expand All @@ -19,6 +20,12 @@ pip3 install tensorflow

### 2. Run the experiment

Prepare the data first:

```bash
bash ./prepare_data.sh
```

Use nvflare simulator to run the hello-examples:

```
Expand Down
1 change: 1 addition & 0 deletions examples/hello-world/hello-tf2/prepare_data.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python3 -c "from tensorflow.keras.datasets import mnist; mnist_data = mnist.load_data()"
2 changes: 1 addition & 1 deletion nvflare/app_opt/lightning/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __init__(self, rank: int = 0, load_state_dict_strict: bool = True):
"""
super(FLCallback, self).__init__()
init(rank=str(rank))
self.train_with_evaluation = get_config().get(ConfigKey.TRAIN_WITH_EVAL, False)
self.train_with_evaluation = get_config().get(ConfigKey.TASK_EXCHANGE, {}).get(ConfigKey.TRAIN_WITH_EVAL, False)
self.current_round = None
self.metrics = None
self.total_local_epochs = 0
Expand Down

0 comments on commit a72375b

Please sign in to comment.