Skip to content

Commit

Permalink
feat: 进一步完成FeatureMatcher
Browse files Browse the repository at this point in the history
  • Loading branch information
MistEO committed Oct 11, 2023
1 parent 825bc52 commit 3c8be08
Show file tree
Hide file tree
Showing 8 changed files with 185 additions and 41 deletions.
43 changes: 43 additions & 0 deletions source/MaaFramework/Resource/PipelineResMgr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ bool PipelineResMgr::parse_recognition(const json::value& input, Recognition::Ty
{ kDefaultRecognitionFlag, default_type },
{ "DirectHit", Type::DirectHit },
{ "TemplateMatch", Type::TemplateMatch },
{ "FeatureMatch", Type::FeatureMatch },
{ "OCR", Type::OCR },
{ "NeuralNetworkClassify", Type::NeuralNetworkClassify },
{ "NeuralNetworkDetect", Type::NeuralNetworkDetect },
Expand Down Expand Up @@ -366,6 +367,12 @@ bool PipelineResMgr::parse_recognition(const json::value& input, Recognition::Ty
same_type ? std::get<TemplateMatcherParam>(default_param)
: TemplateMatcherParam {});

case Type::FeatureMatch:
out_param = FeatureMatcherParam {};
return parse_feature_matcher_param(input, std::get<FeatureMatcherParam>(out_param),
same_type ? std::get<FeatureMatcherParam>(default_param)
: FeatureMatcherParam {});

case Type::NeuralNetworkClassify:
out_param = NeuralNetworkClassifierParam {};
return parse_nn_classifier_param(input, std::get<NeuralNetworkClassifierParam>(out_param),
Expand Down Expand Up @@ -456,6 +463,42 @@ bool PipelineResMgr::parse_template_matcher_param(const json::value& input, MAA_
return true;
}

bool PipelineResMgr::parse_feature_matcher_param(const json::value& input, MAA_VISION_NS::FeatureMatcherParam& output,
const MAA_VISION_NS::FeatureMatcherParam& default_value)
{
if (!parse_roi(input, output.roi, default_value.roi)) {
LogError << "failed to parse_roi" << VAR(input);
return false;
}

if (!get_and_check_value(input, "template", output.template_path, default_value.template_path)) {
LogError << "failed to get_and_check_value template_path" << VAR(input);
return false;
}

if (!get_and_check_value(input, "green_mask", output.green_mask, default_value.green_mask)) {
LogError << "failed to get_and_check_value green_mask" << VAR(input);
return false;
}

if (!get_and_check_value(input, "hessian", output.hessian, default_value.hessian)) {
LogError << "failed to get_and_check_value hessian" << VAR(input);
return false;
}

if (!get_and_check_value(input, "distance_ratio", output.distance_ratio, default_value.distance_ratio)) {
LogError << "failed to get_and_check_value hessian" << VAR(input);
return false;
}

if (!get_and_check_value(input, "count", output.count, default_value.count)) {
LogError << "failed to get_and_check_value hessian" << VAR(input);
return false;
}

return true;
}

bool PipelineResMgr::parse_ocrer_param(const json::value& input, MAA_VISION_NS::OCRerParam& output,
const MAA_VISION_NS::OCRerParam& default_value)
{
Expand Down
2 changes: 2 additions & 0 deletions source/MaaFramework/Resource/PipelineResMgr.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ class PipelineResMgr : public NonCopyable
// const MAA_VISION_NS::DirectHitParam& default_value);
static bool parse_template_matcher_param(const json::value& input, MAA_VISION_NS::TemplateMatcherParam& output,
const MAA_VISION_NS::TemplateMatcherParam& default_value);
static bool parse_feature_matcher_param(const json::value& input, MAA_VISION_NS::FeatureMatcherParam& output,
const MAA_VISION_NS::FeatureMatcherParam& default_value);
static bool parse_ocrer_param(const json::value& input, MAA_VISION_NS::OCRerParam& output,
const MAA_VISION_NS::OCRerParam& default_value);
static bool parse_custom_recognizer_param(const json::value& input, MAA_VISION_NS::CustomRecognizerParam& output,
Expand Down
7 changes: 4 additions & 3 deletions source/MaaFramework/Resource/PipelineTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ enum class Type
Invalid = 0,
DirectHit,
TemplateMatch,
FeatureMatch,
OCR,
NeuralNetworkClassify,
NeuralNetworkDetect,
Expand All @@ -31,9 +32,9 @@ enum class Type
};

using Param = std::variant<std::monostate, MAA_VISION_NS::DirectHitParam, MAA_VISION_NS::TemplateMatcherParam,
MAA_VISION_NS::OCRerParam, MAA_VISION_NS::NeuralNetworkClassifierParam,
MAA_VISION_NS::NeuralNetworkDetectorParam, MAA_VISION_NS::ColorMatcherParam,
MAA_VISION_NS::CustomRecognizerParam>;
MAA_VISION_NS::FeatureMatcherParam, MAA_VISION_NS::OCRerParam,
MAA_VISION_NS::NeuralNetworkClassifierParam, MAA_VISION_NS::NeuralNetworkDetectorParam,
MAA_VISION_NS::ColorMatcherParam, MAA_VISION_NS::CustomRecognizerParam>;
} // namespace Recognition

namespace Action
Expand Down
37 changes: 37 additions & 0 deletions source/MaaFramework/Task/Recognizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "Utils/Logger.h"
#include "Vision/ColorMatcher.h"
#include "Vision/CustomRecognizer.h"
#include "Vision/FeatureMatcher.h"
#include "Vision/NeuralNetworkClassifier.h"
#include "Vision/NeuralNetworkDetector.h"
#include "Vision/OCRer.h"
Expand Down Expand Up @@ -35,6 +36,10 @@ std::optional<Recognizer::Result> Recognizer::recognize(const cv::Mat& image, co
result = template_match(image, std::get<TemplateMatcherParam>(task_data.rec_param), task_data.name);
break;

case Type::FeatureMatch:
result = feature_match(image, std::get<FeatureMatcherParam>(task_data.rec_param), task_data.name);
break;

case Type::ColorMatch:
result = color_match(image, std::get<ColorMatcherParam>(task_data.rec_param), task_data.name);
break;
Expand Down Expand Up @@ -117,6 +122,38 @@ std::optional<Recognizer::Result> Recognizer::template_match(const cv::Mat& imag
return Result { .box = box, .detail = detail.to_string() };
}

std::optional<Recognizer::Result> Recognizer::feature_match(const cv::Mat& image,
const MAA_VISION_NS::FeatureMatcherParam& param,
const std::string& name)
{
using namespace MAA_VISION_NS;

if (!resource()) {
LogError << "Resource not binded";
return std::nullopt;
}

FeatureMatcher matcher;
matcher.set_image(image);
matcher.set_name(name);
matcher.set_param(param);

std::shared_ptr<cv::Mat> templ = resource()->template_res().image(param.template_path);
matcher.set_template(std::move(templ));

auto ret = matcher.analyze();
if (ret.empty()) {
return std::nullopt;
}

const cv::Rect& box = ret.front().box;
json::array detail;
for (const auto& res : ret) {
detail.emplace_back(res.to_json());
}
return Result { .box = box, .detail = detail.to_string() };
}

std::optional<Recognizer::Result> Recognizer::color_match(const cv::Mat& image,
const MAA_VISION_NS::ColorMatcherParam& param,
const std::string& name)
Expand Down
2 changes: 2 additions & 0 deletions source/MaaFramework/Task/Recognizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ class Recognizer
std::optional<Result> direct_hit();
std::optional<Result> template_match(const cv::Mat& image, const MAA_VISION_NS::TemplateMatcherParam& param,
const std::string& name);
std::optional<Result> feature_match(const cv::Mat& image, const MAA_VISION_NS::FeatureMatcherParam& param,
const std::string& name);
std::optional<Result> color_match(const cv::Mat& image, const MAA_VISION_NS::ColorMatcherParam& param,
const std::string& name);
std::optional<Result> ocr(const cv::Mat& image, const MAA_VISION_NS::OCRerParam& param, const std::string& name);
Expand Down
91 changes: 62 additions & 29 deletions source/MaaFramework/Vision/FeatureMatcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,37 +40,23 @@ FeatureMatcher::ResultsVec FeatureMatcher::foreach_rois(const cv::Mat& templ) co
return {};
}

auto matcher = create_matcher(templ);
auto [keypoints, descriptors] = detect(templ, param_.green_mask);
auto matcher = create_matcher(keypoints, descriptors);

if (param_.roi.empty()) {
return match(cv::Rect(0, 0, image_.cols, image_.rows), matcher);
return match(matcher, keypoints, cv::Rect(0, 0, image_.cols, image_.rows));
}

ResultsVec results;
for (const cv::Rect& roi : param_.roi) {
ResultsVec res = match(roi, matcher);
ResultsVec res = match(matcher, keypoints, roi);
results.insert(results.end(), std::make_move_iterator(res.begin()), std::make_move_iterator(res.end()));
}

return results;
}

cv::FlannBasedMatcher FeatureMatcher::create_matcher(const cv::Mat& templ) const
{
std::vector<cv::KeyPoint> keypoints_1;
cv::Mat descriptors_1;
detect(templ, param_.green_mask, keypoints_1, descriptors_1);

std::vector<cv::Mat> train_desc(1, descriptors_1);
cv::FlannBasedMatcher matcher;
matcher.add(train_desc);
matcher.train();

return matcher;
}

void FeatureMatcher::detect(const cv::Mat& image, bool green_mask, std::vector<cv::KeyPoint>& keypoints,
cv::Mat& descriptors) const
std::pair<std::vector<cv::KeyPoint>, cv::Mat> FeatureMatcher::detect(const cv::Mat& image, bool green_mask) const
{
auto detector = cv::xfeatures2d::SURF::create(param_.hessian);

Expand All @@ -80,33 +66,80 @@ void FeatureMatcher::detect(const cv::Mat& image, bool green_mask, std::vector<c
mask = ~mask;
}

std::vector<cv::KeyPoint> keypoints;
cv::Mat descriptors;
detector->detectAndCompute(image, mask, keypoints, descriptors);

return std::make_pair(std::move(keypoints), std::move(descriptors));
}

FeatureMatcher::ResultsVec FeatureMatcher::match(const cv::Rect& roi, cv::FlannBasedMatcher& matcher) const
cv::FlannBasedMatcher FeatureMatcher::create_matcher(const std::vector<cv::KeyPoint>& keypoints,
const cv::Mat& descriptors) const
{
cv::Mat image = image_with_roi(roi);
std::vector<cv::KeyPoint> keypoints_2;
cv::Mat descriptors_2;
detect(image, false, keypoints_2, descriptors_2);
std::ignore = keypoints;

std::vector<cv::Mat> train_desc(1, descriptors);
cv::FlannBasedMatcher matcher;
matcher.add(train_desc);
matcher.train();

return matcher;
}

FeatureMatcher::ResultsVec FeatureMatcher::match(cv::FlannBasedMatcher& matcher,
const std::vector<cv::KeyPoint>& keypoints_1,
const cv::Rect& roi_2) const
{
auto image_2 = image_with_roi(roi_2);
auto [keypoints_2, descriptors_2] = detect(image_2, false);

std::vector<std::vector<cv::DMatch>> match_points;
matcher.knnMatch(descriptors_2, match_points, 2);

ResultsVec results;
std::vector<cv::DMatch> good_matches;
std::vector<cv::Point> good_points;
for (const auto& point : match_points) {
if (point.size() != 2) {
continue;
}

if (point[0].distance < param_.distance_ratio * point[1].distance) {
// TODO
double threshold = param_.distance_ratio * point[0].distance;
if (point[1].distance > threshold) {
continue;
}
good_matches.emplace_back(point[1]);

cv::Point pt = keypoints_2[point[1].queryIdx].pt;
good_points.emplace_back(pt);
}

draw_result(*template_, keypoints_1, roi_2, keypoints_2, good_matches);

return {};
}

void FeatureMatcher::draw_result(const cv::Rect& roi, const ResultsVec& results) const {}
void FeatureMatcher::draw_result(const cv::Mat& templ, const std::vector<cv::KeyPoint>& keypoints_1,
const cv::Rect& roi, const std::vector<cv::KeyPoint>& keypoints_2,
const std::vector<cv::DMatch>& good_matches) const
{
if (!debug_draw_) {
return;
}

cv::Mat image_draw = draw_roi(roi);
// const auto color = cv::Scalar(0, 0, 255);

cv::drawMatches(templ, keypoints_1, image_draw, keypoints_2, good_matches, image_draw);

if (save_draw_) {
save_image(image_draw);
}
}

void FeatureMatcher::filter(ResultsVec& results, int count) const {}
void FeatureMatcher::filter(ResultsVec& results, int count) const
{
std::ignore = results;
std::ignore = count;
}

MAA_VISION_NS_END
40 changes: 33 additions & 7 deletions source/MaaFramework/Vision/FeatureMatcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
#include <ostream>
#include <vector>

#include "Conf/Conf.h"

MAA_SUPPRESS_CV_WARNINGS_BEGIN
#include <opencv2/features2d.hpp>
MAA_SUPPRESS_CV_WARNINGS_END

#include "VisionBase.h"
#include "VisionTypes.h"
Expand All @@ -16,13 +20,13 @@ class FeatureMatcher : public VisionBase
struct Result
{
cv::Rect box {};
double score = 0.0;
int count = 0;

json::value to_json() const
{
json::value root;
root["box"] = json::array({ box.x, box.y, box.width, box.height });
root["score"] = score;
root["count"] = count;
return root;
}
};
Expand All @@ -37,15 +41,37 @@ class FeatureMatcher : public VisionBase

private:
ResultsVec foreach_rois(const cv::Mat& templ) const;
cv::FlannBasedMatcher create_matcher(const cv::Mat& templ) const;
void detect(const cv::Mat& image, bool green_mask, std::vector<cv::KeyPoint>& keypoints,
cv::Mat& descriptors) const;
ResultsVec match(const cv::Rect& roi, cv::FlannBasedMatcher& matcher) const;
void draw_result(const cv::Rect& roi, const ResultsVec& results) const;
std::pair<std::vector<cv::KeyPoint>, cv::Mat> detect(const cv::Mat& image, bool green_mask) const;
cv::FlannBasedMatcher create_matcher(const std::vector<cv::KeyPoint>& keypoints, const cv::Mat& descriptors) const;

ResultsVec match(cv::FlannBasedMatcher& matcher, const std::vector<cv::KeyPoint>& keypoints_1,
const cv::Rect& roi_2) const;
void draw_result(const cv::Mat& templ, const std::vector<cv::KeyPoint>& keypoints_1, const cv::Rect& roi,
const std::vector<cv::KeyPoint>& keypoints_2, const std::vector<cv::DMatch>& good_matches) const;
void filter(ResultsVec& results, int count) const;

FeatureMatcherParam param_;
std::shared_ptr<cv::Mat> template_;
};

MAA_VISION_NS_END

MAA_NS_BEGIN

inline std::ostream& operator<<(std::ostream& os, const MAA_VISION_NS::FeatureMatcher::Result& res)
{
os << res.to_json().to_string();
return os;
}

inline std::ostream& operator<<(std::ostream& os, const MAA_VISION_NS::FeatureMatcher::ResultsVec& resutls)
{
json::array root;
for (const auto& res : resutls) {
root.emplace_back(res.to_json());
}
os << root.to_string();
return os;
}

MAA_NS_END
4 changes: 2 additions & 2 deletions source/MaaFramework/Vision/VisionTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ struct FeatureMatcherParam

inline static constexpr Detector kDefaultDetector = Detector::SURF;
inline static constexpr Matcher kDefaultMatcher = Matcher::KNN;
inline static constexpr int kDefaultHessianThreshold = 100;
inline static constexpr double kDefaultHessianThreshold = 100.0;
inline static constexpr double kDefaultDistanceRatio = 0.6;
inline static constexpr int kDefaultCount = 4;

Expand All @@ -120,7 +120,7 @@ struct FeatureMatcherParam
bool green_mask = false;

Detector detector = kDefaultDetector;
int hessian = kDefaultHessianThreshold;
double hessian = kDefaultHessianThreshold;

Matcher matcher = kDefaultMatcher;

Expand Down

0 comments on commit 3c8be08

Please sign in to comment.