diff --git a/jenkspy/core.py b/jenkspy/core.py index 2a962d6..4e37736 100644 --- a/jenkspy/core.py +++ b/jenkspy/core.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- import numpy as np -from collections.abc import Iterable +from collections.abc import Iterable as IterableType +from typing import List, Dict, Union, Iterable, Sequence from . import jenks @@ -8,7 +9,8 @@ class JenksNaturalBreaks: """ A class that can be used to classify a sequence of numbers into groups (clusters) using Fisher-Jenks natural breaks. """ - def __init__(self, n_classes=6): + + def __init__(self, n_classes: int = 6) -> None: """ Parameters ---------- @@ -17,18 +19,18 @@ def __init__(self, n_classes=6): """ self.n_classes = n_classes - def __repr__(self): + def __repr__(self) -> str: return f"JenksNaturalBreaks(n_classes={self.n_classes})" - def __str__(self): + def __str__(self) -> str: return f"JenksNaturalBreaks(n_classes={self.n_classes})" - def fit(self, x): + def fit(self, x: Sequence[float]) -> None: """ Parameters ---------- x : array-like - The Iterable sequence of numbers (integer/float) to be classified. + The sequence of numbers (integer/float) to be classified. """ self.breaks_ = jenks_breaks(x, self.n_classes) @@ -36,49 +38,49 @@ def fit(self, x): self.labels_ = self.predict(x) self.groups_ = self.group(x) - def predict(self, x): + def predict(self, x: Union[float, Iterable[float]]) -> np.ndarray: """ Predicts the class of each element in x. Parameters ---------- - x : array-like + x : scalar or array-like Returns ------- - list + numpy.array """ - if np.size(x) == 1: + if not isinstance(x, IterableType): return np.array(self.get_label_(x, idx=0)) - else: - labels_ = [] - for val in x: - label_ = self.get_label_(val, idx=0) - labels_.append(label_) - return np.array(labels_) - - def group(self, x): + + labels_ = [] + for val in x: + label_ = self.get_label_(val, idx=0) + labels_.append(label_) + return np.array(labels_) + + def group(self, x: Sequence[float]) -> List[np.ndarray]: """ Groups the elements in x into groups according to the classifier. Parameters ---------- x : array-like - The Iterable sequence of numbers (integer/float) to be classified. + The sequence of numbers (integer/float) to be classified. Returns ------- - list + list of numpy.array The list of groups that contains the values of x. """ - x = np.array(x) - groups_ = [x[x <= self.inner_breaks_[0]]] + arr = np.array(x) + groups_ = [arr[arr <= self.inner_breaks_[0]]] for idx in range(len(self.inner_breaks_))[:-1]: - groups_.append(x[(x > self.inner_breaks_[idx]) * (x <= self.inner_breaks_[idx+1])]) - groups_.append(x[x > self.inner_breaks_[-1]]) + groups_.append(arr[(arr > self.inner_breaks_[idx])*(arr <= self.inner_breaks_[idx + 1])]) + groups_.append(arr[arr > self.inner_breaks_[-1]]) return groups_ - - def goodness_of_variance_fit(self, x): + + def goodness_of_variance_fit(self, x: Sequence[float]) -> float: """ Parameters ---------- @@ -89,17 +91,17 @@ def goodness_of_variance_fit(self, x): float The goodness of variance fit. """ - x = np.array(x) - array_mean = np.mean(x) - sdam = sum([(value - array_mean)**2 for value in x]) + arr = np.array(x) + array_mean = np.mean(arr) + sdam = sum([(value - array_mean) ** 2 for value in arr]) sdcm = 0 for group in self.groups_: group_mean = np.mean(group) - sdcm += sum([(value - group_mean)**2 for value in group]) + sdcm += sum([(value - group_mean) ** 2 for value in group]) gvf = (sdam - sdcm)/sdam return gvf - def get_label_(self, val, idx=0): + def get_label_(self, val: float, idx: int = 0) -> int: """ Compute the group label of the given value. @@ -117,15 +119,15 @@ def get_label_(self, val, idx=0): if val <= self.inner_breaks_[idx]: return idx else: - idx = self.get_label_(val, idx+1) + idx = self.get_label_(val, idx + 1) return idx except: return len(self.inner_breaks_) -def validate_input(values, n_classes): - # Check input so that we have an Iterable sequence of numbers - if not isinstance(values, Iterable) or isinstance(values, (str, bytes)): +def validate_input(values: Sequence[float], n_classes: int) -> int: + # Check input so that we have a sequence of numbers + if not isinstance(values, IterableType) or isinstance(values, (str, bytes)): raise TypeError("A sequence of numbers is expected") # Number of classes have to be an integer @@ -162,7 +164,8 @@ def validate_input(values, n_classes): return n_classes -def jenks_breaks(values, n_classes): + +def jenks_breaks(values: Sequence[float], n_classes: int) -> List[float]: """ Compute natural breaks (Fisher-Jenks algorithm) on a sequence of `values`, given `n_classes`, the number of desired class. @@ -170,7 +173,7 @@ def jenks_breaks(values, n_classes): Parameters ---------- values : array-like - The Iterable sequence of numbers (integer/float) to be used. + The sequence of numbers (integer/float) to be used. n_classes : int The desired number of class. Have to be lesser than or equal to the length of `values` and greater than or equal to 1. @@ -196,7 +199,7 @@ def jenks_breaks(values, n_classes): return jenks._jenks_breaks(values, n_classes) -def _jenks_matrices(values, n_classes, testing_algo=False): +def _jenks_matrices(values: Sequence[float], n_classes: int, testing_algo: bool = False) -> Dict[str, np.ndarray]: """ Returns the intermediate matrices (lower_class_limits and variance combinations) that are created when computing natural breaks (Fisher-Jenks algorithm) on a sequence of `values`, @@ -207,7 +210,7 @@ def _jenks_matrices(values, n_classes, testing_algo=False): Parameters ---------- values : array-like - The Iterable sequence of numbers (integer/float) to be used. + The sequence of numbers (integer/float) to be used. n_classes : int The desired number of class. Have to be lesser than or equal to the length of `values` and greater than or equal to 1. diff --git a/tests/test_jenks.py b/tests/test_jenks.py index ce7d697..b4eed76 100755 --- a/tests/test_jenks.py +++ b/tests/test_jenks.py @@ -247,12 +247,19 @@ def test_predict_multiple_values(self): jnb.predict(150) jnb.fit(self.data2) + # predict iterable return numpy array predicted = jnb.predict([150, 700]) for val_predict, val_true in zip(predicted, self.res7): self.assertEqual(val_predict, val_true) self.assertEqual(type(val_predict).__name__, type(val_true).__name__) + # also works with other iterable as argument + predicted = jnb.predict((i for i in range(150, 701, 550))) + for val_predict, val_true in zip(predicted, self.res7): + self.assertEqual(val_predict, val_true) + self.assertEqual(type(val_predict).__name__, type(val_true).__name__) + def test_grouping(self): """ JenksNaturalBreaks.groups groups new values according to the existing breaks