forked from oreillymedia/Learning-OpenCV-3_examples
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding the Decision tree and the mushroom database
- Loading branch information
Prasanna
committed
Jun 11, 2017
1 parent
a1630a0
commit 7375ccf
Showing
7 changed files
with
8,390 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
#include <opencv2/opencv.hpp> | ||
#include <stdio.h> | ||
#include <iostream> | ||
using namespace std; | ||
using namespace cv; | ||
int main(int argc, char *argv[]) { | ||
// If the caller gave a filename, great. Otherwise, use a default. | ||
// | ||
const char *csv_file_name = argc >= 2 ? argv[1] : "agaricus-lepiota.data"; | ||
cout << "OpenCV Version: " << CV_VERSION << endl; | ||
// Read in the CSV file that we were given. | ||
// | ||
cv::Ptr<cv::ml::TrainData> data_set = | ||
cv::ml::TrainData::loadFromCSV(csv_file_name, | ||
// Input file name | ||
0, | ||
// Header lines (ignore this many) | ||
0, | ||
// Responses are (start) at thie column | ||
1, | ||
// Inputs start at this column | ||
"cat[0-22]" | ||
// All 23 columns are categorical | ||
); | ||
// Use defaults for delimeter (',') and missch ('?') | ||
// Verify that we read in what we think. | ||
// | ||
int n_samples = data_set->getNSamples(); | ||
if (n_samples == 0) { | ||
cerr << "Could not read file: " << csv_file_name << endl; | ||
exit(-1); | ||
} else { | ||
cout << "Read " << n_samples << " samples from " << csv_file_name << endl; | ||
} | ||
// Split the data, so that 90% is train data | ||
// | ||
data_set->setTrainTestSplitRatio(0.90, false); | ||
int n_train_samples = data_set->getNTrainSamples(); | ||
int n_test_samples = data_set->getNTestSamples(); | ||
cout << "Found " << n_train_samples << " Train Samples, and " | ||
<< n_test_samples << " Test Samples" << endl; | ||
// Create a DTrees classifier. | ||
// | ||
cv::Ptr<cv::ml::RTrees> dtree = cv::ml::RTrees::create(); | ||
// set parameters | ||
// | ||
// These are the parameters from the old mushrooms.cpp code | ||
// Set up priors to penalize "poisonous" 10x as much as "edible" | ||
// | ||
float _priors[] = {1.0, 10.0}; | ||
cv::Mat priors(1, 2, CV_32F, _priors); | ||
dtree->setMaxDepth(8); | ||
dtree->setMinSampleCount(10); | ||
dtree->setRegressionAccuracy(0.01f); | ||
dtree->setUseSurrogates(false /* true */); | ||
dtree->setMaxCategories(15); | ||
dtree->setCVFolds(0 /*10*/); // nonzero causes core dump | ||
dtree->setUse1SERule(true); | ||
dtree->setTruncatePrunedTree(true); | ||
// dtree->setPriors( priors ); | ||
dtree->setPriors(cv::Mat()); // ignore priors for now... | ||
// Now train the model | ||
// NB: we are only using the "train" part of the data set | ||
// | ||
dtree->train(data_set); | ||
// Having successfully trained the data, we should be able | ||
// to calculate the error on both the training data, as well | ||
// as the test data that we held out. | ||
// | ||
cv::Mat results; | ||
float train_performance = dtree->calcError(data_set, false, | ||
// use train data | ||
results // cv::noArray() | ||
); | ||
std::vector<cv::String> names; | ||
data_set->getNames(names); | ||
Mat flags = data_set->getVarSymbolFlags(); | ||
// Compute some statistics on our own: | ||
// | ||
{ | ||
cv::Mat expected_responses = data_set->getResponses(); | ||
int good = 0, bad = 0, total = 0; | ||
for (int i = 0; i < data_set->getNTrainSamples(); ++i) { | ||
float received = results.at<float>(i, 0); | ||
float expected = expected_responses.at<float>(i, 0); | ||
cv::String r_str = names[(int)received]; | ||
cv::String e_str = names[(int)expected]; | ||
cout << "Expected: " << e_str << ", got: " << r_str << endl; | ||
if (received == expected) | ||
good++; | ||
else | ||
bad++; | ||
total++; | ||
} | ||
cout << "Correct answers: " <<(float(good)/total) <<" % " << endl; | ||
cout << "Incorrect answers: " << (float(bad) / total) << "%" | ||
<< endl; | ||
} | ||
float test_performance = dtree->calcError(data_set, true, | ||
// use test data | ||
results // cv::noArray() | ||
); | ||
cout << "Performance on training data: " << train_performance << "%" << endl; | ||
cout << "Performance on test data: " <<test_performance <<" % " <<endl; | ||
return 0; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
Index of mushroom | ||
|
||
02 Dec 1996 193 Index | ||
25 Jun 1990 111577 expanded.Z | ||
26 Feb 1990 4167 agaricus-lepiota.names | ||
30 May 1989 853 README | ||
30 May 1989 373704 agaricus-lepiota.data |
Oops, something went wrong.