Skip to content

Commit

Permalink
[CONFORMANCE] Extend Graph Cache by Fused names based extractor (open…
Browse files Browse the repository at this point in the history
…vinotoolkit#18601)

* Add fused names based subgraph extractor

* fix test build
  • Loading branch information
iefode authored Jul 24, 2023
1 parent bc734df commit 1d3dec9
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@
#include "cache/meta/input_info.hpp"
#include "matchers/subgraph/manager.hpp"
#include "matchers/subgraph/subgraph.hpp"
#include "matchers/subgraph/fused_names.hpp"

namespace ov {
namespace tools {
namespace subgraph_dumper {

class GraphCache : public ICache {
public:
void update_cache(const std::shared_ptr<ov::Model>& model,
const std::string& model_meta_data,
void update_cache(const std::shared_ptr<ov::Model>& model, const std::string& model_meta_data,
bool extract_body = true) override;
void serialize_cache() override;

Expand All @@ -43,7 +43,9 @@ class GraphCache : public ICache {
static std::shared_ptr<GraphCache> m_cache_instance;

GraphCache() {
ExtractorsManager::ExtractorsMap matchers = {};
ExtractorsManager::ExtractorsMap matchers = {
{ "fused_names", FusedNamesExtractor::Ptr(new FusedNamesExtractor) },
};
m_manager.set_extractors(matchers);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <utility>

#include "matchers/subgraph/subgraph.hpp"

namespace ov {
namespace tools {
namespace subgraph_dumper {

class FusedNamesExtractor : public SubgraphExtractor {
public:
std::list<ExtractedPattern> extract(const std::shared_ptr<ov::Model> &model,
bool is_extract_body = true) override;

protected:
std::unordered_set<std::string> extract_compiled_model_names(const std::shared_ptr<ov::Model>& model);
};

} // namespace subgraph_dumper
} // namespace tools
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/op/lstm_cell.hpp"
#include "openvino/op/tensor_iterator.hpp"
#include "openvino/op/if.hpp"

#include "functional_test_utils/ov_plugin_cache.hpp"

#include "matchers/subgraph/fused_names.hpp"
#include "utils/model.hpp"

using namespace ov::tools::subgraph_dumper;

std::unordered_set<std::string>
FusedNamesExtractor::extract_compiled_model_names(const std::shared_ptr<ov::Model>& model) {
auto core = ov::test::utils::PluginCache::get().core();
auto compiled_model = core->compile_model(model);
std::unordered_set<std::string> compiled_op_name;
for (const auto& compiled_op : compiled_model.get_runtime_model()->get_ordered_ops()) {
const auto& rt_info = compiled_op->get_rt_info();
if (rt_info.count("originalLayersNames")) {
compiled_op_name.insert(rt_info.find("originalLayersNames")->second.as<std::string>());
}
}
return compiled_op_name;
}

std::list<ExtractedPattern>
FusedNamesExtractor::extract(const std::shared_ptr<ov::Model> &model,
bool is_extract_body) {
auto compiled_op_name = extract_compiled_model_names(model);
std::list<ExtractedPattern> matched_patterns;
std::unordered_set<std::string> checked_ops;
std::set<std::shared_ptr<ov::Node>> nodes;
std::shared_ptr<ov::Node> start_node = nullptr;
for (const auto& op : model->get_ordered_ops()) {
auto op_name = op->get_friendly_name();
if (is_node_to_skip(op) || checked_ops.count(op_name)) {
continue;
}
if (start_node == nullptr) {
start_node = op;
}
nodes.insert(op);
if (is_extract_body) {
if (std::dynamic_pointer_cast<ov::op::v0::TensorIterator>(op)) {
auto ti = ov::as_type_ptr<ov::op::v0::TensorIterator>(op);
auto ti_body = ti->get_function();
auto tmp_res = extract(ti_body);
matched_patterns.insert(matched_patterns.end(), tmp_res.begin(), tmp_res.end());
} else if (std::dynamic_pointer_cast<ov::op::v5::Loop>(op)) {
auto loop = ov::as_type_ptr<ov::op::v5::Loop>(op);
auto loop_body = loop->get_function();
auto tmp_res = extract(loop_body);
matched_patterns.insert(matched_patterns.end(), tmp_res.begin(), tmp_res.end());
} else if (std::dynamic_pointer_cast<ov::op::v8::If>(op)) {
auto if_op = ov::as_type_ptr<ov::op::v8::If>(op);
std::vector<std::shared_ptr<ov::Model>> bodies;
for (size_t i = 0; i < if_op->get_internal_subgraphs_size(); i++) {
auto if_body = if_op->get_function(i);
auto tmp_res = extract(if_body);
matched_patterns.insert(matched_patterns.end(), tmp_res.begin(), tmp_res.end());
}
}
}
if (!compiled_op_name.count(op_name)) {
try {
matched_patterns.push_back(generate_model(nodes, start_node, checked_ops));
} catch(std::exception& e) {
std::cout << e.what() << std::endl;
}
start_node = nullptr;
nodes.clear();
}
}
try {
matched_patterns.push_back(generate_model(nodes, start_node, checked_ops));
} catch(std::exception& e) {
std::cout << e.what() << std::endl;
}
return matched_patterns;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "gtest/gtest.h"

#include "matchers/subgraph/fused_names.hpp"
#include "utils/model.hpp"

#include "test_models/model_0.hpp"
#include "test_models/model_1.hpp"
#include "test_models/model_2.hpp"

namespace {

using namespace ov::tools::subgraph_dumper;


// ======================= ExtractorsManagerTest Unit tests =======================
class FusedNamesExtractorTest : public FusedNamesExtractor,
public ::testing::Test {
protected:
bool is_match(const std::shared_ptr<ov::Model>& model) {
auto compiled_names = extract_compiled_model_names(model);
std::set<std::string> diff;
for (const auto& op : model->get_ordered_ops()) {
auto op_name = op->get_friendly_name();
if (!compiled_names.count(op_name)) {
diff.insert(op_name);
}
}
auto models = this->extract(model);
return diff.size() == 0 ? true : models.size() + 2 == diff.size();
}
};

TEST_F(FusedNamesExtractorTest, extract_0) {
auto test_model = Model_0();
auto model = test_model.get();
ASSERT_TRUE(is_match(model));
}

TEST_F(FusedNamesExtractorTest, extract_1) {
auto test_model = Model_1();
auto model = test_model.get();
ASSERT_TRUE(is_match(model));
}

TEST_F(FusedNamesExtractorTest, extract_2) {
auto test_model = Model_2();
auto model = test_model.get();
ASSERT_TRUE(is_match(model));
}

} // namespace

0 comments on commit 1d3dec9

Please sign in to comment.