From 3c8be082531e41b26eea4a8e310160c14da6b5e4 Mon Sep 17 00:00:00 2001 From: MistEO Date: Thu, 12 Oct 2023 00:11:26 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E8=BF=9B=E4=B8=80=E6=AD=A5=E5=AE=8C?= =?UTF-8?q?=E6=88=90FeatureMatcher?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../MaaFramework/Resource/PipelineResMgr.cpp | 43 +++++++++ source/MaaFramework/Resource/PipelineResMgr.h | 2 + source/MaaFramework/Resource/PipelineTypes.h | 7 +- source/MaaFramework/Task/Recognizer.cpp | 37 ++++++++ source/MaaFramework/Task/Recognizer.h | 2 + source/MaaFramework/Vision/FeatureMatcher.cpp | 91 +++++++++++++------ source/MaaFramework/Vision/FeatureMatcher.h | 40 ++++++-- source/MaaFramework/Vision/VisionTypes.h | 4 +- 8 files changed, 185 insertions(+), 41 deletions(-) diff --git a/source/MaaFramework/Resource/PipelineResMgr.cpp b/source/MaaFramework/Resource/PipelineResMgr.cpp index 63c7cce36..b9b6d81ad 100644 --- a/source/MaaFramework/Resource/PipelineResMgr.cpp +++ b/source/MaaFramework/Resource/PipelineResMgr.cpp @@ -339,6 +339,7 @@ bool PipelineResMgr::parse_recognition(const json::value& input, Recognition::Ty { kDefaultRecognitionFlag, default_type }, { "DirectHit", Type::DirectHit }, { "TemplateMatch", Type::TemplateMatch }, + { "FeatureMatch", Type::FeatureMatch }, { "OCR", Type::OCR }, { "NeuralNetworkClassify", Type::NeuralNetworkClassify }, { "NeuralNetworkDetect", Type::NeuralNetworkDetect }, @@ -366,6 +367,12 @@ bool PipelineResMgr::parse_recognition(const json::value& input, Recognition::Ty same_type ? std::get(default_param) : TemplateMatcherParam {}); + case Type::FeatureMatch: + out_param = FeatureMatcherParam {}; + return parse_feature_matcher_param(input, std::get(out_param), + same_type ? std::get(default_param) + : FeatureMatcherParam {}); + case Type::NeuralNetworkClassify: out_param = NeuralNetworkClassifierParam {}; return parse_nn_classifier_param(input, std::get(out_param), @@ -456,6 +463,42 @@ bool PipelineResMgr::parse_template_matcher_param(const json::value& input, MAA_ return true; } +bool PipelineResMgr::parse_feature_matcher_param(const json::value& input, MAA_VISION_NS::FeatureMatcherParam& output, + const MAA_VISION_NS::FeatureMatcherParam& default_value) +{ + if (!parse_roi(input, output.roi, default_value.roi)) { + LogError << "failed to parse_roi" << VAR(input); + return false; + } + + if (!get_and_check_value(input, "template", output.template_path, default_value.template_path)) { + LogError << "failed to get_and_check_value template_path" << VAR(input); + return false; + } + + if (!get_and_check_value(input, "green_mask", output.green_mask, default_value.green_mask)) { + LogError << "failed to get_and_check_value green_mask" << VAR(input); + return false; + } + + if (!get_and_check_value(input, "hessian", output.hessian, default_value.hessian)) { + LogError << "failed to get_and_check_value hessian" << VAR(input); + return false; + } + + if (!get_and_check_value(input, "distance_ratio", output.distance_ratio, default_value.distance_ratio)) { + LogError << "failed to get_and_check_value hessian" << VAR(input); + return false; + } + + if (!get_and_check_value(input, "count", output.count, default_value.count)) { + LogError << "failed to get_and_check_value hessian" << VAR(input); + return false; + } + + return true; +} + bool PipelineResMgr::parse_ocrer_param(const json::value& input, MAA_VISION_NS::OCRerParam& output, const MAA_VISION_NS::OCRerParam& default_value) { diff --git a/source/MaaFramework/Resource/PipelineResMgr.h b/source/MaaFramework/Resource/PipelineResMgr.h index 44df687fb..2f5225e1b 100644 --- a/source/MaaFramework/Resource/PipelineResMgr.h +++ b/source/MaaFramework/Resource/PipelineResMgr.h @@ -39,6 +39,8 @@ class PipelineResMgr : public NonCopyable // const MAA_VISION_NS::DirectHitParam& default_value); static bool parse_template_matcher_param(const json::value& input, MAA_VISION_NS::TemplateMatcherParam& output, const MAA_VISION_NS::TemplateMatcherParam& default_value); + static bool parse_feature_matcher_param(const json::value& input, MAA_VISION_NS::FeatureMatcherParam& output, + const MAA_VISION_NS::FeatureMatcherParam& default_value); static bool parse_ocrer_param(const json::value& input, MAA_VISION_NS::OCRerParam& output, const MAA_VISION_NS::OCRerParam& default_value); static bool parse_custom_recognizer_param(const json::value& input, MAA_VISION_NS::CustomRecognizerParam& output, diff --git a/source/MaaFramework/Resource/PipelineTypes.h b/source/MaaFramework/Resource/PipelineTypes.h index 1dfda25f0..275d0ddff 100644 --- a/source/MaaFramework/Resource/PipelineTypes.h +++ b/source/MaaFramework/Resource/PipelineTypes.h @@ -23,6 +23,7 @@ enum class Type Invalid = 0, DirectHit, TemplateMatch, + FeatureMatch, OCR, NeuralNetworkClassify, NeuralNetworkDetect, @@ -31,9 +32,9 @@ enum class Type }; using Param = std::variant; + MAA_VISION_NS::FeatureMatcherParam, MAA_VISION_NS::OCRerParam, + MAA_VISION_NS::NeuralNetworkClassifierParam, MAA_VISION_NS::NeuralNetworkDetectorParam, + MAA_VISION_NS::ColorMatcherParam, MAA_VISION_NS::CustomRecognizerParam>; } // namespace Recognition namespace Action diff --git a/source/MaaFramework/Task/Recognizer.cpp b/source/MaaFramework/Task/Recognizer.cpp index 96a0fb556..2f8d44e8e 100644 --- a/source/MaaFramework/Task/Recognizer.cpp +++ b/source/MaaFramework/Task/Recognizer.cpp @@ -5,6 +5,7 @@ #include "Utils/Logger.h" #include "Vision/ColorMatcher.h" #include "Vision/CustomRecognizer.h" +#include "Vision/FeatureMatcher.h" #include "Vision/NeuralNetworkClassifier.h" #include "Vision/NeuralNetworkDetector.h" #include "Vision/OCRer.h" @@ -35,6 +36,10 @@ std::optional Recognizer::recognize(const cv::Mat& image, co result = template_match(image, std::get(task_data.rec_param), task_data.name); break; + case Type::FeatureMatch: + result = feature_match(image, std::get(task_data.rec_param), task_data.name); + break; + case Type::ColorMatch: result = color_match(image, std::get(task_data.rec_param), task_data.name); break; @@ -117,6 +122,38 @@ std::optional Recognizer::template_match(const cv::Mat& imag return Result { .box = box, .detail = detail.to_string() }; } +std::optional Recognizer::feature_match(const cv::Mat& image, + const MAA_VISION_NS::FeatureMatcherParam& param, + const std::string& name) +{ + using namespace MAA_VISION_NS; + + if (!resource()) { + LogError << "Resource not binded"; + return std::nullopt; + } + + FeatureMatcher matcher; + matcher.set_image(image); + matcher.set_name(name); + matcher.set_param(param); + + std::shared_ptr templ = resource()->template_res().image(param.template_path); + matcher.set_template(std::move(templ)); + + auto ret = matcher.analyze(); + if (ret.empty()) { + return std::nullopt; + } + + const cv::Rect& box = ret.front().box; + json::array detail; + for (const auto& res : ret) { + detail.emplace_back(res.to_json()); + } + return Result { .box = box, .detail = detail.to_string() }; +} + std::optional Recognizer::color_match(const cv::Mat& image, const MAA_VISION_NS::ColorMatcherParam& param, const std::string& name) diff --git a/source/MaaFramework/Task/Recognizer.h b/source/MaaFramework/Task/Recognizer.h index 1294af823..63b4b02d6 100644 --- a/source/MaaFramework/Task/Recognizer.h +++ b/source/MaaFramework/Task/Recognizer.h @@ -35,6 +35,8 @@ class Recognizer std::optional direct_hit(); std::optional template_match(const cv::Mat& image, const MAA_VISION_NS::TemplateMatcherParam& param, const std::string& name); + std::optional feature_match(const cv::Mat& image, const MAA_VISION_NS::FeatureMatcherParam& param, + const std::string& name); std::optional color_match(const cv::Mat& image, const MAA_VISION_NS::ColorMatcherParam& param, const std::string& name); std::optional ocr(const cv::Mat& image, const MAA_VISION_NS::OCRerParam& param, const std::string& name); diff --git a/source/MaaFramework/Vision/FeatureMatcher.cpp b/source/MaaFramework/Vision/FeatureMatcher.cpp index 80d071a6d..5ba857ac5 100644 --- a/source/MaaFramework/Vision/FeatureMatcher.cpp +++ b/source/MaaFramework/Vision/FeatureMatcher.cpp @@ -40,37 +40,23 @@ FeatureMatcher::ResultsVec FeatureMatcher::foreach_rois(const cv::Mat& templ) co return {}; } - auto matcher = create_matcher(templ); + auto [keypoints, descriptors] = detect(templ, param_.green_mask); + auto matcher = create_matcher(keypoints, descriptors); if (param_.roi.empty()) { - return match(cv::Rect(0, 0, image_.cols, image_.rows), matcher); + return match(matcher, keypoints, cv::Rect(0, 0, image_.cols, image_.rows)); } ResultsVec results; for (const cv::Rect& roi : param_.roi) { - ResultsVec res = match(roi, matcher); + ResultsVec res = match(matcher, keypoints, roi); results.insert(results.end(), std::make_move_iterator(res.begin()), std::make_move_iterator(res.end())); } return results; } -cv::FlannBasedMatcher FeatureMatcher::create_matcher(const cv::Mat& templ) const -{ - std::vector keypoints_1; - cv::Mat descriptors_1; - detect(templ, param_.green_mask, keypoints_1, descriptors_1); - - std::vector train_desc(1, descriptors_1); - cv::FlannBasedMatcher matcher; - matcher.add(train_desc); - matcher.train(); - - return matcher; -} - -void FeatureMatcher::detect(const cv::Mat& image, bool green_mask, std::vector& keypoints, - cv::Mat& descriptors) const +std::pair, cv::Mat> FeatureMatcher::detect(const cv::Mat& image, bool green_mask) const { auto detector = cv::xfeatures2d::SURF::create(param_.hessian); @@ -80,33 +66,80 @@ void FeatureMatcher::detect(const cv::Mat& image, bool green_mask, std::vector keypoints; + cv::Mat descriptors; detector->detectAndCompute(image, mask, keypoints, descriptors); + + return std::make_pair(std::move(keypoints), std::move(descriptors)); } -FeatureMatcher::ResultsVec FeatureMatcher::match(const cv::Rect& roi, cv::FlannBasedMatcher& matcher) const +cv::FlannBasedMatcher FeatureMatcher::create_matcher(const std::vector& keypoints, + const cv::Mat& descriptors) const { - cv::Mat image = image_with_roi(roi); - std::vector keypoints_2; - cv::Mat descriptors_2; - detect(image, false, keypoints_2, descriptors_2); + std::ignore = keypoints; + + std::vector train_desc(1, descriptors); + cv::FlannBasedMatcher matcher; + matcher.add(train_desc); + matcher.train(); + + return matcher; +} + +FeatureMatcher::ResultsVec FeatureMatcher::match(cv::FlannBasedMatcher& matcher, + const std::vector& keypoints_1, + const cv::Rect& roi_2) const +{ + auto image_2 = image_with_roi(roi_2); + auto [keypoints_2, descriptors_2] = detect(image_2, false); std::vector> match_points; matcher.knnMatch(descriptors_2, match_points, 2); - ResultsVec results; + std::vector good_matches; + std::vector good_points; for (const auto& point : match_points) { if (point.size() != 2) { continue; } - if (point[0].distance < param_.distance_ratio * point[1].distance) { - // TODO + double threshold = param_.distance_ratio * point[0].distance; + if (point[1].distance > threshold) { + continue; } + good_matches.emplace_back(point[1]); + + cv::Point pt = keypoints_2[point[1].queryIdx].pt; + good_points.emplace_back(pt); } + + draw_result(*template_, keypoints_1, roi_2, keypoints_2, good_matches); + + return {}; } -void FeatureMatcher::draw_result(const cv::Rect& roi, const ResultsVec& results) const {} +void FeatureMatcher::draw_result(const cv::Mat& templ, const std::vector& keypoints_1, + const cv::Rect& roi, const std::vector& keypoints_2, + const std::vector& good_matches) const +{ + if (!debug_draw_) { + return; + } + + cv::Mat image_draw = draw_roi(roi); + // const auto color = cv::Scalar(0, 0, 255); + + cv::drawMatches(templ, keypoints_1, image_draw, keypoints_2, good_matches, image_draw); + + if (save_draw_) { + save_image(image_draw); + } +} -void FeatureMatcher::filter(ResultsVec& results, int count) const {} +void FeatureMatcher::filter(ResultsVec& results, int count) const +{ + std::ignore = results; + std::ignore = count; +} MAA_VISION_NS_END diff --git a/source/MaaFramework/Vision/FeatureMatcher.h b/source/MaaFramework/Vision/FeatureMatcher.h index 0f261a6fb..bdb4fbc86 100644 --- a/source/MaaFramework/Vision/FeatureMatcher.h +++ b/source/MaaFramework/Vision/FeatureMatcher.h @@ -3,7 +3,11 @@ #include #include +#include "Conf/Conf.h" + +MAA_SUPPRESS_CV_WARNINGS_BEGIN #include +MAA_SUPPRESS_CV_WARNINGS_END #include "VisionBase.h" #include "VisionTypes.h" @@ -16,13 +20,13 @@ class FeatureMatcher : public VisionBase struct Result { cv::Rect box {}; - double score = 0.0; + int count = 0; json::value to_json() const { json::value root; root["box"] = json::array({ box.x, box.y, box.width, box.height }); - root["score"] = score; + root["count"] = count; return root; } }; @@ -37,11 +41,13 @@ class FeatureMatcher : public VisionBase private: ResultsVec foreach_rois(const cv::Mat& templ) const; - cv::FlannBasedMatcher create_matcher(const cv::Mat& templ) const; - void detect(const cv::Mat& image, bool green_mask, std::vector& keypoints, - cv::Mat& descriptors) const; - ResultsVec match(const cv::Rect& roi, cv::FlannBasedMatcher& matcher) const; - void draw_result(const cv::Rect& roi, const ResultsVec& results) const; + std::pair, cv::Mat> detect(const cv::Mat& image, bool green_mask) const; + cv::FlannBasedMatcher create_matcher(const std::vector& keypoints, const cv::Mat& descriptors) const; + + ResultsVec match(cv::FlannBasedMatcher& matcher, const std::vector& keypoints_1, + const cv::Rect& roi_2) const; + void draw_result(const cv::Mat& templ, const std::vector& keypoints_1, const cv::Rect& roi, + const std::vector& keypoints_2, const std::vector& good_matches) const; void filter(ResultsVec& results, int count) const; FeatureMatcherParam param_; @@ -49,3 +55,23 @@ class FeatureMatcher : public VisionBase }; MAA_VISION_NS_END + +MAA_NS_BEGIN + +inline std::ostream& operator<<(std::ostream& os, const MAA_VISION_NS::FeatureMatcher::Result& res) +{ + os << res.to_json().to_string(); + return os; +} + +inline std::ostream& operator<<(std::ostream& os, const MAA_VISION_NS::FeatureMatcher::ResultsVec& resutls) +{ + json::array root; + for (const auto& res : resutls) { + root.emplace_back(res.to_json()); + } + os << root.to_string(); + return os; +} + +MAA_NS_END diff --git a/source/MaaFramework/Vision/VisionTypes.h b/source/MaaFramework/Vision/VisionTypes.h index ad3347dac..8d5d734ed 100644 --- a/source/MaaFramework/Vision/VisionTypes.h +++ b/source/MaaFramework/Vision/VisionTypes.h @@ -111,7 +111,7 @@ struct FeatureMatcherParam inline static constexpr Detector kDefaultDetector = Detector::SURF; inline static constexpr Matcher kDefaultMatcher = Matcher::KNN; - inline static constexpr int kDefaultHessianThreshold = 100; + inline static constexpr double kDefaultHessianThreshold = 100.0; inline static constexpr double kDefaultDistanceRatio = 0.6; inline static constexpr int kDefaultCount = 4; @@ -120,7 +120,7 @@ struct FeatureMatcherParam bool green_mask = false; Detector detector = kDefaultDetector; - int hessian = kDefaultHessianThreshold; + double hessian = kDefaultHessianThreshold; Matcher matcher = kDefaultMatcher;