Skip to content

Commit

Permalink
refactor: 重命名部分识别器
Browse files Browse the repository at this point in the history
Classify->NeuralNetworkClassify
Detect->NeuralNetworkDetect
  • Loading branch information
MistEO committed Oct 10, 2023
1 parent 495c7fb commit 8f4a4bf
Show file tree
Hide file tree
Showing 17 changed files with 101 additions and 95 deletions.
4 changes: 2 additions & 2 deletions docs/zh_cn/3.1-任务流水线协议.md
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ graph LR;
若为空,则为 `model/ocr` 根目录下的模型文件。
文件夹中需要包含 `rec.onnx`, `det.onnx`, `keys.txt` 三个文件。
### `Classify`
### `NeuralNetworkClassify`
深度学习分类,判断图像中的 **固定位置** 是否为预期的“类别”。
Expand Down Expand Up @@ -326,7 +326,7 @@ graph LR;
注意这些值需要与模型实际输出相符。
### `Detect`
### `NeuralNetworkDetect`
深度学习检测,泛化版“找图”。
Expand Down
16 changes: 8 additions & 8 deletions source/MaaFramework/MaaFramework.vcxproj
Original file line number Diff line number Diff line change
Expand Up @@ -71,16 +71,16 @@
<ClInclude Include="Utils\StringMisc.hpp" />
<ClInclude Include="Utils\TempPath.hpp" />
<ClInclude Include="Utils\Time.hpp" />
<ClInclude Include="Vision\Classifier.h" />
<ClInclude Include="Vision\NeuralNetworkClassifier.h" />
<ClInclude Include="Vision\ColorMatcher.h" />
<ClInclude Include="Vision\Comparator.h" />
<ClInclude Include="Vision\TemplateComparator.h" />
<ClInclude Include="Vision\CustomRecognizer.h" />
<ClInclude Include="Vision\Matcher.h" />
<ClInclude Include="Vision\TemplateMatcher.h" />
<ClInclude Include="Vision\OCRer.h" />
<ClInclude Include="Vision\VisionTypes.h" />
<ClInclude Include="Vision\VisionUtils.hpp" />
<ClInclude Include="Vision\VisionBase.h" />
<ClInclude Include="Vision\Detector.h" />
<ClInclude Include="Vision\NeuralNetworkDetector.h" />
</ItemGroup>
<ItemGroup>
<ClCompile Include="API\MaaBuffer.cpp" />
Expand All @@ -107,14 +107,14 @@
<ClCompile Include="Task\Recognizer.cpp" />
<ClCompile Include="Task\SyncContext.cpp" />
<ClCompile Include="Task\PipelineTask.cpp" />
<ClCompile Include="Vision\Classifier.cpp" />
<ClCompile Include="Vision\NeuralNetworkClassifier.cpp" />
<ClCompile Include="Vision\ColorMatcher.cpp" />
<ClCompile Include="Vision\Comparator.cpp" />
<ClCompile Include="Vision\TemplateComparator.cpp" />
<ClCompile Include="Vision\CustomRecognizer.cpp" />
<ClCompile Include="Vision\Matcher.cpp" />
<ClCompile Include="Vision\TemplateMatcher.cpp" />
<ClCompile Include="Vision\OCRer.cpp" />
<ClCompile Include="Vision\VisionBase.cpp" />
<ClCompile Include="Vision\Detector.cpp" />
<ClCompile Include="Vision\NeuralNetworkDetector.cpp" />
</ItemGroup>
<PropertyGroup Label="Globals">
<VCProjectVersion>16.0</VCProjectVersion>
Expand Down
35 changes: 20 additions & 15 deletions source/MaaFramework/Resource/PipelineResMgr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -340,8 +340,8 @@ bool PipelineResMgr::parse_recognition(const json::value& input, Recognition::Ty
{ "DirectHit", Type::DirectHit },
{ "TemplateMatch", Type::TemplateMatch },
{ "OCR", Type::OCR },
{ "Classify", Type::Classify },
{ "Detect", Type::Detect },
{ "NeuralNetworkClassify", Type::NeuralNetworkClassify },
{ "NeuralNetworkDetect", Type::NeuralNetworkDetect },
{ "ColorMatch", Type::ColorMatch },
{ "Custom", Type::Custom },
};
Expand All @@ -366,15 +366,17 @@ bool PipelineResMgr::parse_recognition(const json::value& input, Recognition::Ty
same_type ? std::get<TemplateMatcherParam>(default_param)
: TemplateMatcherParam {});

