@@ -166,17 +166,33 @@ std::vector<std::unique_ptr<ComputeCapability>> GetCapability::Execute() {
166
166
auto connected_clusters = GetConnectedClusters (graph_viewer_, ng_clusters);
167
167
168
168
int no_of_clusters = 0 ;
169
+ std::vector<NodeIndex> prev_cluster;
170
+ bool try_next_cluster = false ;
169
171
170
172
for (auto this_cluster : connected_clusters) {
173
+ bool omit_subgraph = false ;
174
+ if (try_next_cluster) {
175
+ // no need to check previous cluster
176
+ for (auto idx : prev_cluster) {
177
+ if ((std::find (this_cluster.begin (), this_cluster.end (), idx)) == this_cluster.end ()) {
178
+ this_cluster.emplace_back (idx);
179
+ }
180
+ }
181
+ try_next_cluster = false ;
182
+ }
183
+
171
184
// If subgraph has less then three, graph is considered trivial unless its an epctx cluster
172
- if (this_cluster.size () < 3 ) {
185
+ if (!try_next_cluster && this_cluster.size () < 3 ) {
173
186
bool is_epctx_node = false ;
174
187
for (auto node_idx : this_cluster) {
175
188
if (graph_viewer_.GetNode (node_idx)->OpType () == " EPContext" )
176
189
is_epctx_node = true ;
177
190
}
178
- if (!is_epctx_node)
179
- continue ;
191
+ if (!is_epctx_node) {
192
+ omit_subgraph = true ;
193
+ prev_cluster = this_cluster;
194
+ try_next_cluster = true ;
195
+ }
180
196
}
181
197
182
198
std::vector<std::string> cluster_graph_inputs, cluster_inputs, cluster_outputs;
@@ -188,7 +204,7 @@ std::vector<std::unique_ptr<ComputeCapability>> GetCapability::Execute() {
188
204
cluster_inputs,
189
205
cluster_outputs);
190
206
191
- bool omit_subgraph = false ;
207
+
192
208
// Omitting zero dim subgraphs
193
209
for (auto index : this_cluster) {
194
210
const Node* node = graph_viewer_.GetNode (index);
0 commit comments