88#include < algorithm>
99
1010namespace {
11- void prioritizeTaskByRanku (const nt::TaskFinder& finder)
11+
12+ // Visit tasks in the network in topological order.
13+ template <class F >
14+ void visitTaskTopo (const nt::Network& net, F f)
1215{
13- // c.f. https://en.wikipedia.org/wiki/Heterogeneous_Earliest_Finish_Time
14- // FIXME: Take communication costs into account
15- // FIXME: Tune computation costs by dynamic measurements
16+ using namespace nt ;
1617
18+ std::unordered_map<Task*, size_t > numReadyParents;
19+ std::queue<Task*> que;
20+ net.eachTask ([&](Task* task) {
21+ numReadyParents.emplace (task, 0 );
22+ if (task->areAllInputsReady ())
23+ que.push (task);
24+ });
25+ while (!que.empty ()) {
26+ Task* task = que.front ();
27+ que.pop ();
28+ f (task);
29+ for (Task* child : task->children ()) {
30+ if (child->areAllInputsReady ()) // false parent-child relationship
31+ continue ;
32+ numReadyParents.at (child)++;
33+ assert (child->parents ().size () >= numReadyParents.at (child));
34+ if (child->parents ().size () == numReadyParents.at (child))
35+ que.push (child);
36+ }
37+ }
38+ }
39+
40+ // Visit tasks in the network in reversed topological order
41+ template <class F >
42+ void visitTaskRevTopo (const nt::Network& net, F f)
43+ {
1744 using namespace nt ;
1845
1946 std::unordered_map<Task*, int >
2047 numReadyChildren; // task |-> # of ready children
2148 std::queue<Task*> que; // Initial tasks to be visited
22- finder .eachTask ([&](UID, Task* task) {
49+ net .eachTask ([&](Task* task) {
2350 const std::vector<Task*>& children = task->children ();
2451
2552 // Count the children that have no inputs to wait for
@@ -36,11 +63,60 @@ void prioritizeTaskByRanku(const nt::TaskFinder& finder)
3663 });
3764 assert (!que.empty ());
3865
39- size_t numPrioritizedTasks = 0 ;
4066 while (!que.empty ()) {
4167 Task* task = que.front ();
4268 que.pop ();
69+ f (task);
70+ if (task->areAllInputsReady ()) // The end of the travel
71+ continue ;
72+ // Push parents into the queue if all of their children are ready
73+ for (Task* parent : task->parents ()) {
74+ numReadyChildren.at (parent)++;
75+ assert (parent->children ().size () >= numReadyChildren.at (parent));
76+ if (parent->children ().size () == numReadyChildren.at (parent))
77+ que.push (parent);
78+ }
79+ }
80+ }
81+
82+ void prioritizeTaskByTopo (const nt::Network& net)
83+ {
84+ using namespace nt ;
85+
86+ size_t numPrioritizedTasks = 0 ;
87+ visitTaskTopo (net, [&](Task* task) {
88+ // Calculate and set the priority for the task
89+ int pri = -1 ;
90+ if (!task->areAllInputsReady ())
91+ for (Task* parent : task->parents ())
92+ pri = std::max (pri, parent->priority ());
93+ task->setPriority (pri + 1 );
94+ numPrioritizedTasks++;
95+ });
4396
97+ if (net.size () > numPrioritizedTasks) {
98+ LOG_DBG << " net.size() " << net.size () << " != numPrioritizedTasks "
99+ << numPrioritizedTasks;
100+ net.eachTask ([&](Task* task) {
101+ const Label& l = task->label ();
102+ if (task->priority () == -1 )
103+ LOG_DBG << " \t " << l.uid << " " << l.kind << " " ;
104+ });
105+ ERR_DIE (" Invalid network; some nodes will not be executed." );
106+ }
107+ assert (net.size () == numPrioritizedTasks);
108+ }
109+
110+ void prioritizeTaskByRanku (const nt::Network& net)
111+ {
112+ // c.f. https://en.wikipedia.org/wiki/Heterogeneous_Earliest_Finish_Time
113+ // FIXME: Take communication costs into account
114+ // FIXME: Tune computation costs by dynamic measurements
115+
116+ using namespace nt ;
117+
118+ size_t numPrioritizedTasks = 0 ;
119+ visitTaskRevTopo (net, [&](Task* task) {
44120 // Calculate and set the priority for the task
45121 int pri = 0 ;
46122 for (Task* child : task->children ())
@@ -50,30 +126,21 @@ void prioritizeTaskByRanku(const nt::TaskFinder& finder)
50126 pri = std::max (pri, child->priority ());
51127 task->setPriority (pri + task->getComputationCost ());
52128 numPrioritizedTasks++;
129+ });
53130
54- if (task->areAllInputsReady ()) // The end of the travel
55- continue ;
56-
57- // Push parents into the queue if all of their children are ready
58- for (Task* parent : task->parents ()) {
59- numReadyChildren.at (parent)++;
60- assert (parent->children ().size () >= numReadyChildren.at (parent));
61- if (parent->children ().size () == numReadyChildren.at (parent))
62- que.push (parent);
63- }
64- }
65- if (finder.size () > numPrioritizedTasks) {
66- LOG_DBG << " finder.size() " << finder.size ()
67- << " != numPrioritizedTasks " << numPrioritizedTasks;
68- finder.eachTask ([&](UID, Task* task) {
131+ if (net.size () > numPrioritizedTasks) {
132+ LOG_DBG << " net.size() " << net.size () << " != numPrioritizedTasks "
133+ << numPrioritizedTasks;
134+ net.eachTask ([&](Task* task) {
69135 const Label& l = task->label ();
70136 if (task->priority () == -1 )
71137 LOG_DBG << " \t " << l.uid << " " << l.kind << " " ;
72138 });
73139 ERR_DIE (" Invalid network; some nodes will not be executed." );
74140 }
75- assert (finder .size () == numPrioritizedTasks);
141+ assert (net .size () == numPrioritizedTasks);
76142}
143+
77144} // namespace
78145
79146namespace nt {
@@ -252,6 +319,27 @@ const TaskFinder& Network::finder() const
252319 return finder_;
253320}
254321
322+ bool Network::checkIfValid () const
323+ {
324+ bool valid = true ;
325+ eachTask ([&](Task* task) {
326+ if (!task->checkIfValid ())
327+ valid = false ;
328+ });
329+
330+ // Check if the network is weekly connected
331+ size_t numConnectedTasks = 0 ;
332+ visitTaskTopo (*this , [&](Task*) { numConnectedTasks++; });
333+ if (numConnectedTasks != size ()) {
334+ LOG_S (ERROR) << " The network is not weekly connected i.e., there are "
335+ " some nodes that cannot be visited; numConnectedTasks "
336+ << numConnectedTasks << " != size " << size ();
337+ valid = false ;
338+ }
339+
340+ return valid;
341+ }
342+
255343/* class NetworkBuilder */
256344
257345NetworkBuilder::NetworkBuilder (Allocator& alc)
@@ -474,16 +562,19 @@ void Frontend::buildNetwork(NetworkBuilder& nb)
474562
475563 // Create the network from the builder
476564 network_.emplace (nb.createNetwork ());
477- // FIXME check if network is valid
565+
566+ // Check if network is valid
567+ if (!network_->checkIfValid ())
568+ ERR_DIE (" Network is not valid" );
478569
479570 // Set priority to each task
480571 switch (pr_.sched ) {
481572 case SCHED::TOPO:
482- ERR_DIE ( " Scheduling topo is not supported anymore " ); // FIXME
573+ prioritizeTaskByTopo (network_. value ());
483574 break ;
484575
485576 case SCHED::RANKU:
486- prioritizeTaskByRanku (network_-> finder ());
577+ prioritizeTaskByRanku (network_. value ());
487578 break ;
488579 }
489580}
0 commit comments