Skip to content

Commit

Permalink
Update decision tree plot code
Browse files Browse the repository at this point in the history
  • Loading branch information
Lindsey Berlin committed Aug 17, 2020
1 parent accc048 commit 8dacd50
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 37 deletions.
60 changes: 24 additions & 36 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,12 @@ In order to prepare data, train, evaluate, and visualize a decision tree, we wil
```python
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score
from sklearn.tree import export_graphviz
from sklearn.preprocessing import OneHotEncoder
from IPython.display import Image
from sklearn.tree import export_graphviz
from pydotplus import graph_from_dot_data
from sklearn import tree
```

The play tennis dataset is available in the repo as `'tennis.csv'`. For this step, we'll start by importing the csv file as a pandas DataFrame.
Expand Down Expand Up @@ -71,39 +69,39 @@ df.head()
</thead>
<tbody>
<tr>
<th>0</th>
<td>0</td>
<td>sunny</td>
<td>hot</td>
<td>high</td>
<td>False</td>
<td>no</td>
</tr>
<tr>
<th>1</th>
<td>1</td>
<td>sunny</td>
<td>hot</td>
<td>high</td>
<td>True</td>
<td>no</td>
</tr>
<tr>
<th>2</th>
<td>2</td>
<td>overcast</td>
<td>hot</td>
<td>high</td>
<td>False</td>
<td>yes</td>
</tr>
<tr>
<th>3</th>
<td>3</td>
<td>rainy</td>
<td>mild</td>
<td>high</td>
<td>False</td>
<td>yes</td>
</tr>
<tr>
<th>4</th>
<td>4</td>
<td>rainy</td>
<td>cool</td>
<td>normal</td>
Expand All @@ -122,8 +120,8 @@ Before we do anything we'll want to split our data into **_training_** and **_te


```python
X = df.loc[:, ['outlook', 'temp', 'humidity', 'windy']]
y = df.loc[:, 'play']
X = df[['outlook', 'temp', 'humidity', 'windy']]
y = df[['play']]

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.3, random_state = 42)
```
Expand Down Expand Up @@ -181,7 +179,7 @@ ohe_df.head()
</thead>
<tbody>
<tr>
<th>0</th>
<td>0</td>
<td>0.0</td>
<td>0.0</td>
<td>1.0</td>
Expand All @@ -194,7 +192,7 @@ ohe_df.head()
<td>0.0</td>
</tr>
<tr>
<th>1</th>
<td>1</td>
<td>1.0</td>
<td>0.0</td>
<td>0.0</td>
Expand All @@ -207,7 +205,7 @@ ohe_df.head()
<td>0.0</td>
</tr>
<tr>
<th>2</th>
<td>2</td>
<td>0.0</td>
<td>0.0</td>
<td>1.0</td>
Expand All @@ -220,7 +218,7 @@ ohe_df.head()
<td>1.0</td>
</tr>
<tr>
<th>3</th>
<td>3</td>
<td>0.0</td>
<td>1.0</td>
<td>0.0</td>
Expand All @@ -233,7 +231,7 @@ ohe_df.head()
<td>1.0</td>
</tr>
<tr>
<th>4</th>
<td>4</td>
<td>0.0</td>
<td>1.0</td>
<td>0.0</td>
Expand Down Expand Up @@ -268,43 +266,33 @@ clf.fit(X_train_ohe, y_train)



DecisionTreeClassifier(class_weight=None, criterion='entropy', max_depth=None,
max_features=None, max_leaf_nodes=None,
DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='entropy',
max_depth=None, max_features=None, max_leaf_nodes=None,
min_impurity_decrease=0.0, min_impurity_split=None,
min_samples_leaf=1, min_samples_split=2,
min_weight_fraction_leaf=0.0, presort=False,
min_weight_fraction_leaf=0.0, presort='deprecated',
random_state=None, splitter='best')



## Plot the decision tree

You can see what rules the tree learned by plotting this decision tree. To do this, you need to use additional packages such as `pytdotplus`.

> **Note:** If you are run into errors while generating the plot, you probably need to install `python-graphviz` in your machine using `conda install python-graphviz`.
You can see what rules the tree learned by plotting this decision tree, using matplotlib and sklearn's `plot_tree` function.


```python
# Create DOT data
dot_data = export_graphviz(clf, out_file=None,
feature_names=ohe_df.columns,
class_names=np.unique(y).astype('str'),
filled=True, rounded=True, special_characters=True)

# Draw graph
graph = graph_from_dot_data(dot_data)

# Show graph
Image(graph.create_png())
fig, axes = plt.subplots(nrows = 1,ncols = 1, figsize = (3,3), dpi=300)
tree.plot_tree(clf,
feature_names = ohe_df.columns,
class_names=np.unique(y).astype('str'),
filled = True)
plt.show()
```




![png](index_files/index_11_0.png)



## Evaluate the predictive performance

Now that we have a trained model, we can generate some predictions, and go on to see how accurate our predictions are. We can use a simple accuracy measure, AUC, a confusion matrix, or all of them. This step is performed in the exactly the same manner, so it doesn't matter which classifier you are dealing with.
Expand Down
2 changes: 1 addition & 1 deletion index.ipynb

Large diffs are not rendered by default.

Binary file modified index_files/index_11_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 8dacd50

Please sign in to comment.