case Type::Classify:
out_param = ClassifierParam {};
return parse_classifier_param(input, std::get<ClassifierParam>(out_param),
same_type ? std::get<ClassifierParam>(default_param) : ClassifierParam {});
case Type::NeuralNetworkClassify:
out_param = NeuralNetworkClassifierParam {};
return parse_nn_classifier_param(input, std::get<NeuralNetworkClassifierParam>(out_param),
same_type ? std::get<NeuralNetworkClassifierParam>(default_param)
: NeuralNetworkClassifierParam {});

case Type::Detect:
out_param = DetectorParam {};
return parse_detector_param(input, std::get<DetectorParam>(out_param),
same_type ? std::get<DetectorParam>(default_param) : DetectorParam {});
case Type::NeuralNetworkDetect:
out_param = NeuralNetworkDetectorParam {};
return parse_nn_detector_param(input, std::get<NeuralNetworkDetectorParam>(out_param),
same_type ? std::get<NeuralNetworkDetectorParam>(default_param)
: NeuralNetworkDetectorParam {});

case Type::OCR:
out_param = OCRerParam {};
Expand Down Expand Up @@ -531,8 +533,9 @@ bool PipelineResMgr::parse_custom_recognizer_param(const json::value& input,
return true;
}

bool PipelineResMgr::parse_classifier_param(const json::value& input, MAA_VISION_NS::ClassifierParam& output,
const MAA_VISION_NS::ClassifierParam& default_value)
bool PipelineResMgr::parse_nn_classifier_param(const json::value& input,
MAA_VISION_NS::NeuralNetworkClassifierParam& output,
const MAA_VISION_NS::NeuralNetworkClassifierParam& default_value)
{
if (!parse_roi(input, output.roi, default_value.roi)) {
LogError << "failed to parse_roi" << VAR(input);
Expand Down Expand Up @@ -566,8 +569,9 @@ bool PipelineResMgr::parse_classifier_param(const json::value& input, MAA_VISION
return true;
}

bool PipelineResMgr::parse_detector_param(const json::value& input, MAA_VISION_NS::DetectorParam& output,
const MAA_VISION_NS::DetectorParam& default_value)
bool PipelineResMgr::parse_nn_detector_param(const json::value& input,
MAA_VISION_NS::NeuralNetworkDetectorParam& output,
const MAA_VISION_NS::NeuralNetworkDetectorParam& default_value)
{
if (!parse_roi(input, output.roi, default_value.roi)) {
LogError << "failed to parse_roi" << VAR(input);
Expand Down Expand Up @@ -603,7 +607,8 @@ bool PipelineResMgr::parse_detector_param(const json::value& input, MAA_VISION_N
return false;
}
if (output.thresholds.empty()) {
output.thresholds = std::vector(output.expected.size(), MAA_VISION_NS::DetectorParam::kDefaultThreshold);
output.thresholds =
std::vector(output.expected.size(), MAA_VISION_NS::NeuralNetworkDetectorParam::kDefaultThreshold);
}
else if (output.expected.size() != output.thresholds.size()) {
LogError << "templates.size() != thresholds.size()" << VAR(output.expected.size())
Expand Down
8 changes: 4 additions & 4 deletions source/MaaFramework/Resource/PipelineResMgr.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ class PipelineResMgr : public NonCopyable
const MAA_VISION_NS::OCRerParam& default_value);
static bool parse_custom_recognizer_param(const json::value& input, MAA_VISION_NS::CustomRecognizerParam& output,
const MAA_VISION_NS::CustomRecognizerParam& default_value);
static bool parse_classifier_param(const json::value& input, MAA_VISION_NS::ClassifierParam& output,
const MAA_VISION_NS::ClassifierParam& default_value);
static bool parse_detector_param(const json::value& input, MAA_VISION_NS::DetectorParam& output,
const MAA_VISION_NS::DetectorParam& default_value);
static bool parse_nn_classifier_param(const json::value& input, MAA_VISION_NS::NeuralNetworkClassifierParam& output,
const MAA_VISION_NS::NeuralNetworkClassifierParam& default_value);
static bool parse_nn_detector_param(const json::value& input, MAA_VISION_NS::NeuralNetworkDetectorParam& output,
const MAA_VISION_NS::NeuralNetworkDetectorParam& default_value);
static bool parse_color_matcher_param(const json::value& input, MAA_VISION_NS::ColorMatcherParam& output,
const MAA_VISION_NS::ColorMatcherParam& default_value);

Expand Down
9 changes: 5 additions & 4 deletions source/MaaFramework/Resource/PipelineTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,16 @@ enum class Type
DirectHit,
TemplateMatch,
OCR,
Classify,
Detect,
NeuralNetworkClassify,
NeuralNetworkDetect,
ColorMatch,
Custom,
};

using Param = std::variant<std::monostate, MAA_VISION_NS::DirectHitParam, MAA_VISION_NS::TemplateMatcherParam,
MAA_VISION_NS::OCRerParam, MAA_VISION_NS::ClassifierParam, MAA_VISION_NS::DetectorParam,
MAA_VISION_NS::ColorMatcherParam, MAA_VISION_NS::CustomRecognizerParam>;
MAA_VISION_NS::OCRerParam, MAA_VISION_NS::NeuralNetworkClassifierParam,
MAA_VISION_NS::NeuralNetworkDetectorParam, MAA_VISION_NS::ColorMatcherParam,
MAA_VISION_NS::CustomRecognizerParam>;
} // namespace Recognition

namespace Action
Expand Down
4 changes: 2 additions & 2 deletions source/MaaFramework/Task/Actuator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#include "Instance/InstanceStatus.h"
#include "Task/CustomAction.h"
#include "Utils/Logger.h"
#include "Vision/Comparator.h"
#include "Vision/TemplateComparator.h"

MAA_TASK_NS_BEGIN

Expand Down Expand Up @@ -105,7 +105,7 @@ void Actuator::wait_freezes(const MAA_RES_NS::WaitFreezesParam& param, const cv:

cv::Rect target = get_target_rect(param.target, cur_box);

Comparator comp;
TemplateComparator comp;
comp.set_param({
.roi = { target },
.threshold = param.threshold,
Expand Down
25 changes: 13 additions & 12 deletions source/MaaFramework/Task/Recognizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
#include "Instance/InstanceStatus.h"
#include "Resource/ResourceMgr.h"
#include "Utils/Logger.h"
#include "Vision/Classifier.h"
#include "Vision/ColorMatcher.h"
#include "Vision/CustomRecognizer.h"
#include "Vision/Detector.h"
#include "Vision/Matcher.h"
#include "Vision/NeuralNetworkClassifier.h"
#include "Vision/NeuralNetworkDetector.h"
#include "Vision/OCRer.h"
#include "Vision/TemplateMatcher.h"
#include "Vision/VisionUtils.hpp"

MAA_TASK_NS_BEGIN
Expand Down Expand Up @@ -43,12 +43,12 @@ std::optional<Recognizer::Result> Recognizer::recognize(const cv::Mat& image, co
result = ocr(image, std::get<OCRerParam>(task_data.rec_param), task_data.name);
break;

case Type::Classify:
result = classify(image, std::get<ClassifierParam>(task_data.rec_param), task_data.name);
case Type::NeuralNetworkClassify:
result = classify(image, std::get<NeuralNetworkClassifierParam>(task_data.rec_param), task_data.name);
break;

case Type::Detect:
result = detect(image, std::get<DetectorParam>(task_data.rec_param), task_data.name);
case Type::NeuralNetworkDetect:
result = detect(image, std::get<NeuralNetworkDetectorParam>(task_data.rec_param), task_data.name);
break;

case Type::Custom:
Expand Down Expand Up @@ -88,7 +88,7 @@ std::optional<Recognizer::Result> Recognizer::template_match(const cv::Mat& imag
return std::nullopt;
}

Matcher matcher;
TemplateMatcher matcher;
matcher.set_image(image);
matcher.set_name(name);
matcher.set_param(param);
Expand Down Expand Up @@ -183,7 +183,7 @@ std::optional<Recognizer::Result> Recognizer::ocr(const cv::Mat& image, const MA
}

std::optional<Recognizer::Result> Recognizer::classify(const cv::Mat& image,
const MAA_VISION_NS::ClassifierParam& param,
const MAA_VISION_NS::NeuralNetworkClassifierParam& param,
const std::string& name)
{
using namespace MAA_VISION_NS;
Expand All @@ -193,7 +193,7 @@ std::optional<Recognizer::Result> Recognizer::classify(const cv::Mat& image,
return std::nullopt;
}

Classifier classifier;
NeuralNetworkClassifier classifier;
classifier.set_image(image);
classifier.set_name(name);
classifier.set_param(param);
Expand All @@ -214,7 +214,8 @@ std::optional<Recognizer::Result> Recognizer::classify(const cv::Mat& image,
return Result { .box = box, .detail = std::move(detail) };
}

std::optional<Recognizer::Result> Recognizer::detect(const cv::Mat& image, const MAA_VISION_NS::DetectorParam& param,
std::optional<Recognizer::Result> Recognizer::detect(const cv::Mat& image,
const MAA_VISION_NS::NeuralNetworkDetectorParam& param,
const std::string& name)
{
using namespace MAA_VISION_NS;
Expand All @@ -224,7 +225,7 @@ std::optional<Recognizer::Result> Recognizer::detect(const cv::Mat& image, const
return std::nullopt;
}

Detector detector;
NeuralNetworkDetector detector;
detector.set_image(image);
detector.set_name(name);
detector.set_param(param);
Expand Down
7 changes: 3 additions & 4 deletions source/MaaFramework/Task/Recognizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,10 @@ class Recognizer
const std::string& name);
std::optional<Result> color_match(const cv::Mat& image, const MAA_VISION_NS::ColorMatcherParam& param,
const std::string& name);
std::optional<Result> ocr(const cv::Mat& image, const MAA_VISION_NS::OCRerParam& param,
const std::string& name);
std::optional<Result> classify(const cv::Mat& image, const MAA_VISION_NS::ClassifierParam& param,
std::optional<Result> ocr(const cv::Mat& image, const MAA_VISION_NS::OCRerParam& param, const std::string& name);
std::optional<Result> classify(const cv::Mat& image, const MAA_VISION_NS::NeuralNetworkClassifierParam& param,
const std::string& name);
std::optional<Result> detect(const cv::Mat& image, const MAA_VISION_NS::DetectorParam& param,
std::optional<Result> detect(const cv::Mat& image, const MAA_VISION_NS::NeuralNetworkDetectorParam& param,
const std::string& name);
std::optional<Result> custom_recognize(const cv::Mat& image, const MAA_VISION_NS::CustomRecognizerParam& param,
const std::string& name);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include "Classifier.h"
#include "NeuralNetworkClassifier.h"

#include <onnxruntime/onnxruntime_cxx_api.h>

Expand All @@ -8,7 +8,7 @@

MAA_VISION_NS_BEGIN

Classifier::ResultsVec Classifier::analyze() const
NeuralNetworkClassifier::ResultsVec NeuralNetworkClassifier::analyze() const
{
LogFunc << name_;

Expand Down Expand Up @@ -39,7 +39,7 @@ Classifier::ResultsVec Classifier::analyze() const
return results;
}

Classifier::ResultsVec Classifier::foreach_rois() const
NeuralNetworkClassifier::ResultsVec NeuralNetworkClassifier::foreach_rois() const
{
if (param_.roi.empty()) {
return { classify(cv::Rect(0, 0, image_.cols, image_.rows)) };
Expand All @@ -54,7 +54,7 @@ Classifier::ResultsVec Classifier::foreach_rois() const
return results;
}

Classifier::Result Classifier::classify(const cv::Rect& roi) const
NeuralNetworkClassifier::Result NeuralNetworkClassifier::classify(const cv::Rect& roi) const
{
if (!session_) {
LogError << "OrtSession not loaded";
Expand Down Expand Up @@ -100,7 +100,7 @@ Classifier::Result Classifier::classify(const cv::Rect& roi) const
return result;
}

void Classifier::draw_result(const Result& res) const
void NeuralNetworkClassifier::draw_result(const Result& res) const
{
if (!debug_draw_) {
return;
Expand All @@ -122,7 +122,7 @@ void Classifier::draw_result(const Result& res) const
}
}

void Classifier::filter(ResultsVec& results, const std::vector<size_t>& expected) const
void NeuralNetworkClassifier::filter(ResultsVec& results, const std::vector<size_t>& expected) const
{
if (expected.empty()) {
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

MAA_VISION_NS_BEGIN

class Classifier : public VisionBase
class NeuralNetworkClassifier : public VisionBase
{
public:
struct Result
Expand All @@ -37,7 +37,7 @@ class Classifier : public VisionBase
using ResultsVec = std::vector<Result>;

public:
void set_param(ClassifierParam param) { param_ = std::move(param); }
void set_param(NeuralNetworkClassifierParam param) { param_ = std::move(param); }
void set_session(std::shared_ptr<Ort::Session> session) { session_ = std::move(session); }
ResultsVec analyze() const;

Expand All @@ -48,21 +48,21 @@ class Classifier : public VisionBase

void filter(ResultsVec& results, const std::vector<size_t>& expected) const;

ClassifierParam param_;
NeuralNetworkClassifierParam param_;
std::shared_ptr<Ort::Session> session_ = nullptr;
};

MAA_VISION_NS_END

MAA_NS_BEGIN

inline std::ostream& operator<<(std::ostream& os, const MAA_VISION_NS::Classifier::Result& res)
inline std::ostream& operator<<(std::ostream& os, const MAA_VISION_NS::NeuralNetworkClassifier::Result& res)
{
os << res.to_json().to_string();
return os;
}

inline std::ostream& operator<<(std::ostream& os, const MAA_VISION_NS::Classifier::ResultsVec& resutls)
inline std::ostream& operator<<(std::ostream& os, const MAA_VISION_NS::NeuralNetworkClassifier::ResultsVec& resutls)
{
json::array root;
for (const auto& res : resutls) {
Expand Down
Loading

0 comments on commit 8f4a4bf

Please sign in to comment.