Skip to content

Commit

Permalink
Adding the Decision tree and the mushroom database
Browse files Browse the repository at this point in the history
  • Loading branch information
Prasanna committed Jun 11, 2017
1 parent a1630a0 commit 7375ccf
Show file tree
Hide file tree
Showing 7 changed files with 8,390 additions and 0 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ add_executable( example_17-01 example_17-01.cpp )
add_executable( example_18-01 example_18-01.cpp )
add_executable( example_20-01 example_20-01.cpp )
add_executable( example_20-02 example_20-02.cpp )
add_executable( example_21-01 example_21-01.cpp )
#...

target_link_libraries( example_02-01 ${OpenCV_LIBS} )
Expand Down Expand Up @@ -120,4 +121,5 @@ target_link_libraries( example_17-01 ${OpenCV_LIBS} )
target_link_libraries( example_18-01 ${OpenCV_LIBS} )
target_link_libraries( example_20-01 ${OpenCV_LIBS} )
target_link_libraries( example_20-02 ${OpenCV_LIBS} )
target_link_libraries( example_21-01 ${OpenCV_LIBS} )
#...
106 changes: 106 additions & 0 deletions example_21-01.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
#include <opencv2/opencv.hpp>
#include <stdio.h>
#include <iostream>
using namespace std;
using namespace cv;
int main(int argc, char *argv[]) {
// If the caller gave a filename, great. Otherwise, use a default.
//
const char *csv_file_name = argc >= 2 ? argv[1] : "agaricus-lepiota.data";
cout << "OpenCV Version: " << CV_VERSION << endl;
// Read in the CSV file that we were given.
//
cv::Ptr<cv::ml::TrainData> data_set =
cv::ml::TrainData::loadFromCSV(csv_file_name,
// Input file name
0,
// Header lines (ignore this many)
0,
// Responses are (start) at thie column
1,
// Inputs start at this column
"cat[0-22]"
// All 23 columns are categorical
);
// Use defaults for delimeter (',') and missch ('?')
// Verify that we read in what we think.
//
int n_samples = data_set->getNSamples();
if (n_samples == 0) {
cerr << "Could not read file: " << csv_file_name << endl;
exit(-1);
} else {
cout << "Read " << n_samples << " samples from " << csv_file_name << endl;
}
// Split the data, so that 90% is train data
//
data_set->setTrainTestSplitRatio(0.90, false);
int n_train_samples = data_set->getNTrainSamples();
int n_test_samples = data_set->getNTestSamples();
cout << "Found " << n_train_samples << " Train Samples, and "
<< n_test_samples << " Test Samples" << endl;
// Create a DTrees classifier.
//
cv::Ptr<cv::ml::RTrees> dtree = cv::ml::RTrees::create();
// set parameters
//
// These are the parameters from the old mushrooms.cpp code
// Set up priors to penalize "poisonous" 10x as much as "edible"
//
float _priors[] = {1.0, 10.0};
cv::Mat priors(1, 2, CV_32F, _priors);
dtree->setMaxDepth(8);
dtree->setMinSampleCount(10);
dtree->setRegressionAccuracy(0.01f);
dtree->setUseSurrogates(false /* true */);
dtree->setMaxCategories(15);
dtree->setCVFolds(0 /*10*/); // nonzero causes core dump
dtree->setUse1SERule(true);
dtree->setTruncatePrunedTree(true);
// dtree->setPriors( priors );
dtree->setPriors(cv::Mat()); // ignore priors for now...
// Now train the model
// NB: we are only using the "train" part of the data set
//
dtree->train(data_set);
// Having successfully trained the data, we should be able
// to calculate the error on both the training data, as well
// as the test data that we held out.
//
cv::Mat results;
float train_performance = dtree->calcError(data_set, false,
// use train data
results // cv::noArray()
);
std::vector<cv::String> names;
data_set->getNames(names);
Mat flags = data_set->getVarSymbolFlags();
// Compute some statistics on our own:
//
{
cv::Mat expected_responses = data_set->getResponses();
int good = 0, bad = 0, total = 0;
for (int i = 0; i < data_set->getNTrainSamples(); ++i) {
float received = results.at<float>(i, 0);
float expected = expected_responses.at<float>(i, 0);
cv::String r_str = names[(int)received];
cv::String e_str = names[(int)expected];
cout << "Expected: " << e_str << ", got: " << r_str << endl;
if (received == expected)
good++;
else
bad++;
total++;
}
cout << "Correct answers: " <<(float(good)/total) <<" % " << endl;
cout << "Incorrect answers: " << (float(bad) / total) << "%"
<< endl;
}
float test_performance = dtree->calcError(data_set, true,
// use test data
results // cv::noArray()
);
cout << "Performance on training data: " << train_performance << "%" << endl;
cout << "Performance on test data: " <<test_performance <<" % " <<endl;
return 0;
}
7 changes: 7 additions & 0 deletions mushroom/Index
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Index of mushroom

02 Dec 1996 193 Index
25 Jun 1990 111577 expanded.Z
26 Feb 1990 4167 agaricus-lepiota.names
30 May 1989 853 README
30 May 1989 373704 agaricus-lepiota.data
Loading

0 comments on commit 7375ccf

Please sign in to comment.