Skip to content
This repository was archived by the owner on Jan 6, 2023. It is now read-only.

Commit 4eb34bf

Browse files
Kiuk Chungfacebook-github-bot
Kiuk Chung
authored andcommitted
move 0.2.0 design doc from issue #66 to the repo
Summary: Moving #66 to the `design` directory. Reviewed By: drdarshan Differential Revision: D21100874 fbshipit-source-id: 7510479bccaa3bbab25994892fd0da99d0a1a0fe
1 parent 7d8e67c commit 4eb34bf

File tree

3 files changed

+142
-0
lines changed

3 files changed

+142
-0
lines changed
+142
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
# Introduction
2+
PyTorch Elastic Trainer (PET) provides a framework for conveniently training
3+
models across a compute cluster in a _fault tolerant_ and _elastic_ manner.
4+
PET provides these features in two ways:
5+
6+
1. When a PyTorch worker process throws a certain class of retriable errors, it is caught by PET and the training process is retried.
7+
2. A new worker can leave or join the process pool for an existing training job at any point as long as the number of workers stays within the bounds specified when starting the job. When a membership change happens, all the workers re-rendezvous to establish a new process group and training resumes from the previous well-known good state.
8+
9+
In order to integrate with PET, a PyTorch user needs to make the following
10+
changes to their training logic:
11+
12+
1. They need to enable PET to control their training loop.
13+
Essentially, they provide an "inner training" loop that is wrapped in a
14+
retryable loop by PET. All aspects of establishing or re-establishing the
15+
process group as well as restoring the user's trainer to a known good state
16+
is handled by the retryable PET loop.
17+
2. They need to specify _what_ the state is that needs to be restored in case
18+
a new worker joins the pool and _how_ the state is applied to a new worker.
19+
The API for specifying these is described by the `State` object.
20+
21+
PET v.0.1 was released on GitHub, PyPI and Docker Hub in November 2019 and since
22+
then the community has contributed integrations with Amazon Web Services
23+
(via Elastic Kubernetes Service) and Microsoft Azure (via Azure Kubernetes Service).
24+
25+
# Lessons learned from PET v0.1
26+
In porting existing PyTorch-based projects such as
27+
[ClassyVision](https://github.com/facebookresearch/ClassyVision) and
28+
[PyText](https://github.com/facebookresearch/pytext) to use PET, we encountered
29+
a few areas for refinement in the v0.1 design.
30+
31+
**First**, adapting a mature training library such as ClassyVision to use the
32+
elastic training APIs often requires a significant amount of restructuring often
33+
causing bifurcation of code paths between the elastic and non-elastic implementations.
34+
35+
**Second**, it is non-trivial to correctly implement the state restore logic for
36+
each application during in-process recovery. While explicit state such as weight
37+
tensors are easy to save and restore, there is often "hidden" or implicit state
38+
in the application that is hard for the developer to reason about. For example,
39+
after a rendezvous round, a worker process might be expected to restore the state
40+
of C++ objects either in CPU or GPU memory which are extremely error-prone,
41+
especially after failures or exceptions. To compound this issue, several
42+
applications such as PyText already implement some form of checkpoint/restart and
43+
this logic often needs to be taken into account when implementing the elastic state.
44+
45+
46+
**Finally**, one of the goals of PET v0.1 was to detect and restart straggler workers.
47+
This was not possible when running the training loop in process and necessitated
48+
writing an additional watchdog process to monitor the main training process.
49+
50+
For the next iteration of PET, we would like to propose a design that makes it
51+
significantly simpler to port existing training workflows to an elastic
52+
infrastructure and results in applications that can recover more reliably
53+
from workflow failures.
54+
55+
# Overview of the new design
56+
In PET v.0.2, _we no longer attempt to recover errors in the training function_.
57+
Instead, PET attempts to maintain the number of worker processes such that they
58+
stay within the \[_min_, _max_\] bounds required for the job.
59+
The application writer is responsible for loading and restarting from an existing
60+
checkpoint file is available. Unlike v0.1, PET v0.2 does not mandate how
61+
checkpoints are managed. An application writer is free to use just `torch.save`
62+
and `torch.load` from PyTorch or a higher-level framework such as
63+
[PyTorch Lightening](https://github.com/PyTorchLightning/pytorch-lightning).
64+
65+
PET v0.2 is implemented using a new process named `elastic-agent`.
66+
There is a single `elastic-agent` per job, per node. Each agent process is only
67+
responsible for managing a set of worker process local to that node and coordinating
68+
process group membership changes with elastic agents on other nodes allocated to
69+
that job. This is illustrated in the diagram below:
70+
71+
![image](torchelastic_diagram.jpg)
72+
73+
Membership changes are handled as followed: When a worker process fails,
74+
the corresponding elastic agent managing it kills all the workers on that node,
75+
establishes rendezvous with the other agents and restarts workers with the new
76+
rendezvous information. However, when an agent exits with a non-zero error code,
77+
it is up to a higher-level orchestrator such as Kubernetes to restart the agent
78+
(which in turn will restart all the workers it is responsible for).
79+
The same recovery mechanism holds for node-level failures.
80+
An orchestrator such as Kubernetes will schedule a job such that a minimum replicas
81+
of the elastic agent are running and each agent will in turn orchestrate the
82+
user's training script.
83+
84+
![image](torchelastic_agent_diagram.jpg)
85+
86+
To adopt PET v0.2, an application simply needs its entry-point or `main` function
87+
to be compatible with the
88+
[PyTorch distributed launcher](https://github.com/pytorch/pytorch/blob/master/torch/distributed/launch.py).
89+
We expect distributed training jobs that are started via the distributed launcher
90+
to be seamlessly started via the elastic agent with none to minimal code changes.
91+
The only difference is that in the latter case, the application will be able to
92+
make progress in the presence of certain failures.
93+
94+
# Overview of the API
95+
As mentioned above, with PET v0.2, there is no separate library for a training
96+
application to integrate with. Instead, the user simply launches a training job
97+
via the elastic agent monitor process. For example, if a user starts their job
98+
using PyTorch distributed launcher using:
99+
```sh
100+
python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_ON_NODE
101+
TRAINING_SCRIPT.py (... train script args ...)
102+
```
103+
they would instead use:
104+
105+
```sh
106+
python -m torchelastic.distributed.launch --nproc_per_node=NUM_GPUS_ON_NODE
107+
--nnodes=1:4
108+
--rdzv_id=JOB_ID
109+
--rdzv_backend=etcd
110+
--rdzv_endpoint=ETCD_HOST:ETCD_PORT
111+
TRAINING_SCRIPT.py (... train script args ...)
112+
```
113+
Notice that it adds a few additional parameters:
114+
1. The min and max number of nodes. During a rendezvous, if the number of nodes
115+
drops below the specified threshold, the job is aborted.
116+
2. A rendezvous type and its configuration.
117+
118+
In side the training script, the only potential change the user needs to do is
119+
to make sure that they use environment variables to initialize the process group,
120+
i.e., create the process group as follows:
121+
```py
122+
import torch.distributed as dist
123+
124+
dist.init_process_group(init_method="env://", backend="gloo")
125+
# or
126+
dist.init_process_group(init_method="env://", backend="nccl")
127+
```
128+
129+
All the parameters for initializing the group (the world size, the numerical
130+
rank, the master address and port) are passed in as environment variables
131+
by the parent elastic agent.
132+
133+
The new PET design is intentionally "bare-bones": it trade-offs the granularity
134+
with which an application can recover for simplicity and robustness.
135+
In the future, we hope to provide more APIs for convenient checkpointing that a
136+
developer can optionally use for more efficient restart semantics.
137+
138+
# Implementation details and next steps
139+
An implementation of the above ideas is available in [PR #65](https://github.com/pytorch/elastic/pull/65).
140+
We encourage the community to give evaluate the new functionality and
141+
give us feedback on the trade-offs we have made in the design either in the PR
142+
or in this issue. We look forward to hearing from you!
Loading
Loading

0 commit comments

Comments
 (0)