Skip to content

Commit

Permalink
Merge pull request #27 from plming/type-hint
Browse files Browse the repository at this point in the history
add type hints
  • Loading branch information
mthh committed Nov 10, 2022
2 parents c4156c4 + c30ac18 commit 1951332
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 39 deletions.
81 changes: 42 additions & 39 deletions jenkspy/core.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
# -*- coding: utf-8 -*-
import numpy as np
from collections.abc import Iterable
from collections.abc import Iterable as IterableType
from typing import List, Dict, Union, Iterable, Sequence
from . import jenks


class JenksNaturalBreaks:
"""
A class that can be used to classify a sequence of numbers into groups (clusters) using Fisher-Jenks natural breaks.
"""
def __init__(self, n_classes=6):

def __init__(self, n_classes: int = 6) -> None:
"""
Parameters
----------
Expand All @@ -17,68 +19,68 @@ def __init__(self, n_classes=6):
"""
self.n_classes = n_classes

def __repr__(self):
def __repr__(self) -> str:
return f"JenksNaturalBreaks(n_classes={self.n_classes})"

def __str__(self):
def __str__(self) -> str:
return f"JenksNaturalBreaks(n_classes={self.n_classes})"

def fit(self, x):
def fit(self, x: Sequence[float]) -> None:
"""
Parameters
----------
x : array-like
The Iterable sequence of numbers (integer/float) to be classified.
The sequence of numbers (integer/float) to be classified.
"""
self.breaks_ = jenks_breaks(x, self.n_classes)
self.inner_breaks_ = self.breaks_[1:-1] # because inner_breaks is more
self.labels_ = self.predict(x)
self.groups_ = self.group(x)

def predict(self, x):
def predict(self, x: Union[float, Iterable[float]]) -> np.ndarray:
"""
Predicts the class of each element in x.
Parameters
----------
x : array-like
x : scalar or array-like
Returns
-------
list
numpy.array
"""
if np.size(x) == 1:
if not isinstance(x, IterableType):
return np.array(self.get_label_(x, idx=0))
else:
labels_ = []
for val in x:
label_ = self.get_label_(val, idx=0)
labels_.append(label_)
return np.array(labels_)
def group(self, x):

labels_ = []
for val in x:
label_ = self.get_label_(val, idx=0)
labels_.append(label_)
return np.array(labels_)

def group(self, x: Sequence[float]) -> List[np.ndarray]:
"""
Groups the elements in x into groups according to the classifier.
Parameters
----------
x : array-like
The Iterable sequence of numbers (integer/float) to be classified.
The sequence of numbers (integer/float) to be classified.
Returns
-------
list
list of numpy.array
The list of groups that contains the values of x.
"""
x = np.array(x)
groups_ = [x[x <= self.inner_breaks_[0]]]
arr = np.array(x)
groups_ = [arr[arr <= self.inner_breaks_[0]]]
for idx in range(len(self.inner_breaks_))[:-1]:
groups_.append(x[(x > self.inner_breaks_[idx]) * (x <= self.inner_breaks_[idx+1])])
groups_.append(x[x > self.inner_breaks_[-1]])
groups_.append(arr[(arr > self.inner_breaks_[idx])*(arr <= self.inner_breaks_[idx + 1])])
groups_.append(arr[arr > self.inner_breaks_[-1]])
return groups_
def goodness_of_variance_fit(self, x):

def goodness_of_variance_fit(self, x: Sequence[float]) -> float:
"""
Parameters
----------
Expand All @@ -89,17 +91,17 @@ def goodness_of_variance_fit(self, x):
float
The goodness of variance fit.
"""
x = np.array(x)
array_mean = np.mean(x)
sdam = sum([(value - array_mean)**2 for value in x])
arr = np.array(x)
array_mean = np.mean(arr)
sdam = sum([(value - array_mean) ** 2 for value in arr])
sdcm = 0
for group in self.groups_:
group_mean = np.mean(group)
sdcm += sum([(value - group_mean)**2 for value in group])
sdcm += sum([(value - group_mean) ** 2 for value in group])
gvf = (sdam - sdcm)/sdam
return gvf

def get_label_(self, val, idx=0):
def get_label_(self, val: float, idx: int = 0) -> int:
"""
Compute the group label of the given value.
Expand All @@ -117,15 +119,15 @@ def get_label_(self, val, idx=0):
if val <= self.inner_breaks_[idx]:
return idx
else:
idx = self.get_label_(val, idx+1)
idx = self.get_label_(val, idx + 1)
return idx
except:
return len(self.inner_breaks_)


def validate_input(values, n_classes):
# Check input so that we have an Iterable sequence of numbers
if not isinstance(values, Iterable) or isinstance(values, (str, bytes)):
def validate_input(values: Sequence[float], n_classes: int) -> int:
# Check input so that we have a sequence of numbers
if not isinstance(values, IterableType) or isinstance(values, (str, bytes)):
raise TypeError("A sequence of numbers is expected")

# Number of classes have to be an integer
Expand Down Expand Up @@ -162,15 +164,16 @@ def validate_input(values, n_classes):

return n_classes

def jenks_breaks(values, n_classes):

def jenks_breaks(values: Sequence[float], n_classes: int) -> List[float]:
"""
Compute natural breaks (Fisher-Jenks algorithm) on a sequence of `values`,
given `n_classes`, the number of desired class.
Parameters
----------
values : array-like
The Iterable sequence of numbers (integer/float) to be used.
The sequence of numbers (integer/float) to be used.
n_classes : int
The desired number of class. Have to be lesser than or equal
to the length of `values` and greater than or equal to 1.
Expand All @@ -196,7 +199,7 @@ def jenks_breaks(values, n_classes):
return jenks._jenks_breaks(values, n_classes)


def _jenks_matrices(values, n_classes, testing_algo=False):
def _jenks_matrices(values: Sequence[float], n_classes: int, testing_algo: bool = False) -> Dict[str, np.ndarray]:
"""
Returns the intermediate matrices (lower_class_limits and variance combinations)
that are created when computing natural breaks (Fisher-Jenks algorithm) on a sequence of `values`,
Expand All @@ -207,7 +210,7 @@ def _jenks_matrices(values, n_classes, testing_algo=False):
Parameters
----------
values : array-like
The Iterable sequence of numbers (integer/float) to be used.
The sequence of numbers (integer/float) to be used.
n_classes : int
The desired number of class. Have to be lesser than or equal
to the length of `values` and greater than or equal to 1.
Expand Down
7 changes: 7 additions & 0 deletions tests/test_jenks.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,12 +247,19 @@ def test_predict_multiple_values(self):
jnb.predict(150)

jnb.fit(self.data2)

# predict iterable return numpy array
predicted = jnb.predict([150, 700])
for val_predict, val_true in zip(predicted, self.res7):
self.assertEqual(val_predict, val_true)
self.assertEqual(type(val_predict).__name__, type(val_true).__name__)

# also works with other iterable as argument
predicted = jnb.predict((i for i in range(150, 701, 550)))
for val_predict, val_true in zip(predicted, self.res7):
self.assertEqual(val_predict, val_true)
self.assertEqual(type(val_predict).__name__, type(val_true).__name__)

def test_grouping(self):
"""
JenksNaturalBreaks.groups groups new values according to the existing breaks
Expand Down

0 comments on commit 1951332

Please sign in to comment.