Skip to content

Commit 1bed0b6

Browse files
wip
1 parent 11a6a1c commit 1bed0b6

File tree

2 files changed

+147
-25
lines changed

2 files changed

+147
-25
lines changed

src/iyokan_nt.cpp

Lines changed: 116 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,45 @@
88
#include <algorithm>
99

1010
namespace {
11-
void prioritizeTaskByRanku(const nt::TaskFinder& finder)
11+
12+
// Visit tasks in the network in topological order.
13+
template <class F>
14+
void visitTaskTopo(const nt::Network& net, F f)
1215
{
13-
// c.f. https://en.wikipedia.org/wiki/Heterogeneous_Earliest_Finish_Time
14-
// FIXME: Take communication costs into account
15-
// FIXME: Tune computation costs by dynamic measurements
16+
using namespace nt;
1617

18+
std::unordered_map<Task*, size_t> numReadyParents;
19+
std::queue<Task*> que;
20+
net.eachTask([&](Task* task) {
21+
numReadyParents.emplace(task, 0);
22+
if (task->areAllInputsReady())
23+
que.push(task);
24+
});
25+
while (!que.empty()) {
26+
Task* task = que.front();
27+
que.pop();
28+
f(task);
29+
for (Task* child : task->children()) {
30+
if (child->areAllInputsReady()) // false parent-child relationship
31+
continue;
32+
numReadyParents.at(child)++;
33+
assert(child->parents().size() >= numReadyParents.at(child));
34+
if (child->parents().size() == numReadyParents.at(child))
35+
que.push(child);
36+
}
37+
}
38+
}
39+
40+
// Visit tasks in the network in reversed topological order
41+
template <class F>
42+
void visitTaskRevTopo(const nt::Network& net, F f)
43+
{
1744
using namespace nt;
1845

1946
std::unordered_map<Task*, int>
2047
numReadyChildren; // task |-> # of ready children
2148
std::queue<Task*> que; // Initial tasks to be visited
22-
finder.eachTask([&](UID, Task* task) {
49+
net.eachTask([&](Task* task) {
2350
const std::vector<Task*>& children = task->children();
2451

2552
// Count the children that have no inputs to wait for
@@ -36,11 +63,60 @@ void prioritizeTaskByRanku(const nt::TaskFinder& finder)
3663
});
3764
assert(!que.empty());
3865

39-
size_t numPrioritizedTasks = 0;
4066
while (!que.empty()) {
4167
Task* task = que.front();
4268
que.pop();
69+
f(task);
70+
if (task->areAllInputsReady()) // The end of the travel
71+
continue;
72+
// Push parents into the queue if all of their children are ready
73+
for (Task* parent : task->parents()) {
74+
numReadyChildren.at(parent)++;
75+
assert(parent->children().size() >= numReadyChildren.at(parent));
76+
if (parent->children().size() == numReadyChildren.at(parent))
77+
que.push(parent);
78+
}
79+
}
80+
}
81+
82+
void prioritizeTaskByTopo(const nt::Network& net)
83+
{
84+
using namespace nt;
85+
86+
size_t numPrioritizedTasks = 0;
87+
visitTaskTopo(net, [&](Task* task) {
88+
// Calculate and set the priority for the task
89+
int pri = -1;
90+
if (!task->areAllInputsReady())
91+
for (Task* parent : task->parents())
92+
pri = std::max(pri, parent->priority());
93+
task->setPriority(pri + 1);
94+
numPrioritizedTasks++;
95+
});
4396

97+
if (net.size() > numPrioritizedTasks) {
98+
LOG_DBG << "net.size() " << net.size() << " != numPrioritizedTasks "
99+
<< numPrioritizedTasks;
100+
net.eachTask([&](Task* task) {
101+
const Label& l = task->label();
102+
if (task->priority() == -1)
103+
LOG_DBG << "\t" << l.uid << " " << l.kind << " ";
104+
});
105+
ERR_DIE("Invalid network; some nodes will not be executed.");
106+
}
107+
assert(net.size() == numPrioritizedTasks);
108+
}
109+
110+
void prioritizeTaskByRanku(const nt::Network& net)
111+
{
112+
// c.f. https://en.wikipedia.org/wiki/Heterogeneous_Earliest_Finish_Time
113+
// FIXME: Take communication costs into account
114+
// FIXME: Tune computation costs by dynamic measurements
115+
116+
using namespace nt;
117+
118+
size_t numPrioritizedTasks = 0;
119+
visitTaskRevTopo(net, [&](Task* task) {
44120
// Calculate and set the priority for the task
45121
int pri = 0;
46122
for (Task* child : task->children())
@@ -50,30 +126,21 @@ void prioritizeTaskByRanku(const nt::TaskFinder& finder)
50126
pri = std::max(pri, child->priority());
51127
task->setPriority(pri + task->getComputationCost());
52128
numPrioritizedTasks++;
129+
});
53130

54-
if (task->areAllInputsReady()) // The end of the travel
55-
continue;
56-
57-
// Push parents into the queue if all of their children are ready
58-
for (Task* parent : task->parents()) {
59-
numReadyChildren.at(parent)++;
60-
assert(parent->children().size() >= numReadyChildren.at(parent));
61-
if (parent->children().size() == numReadyChildren.at(parent))
62-
que.push(parent);
63-
}
64-
}
65-
if (finder.size() > numPrioritizedTasks) {
66-
LOG_DBG << "finder.size() " << finder.size()
67-
<< " != numPrioritizedTasks " << numPrioritizedTasks;
68-
finder.eachTask([&](UID, Task* task) {
131+
if (net.size() > numPrioritizedTasks) {
132+
LOG_DBG << "net.size() " << net.size() << " != numPrioritizedTasks "
133+
<< numPrioritizedTasks;
134+
net.eachTask([&](Task* task) {
69135
const Label& l = task->label();
70136
if (task->priority() == -1)
71137
LOG_DBG << "\t" << l.uid << " " << l.kind << " ";
72138
});
73139
ERR_DIE("Invalid network; some nodes will not be executed.");
74140
}
75-
assert(finder.size() == numPrioritizedTasks);
141+
assert(net.size() == numPrioritizedTasks);
76142
}
143+
77144
} // namespace
78145

79146
namespace nt {
@@ -252,6 +319,27 @@ const TaskFinder& Network::finder() const
252319
return finder_;
253320
}
254321

322+
bool Network::checkIfValid() const
323+
{
324+
bool valid = true;
325+
eachTask([&](Task* task) {
326+
if (!task->checkIfValid())
327+
valid = false;
328+
});
329+
330+
// Check if the network is weekly connected
331+
size_t numConnectedTasks = 0;
332+
visitTaskTopo(*this, [&](Task*) { numConnectedTasks++; });
333+
if (numConnectedTasks != size()) {
334+
LOG_S(ERROR) << "The network is not weekly connected i.e., there are "
335+
"some nodes that cannot be visited; numConnectedTasks "
336+
<< numConnectedTasks << " != size " << size();
337+
valid = false;
338+
}
339+
340+
return valid;
341+
}
342+
255343
/* class NetworkBuilder */
256344

257345
NetworkBuilder::NetworkBuilder(Allocator& alc)
@@ -474,16 +562,19 @@ void Frontend::buildNetwork(NetworkBuilder& nb)
474562

475563
// Create the network from the builder
476564
network_.emplace(nb.createNetwork());
477-
// FIXME check if network is valid
565+
566+
// Check if network is valid
567+
if (!network_->checkIfValid())
568+
ERR_DIE("Network is not valid");
478569

479570
// Set priority to each task
480571
switch (pr_.sched) {
481572
case SCHED::TOPO:
482-
ERR_DIE("Scheduling topo is not supported anymore"); // FIXME
573+
prioritizeTaskByTopo(network_.value());
483574
break;
484575

485576
case SCHED::RANKU:
486-
prioritizeTaskByRanku(network_->finder());
577+
prioritizeTaskByRanku(network_.value());
487578
break;
488579
}
489580
}

src/iyokan_nt.hpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,9 @@ class Task {
116116
// Get computation cost of this task. Used for scheduling of tasks.
117117
virtual int getComputationCost() const;
118118

119+
// Check if the task is valid. Returns true iff it is valid.
120+
virtual bool checkIfValid() const = 0;
121+
119122
virtual void notifyOneInputReady() = 0;
120123
virtual bool areAllInputsReady() const = 0;
121124
virtual bool hasFinished() const = 0;
@@ -207,6 +210,30 @@ class TaskCommon : public Task {
207210
{
208211
}
209212

213+
virtual bool checkIfValid() const override
214+
{
215+
assert(output_);
216+
217+
bool valid = true;
218+
if (getInputSize() < numMinExpectedInputs_) {
219+
LOG_S(ERROR) << "Input size < min. expected size: "
220+
<< getInputSize() << " < " << numMinExpectedInputs_;
221+
valid = false;
222+
}
223+
if (getInputSize() > numMaxExpectedInputs_) {
224+
LOG_S(ERROR) << "Input size > max. expected size: "
225+
<< getInputSize() << " > " << numMaxExpectedInputs_;
226+
valid = false;
227+
}
228+
if (getInputSize() != parents().size()) {
229+
LOG_S(ERROR) << "Input size != parents size: " << getInputSize()
230+
<< " != " << parents().size();
231+
valid = false;
232+
}
233+
234+
return valid;
235+
}
236+
210237
virtual void notifyOneInputReady() override
211238
{
212239
numReadyInputs_++;
@@ -311,6 +338,10 @@ class Network {
311338
size_t size() const;
312339
const TaskFinder& finder() const;
313340

341+
// Check if the network is valid. Print error messages if necessary. Returns
342+
// true iff it is valid.
343+
bool checkIfValid() const;
344+
314345
template <class F>
315346
void eachTask(F f) const
316347
{

0 commit comments

Comments
 (0)