forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
engine.h
285 lines (242 loc) · 10.1 KB
/
engine.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
#pragma once
// Engine implements backpropagation from output variables and their gradients
// to "root" variables (variables created by the user with requires_grad=True).
#include <ATen/Tensor.h>
#include <ATen/ThreadLocalState.h>
#include <ATen/core/ivalue.h>
#include <torch/csrc/Export.h>
#include <torch/csrc/autograd/anomaly_mode.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/functions/basic_ops.h>
#include <torch/csrc/autograd/graph_task.h>
#include <torch/csrc/autograd/input_buffer.h>
#include <torch/csrc/autograd/saved_variable_hooks.h>
#include <torch/csrc/autograd/utils/warnings.h>
#include <c10/util/CallOnce.h>
#include <deque>
#include <exception>
#include <functional>
#include <memory>
#include <queue>
#include <thread>
#include <unordered_map>
#include <utility>
#include <vector>
namespace torch {
namespace autograd {
struct ReadyQueue;
}
} // namespace torch
namespace torch {
namespace autograd {
// Maximum reentrant backward depth before switching to a new thread
// This limit is based on the TSAN's deadlock detector, where it will
// fail if a program hold more than 65 locks in one thread at once.
// As we hold mutex in every of our custom C++ autograd Node, we would
// like to avoid TSAN complains on this when doing reentrant backwards
// For reference, see https://github.com/google/sanitizers/issues/950
static constexpr int MAX_DEPTH = 60;
void set_device(int device);
void validate_outputs(
const edge_list& edges,
variable_list& grads,
const std::function<std::string(const std::string&)>& format_error);
struct NodeTask {
std::weak_ptr<GraphTask> base_;
std::shared_ptr<Node> fn_;
// This buffer serves as an implicit "addition" node for all of the
// gradients flowing here. Once all the dependencies are finished, we
// use the contents of this buffer to run the function.
InputBuffer inputs_;
// When worker receives a task with isShutdownTask = true, it will immediately
// exit. The engine sends a shutdown task to every queue upon its destruction.
bool isShutdownTask_;
int getReentrantDepth() const;
NodeTask(
std::weak_ptr<GraphTask> base,
std::shared_ptr<Node> fn,
InputBuffer inputs,
bool isShutdownTask = false)
: base_(std::move(base)),
fn_(std::move(fn)),
inputs_(std::move(inputs)),
isShutdownTask_(isShutdownTask) {}
};
// Guard that sets and restores checkpoint_valid
class CheckpointValidGuard {
public:
explicit CheckpointValidGuard(
const std::shared_ptr<const GraphTask>& graph_task);
~CheckpointValidGuard();
private:
bool prev_checkpoint_valid_state;
};
struct ReadyQueue {
private:
// Returns true when t2 should be (weakly) BEFORE t1 in the queue.
// Shutdown tasks are first and then empty NodeTask are next.
struct CompareNodeTaskTime {
bool operator()(NodeTask const& t1, NodeTask const& t2) {
// NOLINTNEXTLINE(bugprone-branch-clone)
if (t2.isShutdownTask_) {
return true;
} else if (!t1.fn_ || t1.isShutdownTask_) {
return false;
} else if (!t2.fn_) {
return true;
} else if (t1.getReentrantDepth() == t2.getReentrantDepth()) {
return t1.fn_->sequence_nr() < t2.fn_->sequence_nr();
} else {
return t1.getReentrantDepth() < t2.getReentrantDepth();
}
}
};
// To notify threads waiting on the ReadyQueue of available tasks on the heap_
std::condition_variable not_empty_;
// To protect read and writes to heap_
mutable std::mutex mutex_;
std::priority_queue<NodeTask, std::vector<NodeTask>, CompareNodeTaskTime>
heap_;
public:
// incrementOutstandingTasks indicates whether or not we should increment
// 'outstanding_tasks_' for the associated GraphTask. This should mostly
// always be true and is only set false in certain cases (see docs for
// DistEngine.execute_graph_task_until_ready_queue_empty)
void push(NodeTask item, bool incrementOutstandingTasks = true);
void pushShutdownTask();
NodeTask pop();
bool empty() const;
size_t size() const;
};
// A single instance of this struct should be created through the whole process
// lifetime. The worker thread creation logic and Engine's destructor rely on
// this.
struct TORCH_API Engine {
/// Returns a reference to a static `Engine` instance.
static Engine& get_default_engine();
static Engine& get_base_engine();
Engine(const Engine&) = delete;
Engine(Engine&&) = delete;
virtual ~Engine();
// Given a list of (Node, input number) pairs computes the value of the graph
// by following next_edge references.
virtual variable_list execute(
const edge_list& roots,
const variable_list& inputs,
bool keep_graph,
bool create_graph,
bool accumulate_grad,
const edge_list& outputs = {});
// Given a pre-populated GraphTask and GraphRoot, computes the backward pass
// for the graph.
//
// NB: This API should only be used by internal autograd specific
// machinery and shouldn't be exposed to users in anyway.
virtual c10::intrusive_ptr<at::ivalue::Future> execute_with_graph_task(
const std::shared_ptr<GraphTask>& graph_task,
std::shared_ptr<Node> graph_root,
InputBuffer&& input_buffer);
virtual std::unique_ptr<AnomalyMetadata> make_anomaly_metadata() {
return std::make_unique<AnomalyMetadata>();
}
virtual std::unique_ptr<SavedVariableHooks> get_default_saved_variable_hooks() {
return nullptr;
}
// We pass cpu_ready_queue to evaluate_function, so that it knows
// the correct ready queue to push to after a NodeTask is ready
void evaluate_function(
std::shared_ptr<GraphTask>& graph_task,
Node* func,
InputBuffer& inputs,
const std::shared_ptr<ReadyQueue>& cpu_ready_queue);
void initialize_device_threads_pool();
virtual void thread_on_exception(
std::shared_ptr<GraphTask> graph_task,
const std::shared_ptr<Node>& fn,
std::exception& e);
void queue_callback(std::function<void()> callback);
bool is_checkpoint_valid();
// Should be called after fork to notify that worker threads are gone
void release_workers();
// Must be called by subclass before destructing to avoid a data-race-on-vptr.
void stop();
// Initializes a device thread for the autograd engine.
virtual void thread_init(
int device,
const std::shared_ptr<ReadyQueue>& ready_queue,
bool should_increment = true);
protected:
Engine();
void compute_dependencies(Node* root, GraphTask& task, uint64_t min_topo_nr);
// initialize the thread local ready queue with the ready queue that is
// created elsewhere (i.e. thread_init, Engine::execute, etc), or create a new
// ready queue if ready_queue is not provided.
void init_local_ready_queue(
std::shared_ptr<ReadyQueue> ready_queue = nullptr);
std::shared_ptr<ReadyQueue> ready_queue(
std::shared_ptr<ReadyQueue> cpu_ready_queue,
at::Device device);
std::shared_ptr<ReadyQueue> ready_queue_by_index(
std::shared_ptr<ReadyQueue> cpu_ready_queue,
int device_index);
// start device threads (CUDA, XLA, etc.) in Engine,
// note that it does NOT start CPU thread.
void start_device_threads();
void increment_non_reentrant_thread_count();
void decrement_non_reentrant_thread_count();
virtual void thread_main(const std::shared_ptr<GraphTask>& task);
void reentrant_thread_init();
void add_thread_pool_task(const std::weak_ptr<GraphTask>& graph_task);
// Ensures device_ready_queues_ are initialized only once
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
c10::once_flag start_device_threads_flag_;
// Safe to read device_ready_queues_ without synchronization after
// initialization
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
std::vector<std::shared_ptr<ReadyQueue>> device_ready_queues_;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
std::vector<std::function<void()>> final_callbacks_;
// To protect reads and writes to final_callbacks_
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
std::mutex post_callbacks_lock_;
// How many nested reentrant calls are allowed until a new thread is used
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
int max_recursion_depth_;
struct ThreadPoolShared {
// Data structures used by the threads for executing reentrant backwards
// tasks. See Note [Reentrant backwards]
// Number of available threads for processing new GraphTasks.
unsigned int num_workers_{0};
// The threads will wait on work_ to be notified of GraphTasks
std::condition_variable work_;
// To protect reads and writes to graphtask_queue_ and num_workers_
// and for synchronizing creating new threads when needed
std::mutex mutex_;
// Workers will process the GraphTasks added to this queue. A GraphTask is
// allocated inside Engine::execute and lives for the duration of execute
std::queue<std::weak_ptr<GraphTask>> graphtasks_queue_;
ThreadPoolShared() = default;
};
// Temporary workaround until shutting down threads is done
// We need shared ownership of all these objects because the threads are
// leaked when Engine shuts down, so there may be threads waiting on work_ for
// the graphtasks_queue_ to be nonempty.
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
std::shared_ptr<ThreadPoolShared> thread_pool_shared_;
private:
// Number of non-reentrant threads
std::atomic<uint32_t> non_reentrant_device_thread_count_;
// Destructor will wait for non-reentrant threads to finish
std::condition_variable non_reentrant_device_thread_condvar_;
std::mutex non_reentrant_device_thread_mutex_;
// stop() must be called before the destruction path goes down to the base
// class, in order to avoid a data-race-on-vptr. Use this boolean to guard
// whether stop() has already been called, so we can call this in every
// destructor of the class hierarchy.
bool stopped_{false};
};
// allow python_engine to override the default engine when it loads
using EngineStub = Engine& (*)();
TORCH_API void set_default_engine_stub(EngineStub stub);
} // namespace autograd
} // namespace torch