Skip to content
This repository was archived by the owner on Nov 28, 2025. It is now read-only.

Commit 598bb4d

Browse files
committed
Adding example to train the resnet56 model using MultiworkerMirroredTraining example on the cifar-10 dataset
1 parent 2fb34e4 commit 598bb4d

File tree

5 files changed

+670
-6
lines changed

5 files changed

+670
-6
lines changed

distribution_strategy/multi_worker_mirrored_strategy/README.md

Lines changed: 115 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
The steps below are meant to train models using [MultiWorkerMirrored Strategy](https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/MultiWorkerMirroredStrategy) using the tensorflow 2.x API on the Kubernetes platform.
55

66
Reference programs such as [keras_mnist.py](examples/keras_mnist.py) and
7-
[custom_training_mnist.py](examples/custom_training_mnist.py) are available in the examples directory.
7+
[custom_training_mnist.py](examples/custom_training_mnist.py) and [keras_resnet_cifar.py](examples/keras_resnet_cifar.py) are available in the examples directory.
88

99
The Kubernetes manifest templates and other cluster specific configuration is available in the [kubernetes](kubernetes) directory
1010

@@ -28,14 +28,39 @@ here are instructions to [create GKE clusters](https://cloud.google.com/kubernet
2828

2929
5. Install [Docker](https://docs.docker.com/get-docker/) for your system, while also creating an account that you can associate with your container images.
3030

31-
6. For model storage and checkpointing, a [persistent-volume-claim](https://kubernetes.io/docs/concepts/storage/persistent-volumes/) needs to be available to mount onto the chief worker pod. The steps below include the yaml to create a persistent-volume-claim for GKE backed by GCEPersistentDisk.
31+
6. For the mnist examples, for model storage and checkpointing, a [persistent-volume-claim](https://kubernetes.io/docs/concepts/storage/persistent-volumes/) needs to be available to mount onto the chief worker pod. The steps below include the yaml to create a persistent-volume-claim for GKE backed by GCEPersistentDisk.
3232

33-
### Steps to Run the job
33+
### Additional prerequisites for resnet56 example
34+
35+
1. Create a
36+
[service account](https://cloud.google.com/compute/docs/access/service-accounts)
37+
and download its key file in JSON format. Assign Storage Admin role for
38+
[Google Cloud Storage](https://cloud.google.com/storage/) to this service account:
39+
40+
```bash
41+
gcloud iam service-accounts create <service_account_id> --display-name="<display_name>"
42+
```
43+
44+
```bash
45+
gcloud projects add-iam-policy-binding <project-id> \
46+
--member="serviceAccount:<service_account_id>@<project_id>.iam.gserviceaccount.com" \
47+
--role="roles/storage.admin"
48+
```
49+
2. Create a Kubernetes secret from the JSON key file of your service account:
50+
51+
```bash
52+
kubectl create secret generic credential --from-file=key.json=<path_to_json_file>
53+
```
54+
55+
3. For GPU based training, ensure your kubernetes cluster has a node-pool with gpu enabled.
56+
The steps to achieve this on GKE are available [here](https://cloud.google.com/kubernetes-engine/docs/how-to/gpus)
57+
58+
## Steps to train mnist examples
3459

3560
1. Follow the instructions for building and pushing the Docker image to a docker registry
3661
in the [Docker README](examples/README.md).
3762

38-
2. Copy the template file:
63+
2. Copy the template file `MultiWorkerMirroredTemplate.yaml.jinja`:
3964

4065
```sh
4166
cp kubernetes/MultiWorkerMirroredTemplate.yaml.jinja myjob.template.jinja
@@ -114,4 +139,89 @@ here are instructions to [create GKE clusters](https://cloud.google.com/kubernet
114139
kubectl -n <namspace> exec --stdin --tty <volume-inspector-pod> -- /bin/sh
115140
```
116141

117-
The contents of the trained model are available for inspection at `model_checkpoint_dir`.
142+
The contents of the trained model are available for inspection at `model_checkpoint_dir`.
143+
144+
## Steps to train resnet examples
145+
146+
1. Follow the instructions for building and pushing the Docker image using `Dockerfile.gpu` to a docker registry
147+
in the [Docker README](examples/README.md).
148+
149+
2. Copy the template file `EnhancedMultiWorkerMirroredTemplate.yaml.jinja`
150+
151+
```sh
152+
cp kubernetes/EnhancedMultiWorkerMirroredTemplate.yaml.jinja myjob.template.jinja
153+
```
154+
3. Create three buckets for model data, checkpoints and training logs using either GCP web UI or gsutil tool (included with the gcloud tool you have installed above):
155+
156+
```bash
157+
gsutil mb gs://<bucket_name>
158+
```
159+
You will use these bucket names to modify `data_dir`, `log_dir` and `model_dir` in step #4.
160+
161+
162+
4. Download CIFAR-10 data and place them in your data_dir bucket. Head to the [ResNet in TensorFlow](https://github.com/tensorflow/models/tree/r1.13.0/official/resnet#cifar-10) directory to obtain CIFAR-10 data. Alternatively, you can use this [direct link](https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz) to download and extract the data yourself as well.
163+
164+
```bash
165+
python cifar10_download_and_extract.py
166+
```
167+
168+
Upload the contents of cifar-10-batches-bin directory to your `data_dir` bucket.
169+
170+
```bash
171+
gsutil -m cp cifar-10-batches-bin/* gs://<your_data_dir>/
172+
```
173+
174+
5. Edit the `myjob.template.jinja` file to edit job parameters.
175+
1. `script` - which training program needs to be run. This should be either
176+
`keras_resnet_cifar.py` or `your_own_training_example.py`
177+
178+
2. `name` - the prefix attached to all the Kubernetes jobs created
179+
180+
3. `worker_replicas` - number of parallel worker processes that train the example
181+
182+
4. `port` - the port used by tensorflow worker processes to communicate with each other.
183+
184+
5. `model_dir` - the GCP bucket path that stores the model checkoints `gs://model_dir/`
185+
186+
6. `image` - name of the docker image created in step 2 that needs to be loaded onto the cluster
187+
188+
7. `log_dir` - the GCP bucket path that where the logs are stored `gs://log_dir/`
189+
190+
8. `data_dir` - the GCP bucket path for the Cifar-10 dataset `gs://data_dir/`
191+
192+
9. `gcp_credential_secret` - the name of secret created in the kubernetes cluster that contains the service Account credentials
193+
194+
10. `batch_size` - the global batch size used for training
195+
196+
11. `num_train_epoch` - the number of training epochs
197+
198+
4. Run the job:
199+
1. Create a namespace to run your training jobs
200+
201+
```sh
202+
kubectl create namespace <namespace>
203+
```
204+
205+
2. Deploy the training workloads in the cluster
206+
207+
```sh
208+
python ../../render_template.py myjob.template.jinja | kubectl apply -n <namespace> -f -
209+
```
210+
211+
This will create the Kubernetes jobs on the clusters. Each Job has a single service-endpoint and a single pod that runs the training image. You can track the running jobs in the cluster by running
212+
213+
```sh
214+
kubectl get jobs -n <namespace>
215+
kubectl describe jobs -n <namespace>
216+
```
217+
218+
By default, this also deploys tensorboard on the cluster.
219+
220+
```sh
221+
kubectl get services -n <namespace> | grep tensorboard
222+
```
223+
224+
Note the external-ip corresponding to the service and the previously configured `port` in the yaml
225+
The tensorboard service should be accessible through the web at `http://tensorboard-external-ip:port`
226+
227+
3. The final model should be available in the GCP bucket corresponding to `model_dir` configured in the yaml
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
FROM tensorflow/tensorflow:2.3.1-gpu-jupyter
2+
3+
RUN apt-get install -y python3 && \
4+
apt install python3-pip
5+
6+
RUN pip3 install absl-py && \
7+
pip3 install portpicker
8+
9+
# Install git
10+
RUN apt-get update && \
11+
apt-get install -y git && \
12+
apt-get install -y vim
13+
14+
WORKDIR /app
15+
16+
RUN git clone --single-branch --branch benchmark https://github.com/tensorflow/models.git && \
17+
mv models tensorflow_models && \
18+
git clone https://github.com/tensorflow/model-optimization.git && \
19+
mv model-optimization tensorflow_model_optimization
20+
21+
# Keeps Python from generating .pyc files in the container
22+
ENV PYTHONDONTWRITEBYTECODE=1
23+
# Turns off buffering for easier container logging
24+
ENV PYTHONUNBUFFERED=1
25+
26+
COPY . /app/
27+
28+
ENV PYTHONPATH "${PYTHONPATH}:/:/app/tensorflow_models"
29+
30+
CMD ["python", "resnet_cifar_multiworker_strategy_keras.py"]

distribution_strategy/multi_worker_mirrored_strategy/examples/README.md

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
This directory contains examples of MultiWorkerMirrored Training along with the docker file to build them
44

55
- [Dockerfile](Dockerfile) contains all dependenices required to build a container image using docker with the training examples
6+
- [Dockerfile.gpu](Dockerfile.gpu) contains all dependenices required to build a container image using docker with gpu and the tensorflow model garden
67
- [keras_mnist.py](mnist.py) demonstrates how to train an MNIST classifier using
78
[tf.distribute.MultiWorkerMirroredStrategy and Keras Tensorflow 2.0 API](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras).
89
- [custom_training_mnist.py](mnist.py) demonstrates how to train a fashion MNIST classifier using
910
[tf.distribute.MultiWorkerMirroredStrategy and Tensorflow 2.0 Custom Training Loop APIs](https://www.tensorflow.org/tutorials/distribute/custom_training).
10-
11+
- [keras_resnet_cifar.py](keras_resnet_cifar.py) demonstrates how to train the resnet56 model on the Cifar-10 dataset using
12+
[tf.distribute.MultiWorkerMirroredStrategy and Keras Tensorflow 2.0 API](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras).
1113
## Best Practices
1214

1315
- Always pin the TensorFlow version with the Docker image tag. This ensures that
@@ -51,3 +53,10 @@ The [custom_training_mnist.py](mnist.py) example demonstrates how to train a fas
5153
[tf.distribute.MultiWorkerMirroredStrategy and Tensorflow 2.0 Custom Training Loop APIs](https://www.tensorflow.org/tutorials/distribute/custom_training).
5254
The final model is saved to disk by the chief worker process. The disk is assumed to be mounted onto the running container by the cluster manager.
5355
It assumes that the cluster configuration is passed in through the `TF_CONFIG` environment variable when deployed in the cluster.
56+
57+
## Running the keras_resnet_cifar.py example
58+
59+
The [keras_resnet_cifar.py](keras_resnet_cifar.py) example demonstrates how to train a Resnet56 model on the cifar-10 dataset using
60+
[tf.distribute.MultiWorkerMirroredStrategy and Keras Tensorflow 2.0 API](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras).
61+
The final model is saved to the GCP storage bucket.
62+
It assumes that the cluster configuration is passed in through the `TF_CONFIG` environment variable when deployed in the cluster.

0 commit comments

Comments
 (0)