diff --git a/sklearn/ensemble/_gradient_boosting.pyx b/sklearn/ensemble/_gradient_boosting.pyx index 6224dee324a57..82537e56f3738 100644 --- a/sklearn/ensemble/_gradient_boosting.pyx +++ b/sklearn/ensemble/_gradient_boosting.pyx @@ -10,8 +10,8 @@ from scipy.sparse import issparse from sklearn.utils._typedefs cimport float32_t, float64_t, intp_t, int32_t, uint8_t # Note: _tree uses cimport numpy, cnp.import_array, so we need to include # numpy headers in the build configuration of this extension -from sklearn.tree._tree cimport Node from sklearn.tree._tree cimport Tree +from sklearn.tree._utils cimport Node from sklearn.tree._utils cimport safe_realloc diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index dc657365aaec1..85fe623166d44 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -125,6 +125,10 @@ class BaseDecisionTree(MultiOutputMixin, BaseEstimator, metaclass=ABCMeta): "min_impurity_decrease": [Interval(Real, 0.0, None, closed="left")], "ccp_alpha": [Interval(Real, 0.0, None, closed="left")], "monotonic_cst": ["array-like", None], + "categorical_features": [ + "array-like", + None, + ], } @abstractmethod @@ -144,6 +148,7 @@ def __init__( class_weight=None, ccp_alpha=0.0, monotonic_cst=None, + categorical_features=None, ): self.criterion = criterion self.splitter = splitter @@ -158,6 +163,7 @@ def __init__( self.class_weight = class_weight self.ccp_alpha = ccp_alpha self.monotonic_cst = monotonic_cst + self.categorical_features = categorical_features def get_depth(self): """Return the depth of the decision tree. @@ -259,13 +265,18 @@ def _fit( missing_values_in_feature_mask = ( self._compute_missing_values_in_feature_mask(X) ) - if issparse(X): + is_sparse_X = issparse(X) + if is_sparse_X: X.sort_indices() if X.indices.dtype != np.intc or X.indptr.dtype != np.intc: raise ValueError( "No support for np.int64 index based sparse matrices" ) + if is_sparse_X and self.categorical_features is not None: + raise NotImplementedError( + "Categorical features not supported with sparse inputs" + ) if self.criterion == "poisson": if np.any(y < 0): @@ -431,6 +442,10 @@ def _fit( # *positive class*, all signs must be flipped. monotonic_cst *= -1 + self.is_categorical_, n_categories_in_feature = ( + self._check_categorical_features(X, monotonic_cst) + ) + if not isinstance(self.splitter, Splitter): splitter = SPLITTERS[self.splitter]( criterion, @@ -442,13 +457,19 @@ def _fit( ) if is_classifier(self): - self.tree_ = Tree(self.n_features_in_, self.n_classes_, self.n_outputs_) + self.tree_ = Tree( + self.n_features_in_, + self.n_classes_, + self.n_outputs_, + self.is_categorical_, + ) else: self.tree_ = Tree( self.n_features_in_, # TODO: tree shouldn't need this in this case np.array([1] * self.n_outputs_, dtype=np.intp), self.n_outputs_, + self.is_categorical_, ) # Use BestFirst if max_leaf_nodes given; use DepthFirst otherwise @@ -472,7 +493,14 @@ def _fit( self.min_impurity_decrease, ) - builder.build(self.tree_, X, y, sample_weight, missing_values_in_feature_mask) + builder.build( + self.tree_, + X, + y, + sample_weight, + missing_values_in_feature_mask, + n_categories_in_feature, + ) if self.n_outputs_ == 1 and is_classifier(self): self.n_classes_ = self.n_classes_[0] @@ -482,6 +510,78 @@ def _fit( return self + def _check_categorical_features(self, X, monotonic_cst): + """Check and validate categorical features in X + + Parameters + ---------- + X : {array-like} of shape (n_samples, n_features) + Input data (after `validate_data` was called) + + Return + ------ + is_categorical : ndarray of shape (n_features,) or None, dtype=bool + Indicates whether a feature is categorical. If no feature is + categorical, this is None. + n_categories_in_feature: TODO + """ + n_features = X.shape[1] + categorical_features = np.asarray(self.categorical_features) + + if self.categorical_features is None or categorical_features.size == 0: + is_categorical = np.zeros(n_features, dtype=bool) + elif categorical_features.dtype.kind not in ("i", "b"): + raise ValueError( + "categorical_features must be an array-like of bool or int, " + f"got: {categorical_features.dtype.name}." + ) + elif categorical_features.dtype.kind == "i": + # check for categorical features as indices + if ( + np.max(categorical_features) >= n_features + or np.min(categorical_features) < 0 + ): + raise ValueError( + "categorical_features set as integer " + "indices must be in [0, n_features - 1]" + ) + is_categorical = np.zeros(n_features, dtype=bool) + is_categorical[categorical_features] = True + else: + if categorical_features.shape[0] != n_features: + raise ValueError( + "categorical_features set as a boolean mask " + "must have shape (n_features,), got: " + f"{categorical_features.shape}" + ) + is_categorical = categorical_features + + n_categories_in_feature = np.full(self.n_features_in_, -1, dtype=np.intp) + MAX_NC = 64 # TODO import from somewhere + base_msg = ( + f"Values for categorical features should be integers in [0, {MAX_NC - 1}]." + ) + for idx in np.where(is_categorical)[0]: + if np.isnan(X[:, idx]).any(): + raise ValueError( + "Missing values are not supported in categorical features" + ) + if not np.allclose(X[:, idx].astype(np.intp), X[:, idx]): + raise ValueError(f"{base_msg} Found non-integer values.") + if X[:, idx].min() < 0: + raise ValueError(f"{base_msg} Found negative values.") + X_idx_max = X[:, idx].max() + if X_idx_max >= MAX_NC: + raise ValueError(f"{base_msg} Found {X_idx_max}.") + n_categories_in_feature[idx] = X_idx_max + 1 + if monotonic_cst is not None and monotonic_cst[idx] != 0: + raise ValueError( + "A categorical feature cannot have a non-null monotonic" + " constraint. " + ) + + return is_categorical, n_categories_in_feature + def _validate_X_predict(self, X, check_input): """Validate the training data on predict (probabilities).""" if check_input: @@ -620,13 +720,16 @@ def _prune_tree(self): # build pruned tree if is_classifier(self): n_classes = np.atleast_1d(self.n_classes_) - pruned_tree = Tree(self.n_features_in_, n_classes, self.n_outputs_) + pruned_tree = Tree( + self.n_features_in_, n_classes, self.n_outputs_, self.is_categorical_ + ) else: pruned_tree = Tree( self.n_features_in_, # TODO: the tree shouldn't need this param np.array([1] * self.n_outputs_, dtype=np.intp), self.n_outputs_, + self.is_categorical_, ) _build_pruned_tree_ccp(pruned_tree, self.tree_, self.ccp_alpha) @@ -976,6 +1079,7 @@ def __init__( class_weight=None, ccp_alpha=0.0, monotonic_cst=None, + categorical_features=None, ): super().__init__( criterion=criterion, @@ -991,6 +1095,7 @@ def __init__( min_impurity_decrease=min_impurity_decrease, monotonic_cst=monotonic_cst, ccp_alpha=ccp_alpha, + categorical_features=categorical_features, ) @_fit_context(prefer_skip_nested_validation=True) @@ -1353,6 +1458,7 @@ def __init__( min_impurity_decrease=0.0, ccp_alpha=0.0, monotonic_cst=None, + categorical_features=None, ): super().__init__( criterion=criterion, @@ -1367,6 +1473,7 @@ def __init__( min_impurity_decrease=min_impurity_decrease, ccp_alpha=ccp_alpha, monotonic_cst=monotonic_cst, + categorical_features=categorical_features, ) @_fit_context(prefer_skip_nested_validation=True) diff --git a/sklearn/tree/_partitioner.pxd b/sklearn/tree/_partitioner.pxd index a8ea709ae787c..eda3721e0efc5 100644 --- a/sklearn/tree/_partitioner.pxd +++ b/sklearn/tree/_partitioner.pxd @@ -3,10 +3,12 @@ # See _partitioner.pyx for details. + from sklearn.utils._typedefs cimport ( - float32_t, float64_t, int8_t, int32_t, intp_t, uint8_t, uint32_t + float32_t, float64_t, int8_t, int32_t, intp_t, uint8_t, uint32_t, uint64_t ) from sklearn.tree._splitter cimport SplitRecord +from sklearn.tree._utils cimport SplitValue # Mitigate precision differences between 32 bit and 64 bit @@ -68,15 +70,25 @@ cdef class DensePartitioner: Note that this partitioner is agnostic to the splitting strategy (best vs. random). """ cdef const float32_t[:, :] X + cdef const float64_t[:, :] y + cdef const float64_t[::1] sample_weight cdef intp_t[::1] samples cdef float32_t[::1] feature_values cdef intp_t start cdef intp_t end cdef intp_t n_missing cdef const uint8_t[::1] missing_values_in_feature_mask + cdef const intp_t[::1] n_categories_in_feature cdef bint missing_on_the_left + cdef intp_t n_categories + + cdef intp_t[::1] counts + cdef float64_t[::1] weighted_counts + cdef float64_t[::1] means + cdef intp_t[::1] sorted_cat + cdef intp_t[::1] offsets - cdef void sort_samples_and_feature_values( + cdef bint sort_samples_and_feature_values( self, intp_t current_feature ) noexcept nogil cdef void shift_missing_to_the_left(self) noexcept nogil @@ -96,6 +108,9 @@ cdef class DensePartitioner: intp_t* p_prev, intp_t* p ) noexcept nogil + cdef inline SplitValue pos_to_threshold( + self, intp_t p_prev, intp_t p + ) noexcept nogil cdef intp_t partition_samples( self, float64_t current_threshold, @@ -104,10 +119,12 @@ cdef class DensePartitioner: cdef void partition_samples_final( self, intp_t best_pos, - float64_t best_threshold, + SplitValue split_value, intp_t best_feature, bint best_missing_go_to_left, ) noexcept nogil + cdef void _breiman_sort_categories(self, intp_t nc) noexcept nogil + cdef inline uint64_t _split_pos_to_bitset(self, intp_t p, intp_t nc) noexcept nogil cdef class SparsePartitioner: @@ -132,8 +149,9 @@ cdef class SparsePartitioner: cdef intp_t end cdef intp_t n_missing cdef const uint8_t[::1] missing_values_in_feature_mask + cdef intp_t n_categories - cdef void sort_samples_and_feature_values( + cdef bint sort_samples_and_feature_values( self, intp_t current_feature ) noexcept nogil cdef void shift_missing_to_the_left(self) noexcept nogil @@ -153,6 +171,9 @@ cdef class SparsePartitioner: intp_t* p_prev, intp_t* p ) noexcept nogil + cdef inline SplitValue pos_to_threshold( + self, intp_t p_prev, intp_t p + ) noexcept nogil cdef intp_t partition_samples( self, float64_t current_threshold, @@ -161,7 +182,7 @@ cdef class SparsePartitioner: cdef void partition_samples_final( self, intp_t best_pos, - float64_t best_threshold, + SplitValue split_value, intp_t best_feature, bint best_missing_go_to_left, ) noexcept nogil diff --git a/sklearn/tree/_partitioner.pyx b/sklearn/tree/_partitioner.pyx index 467f724ee1dce..9ef254dea41e8 100644 --- a/sklearn/tree/_partitioner.pyx +++ b/sklearn/tree/_partitioner.pyx @@ -13,7 +13,7 @@ and sparse data stored in a Compressed Sparse Column (CSC) format. from cython cimport final from libc.math cimport isnan, log2 from libc.stdlib cimport qsort -from libc.string cimport memcpy +from libc.string cimport memcpy, memset from ._utils cimport swap_array_slices @@ -25,9 +25,12 @@ from scipy.sparse import issparse # in SparsePartitioner cdef float32_t EXTRACT_NNZ_SWITCH = 0.1 +cdef float64_t INFINITY = np.inf # Allow for 32 bit float comparisons cdef float32_t INFINITY_32t = np.inf +cdef intp_t MAX_N_CAT = 64 + @final cdef class DensePartitioner: @@ -38,15 +41,29 @@ cdef class DensePartitioner: def __init__( self, const float32_t[:, :] X, + const float64_t[:, :] y, + const float64_t[::1] sample_weight, intp_t[::1] samples, float32_t[::1] feature_values, const uint8_t[::1] missing_values_in_feature_mask, + const intp_t[::1] n_categories_in_feature, ): self.X = X + self.y = y + self.sample_weight = sample_weight self.samples = samples self.feature_values = feature_values self.missing_values_in_feature_mask = missing_values_in_feature_mask + self.n_categories_in_feature = n_categories_in_feature self.missing_on_the_left = False + self.n_categories = 0 + + # for breiman shortcut: + self.counts = np.empty(MAX_N_CAT, dtype=np.intp) + self.weighted_counts = np.empty(MAX_N_CAT, dtype=np.float64) + self.means = np.empty(MAX_N_CAT, dtype=np.float64) + self.sorted_cat = np.empty(MAX_N_CAT, dtype=np.intp) + self.offsets = np.empty(MAX_N_CAT, dtype=np.intp) cdef inline void init_node_split(self, intp_t start, intp_t end) noexcept nogil: """Initialize splitter at the beginning of node_split.""" @@ -54,7 +71,7 @@ cdef class DensePartitioner: self.end = end self.n_missing = 0 - cdef inline void sort_samples_and_feature_values( + cdef inline bint sort_samples_and_feature_values( self, intp_t current_feature ) noexcept nogil: """Simultaneously sort based on the feature_values. @@ -98,9 +115,91 @@ cdef class DensePartitioner: for i in range(self.start, self.end): feature_values[i] = X[samples[i], current_feature] - sort(&feature_values[self.start], &samples[self.start], self.end - self.start - n_missing) self.missing_on_the_left = False self.n_missing = n_missing + self.n_categories = self.n_categories_in_feature[current_feature] + if n_missing == self.end - self.start: + return True + if self.n_categories <= 0: + # not a categorical feature + sort(&feature_values[self.start], &samples[self.start], self.end - self.start - n_missing) + if n_missing > 0: + return False + return feature_values[self.end - n_missing - 1] <= feature_values[self.start] + FEATURE_THRESHOLD + else: + self._breiman_sort_categories(self.n_categories) + return feature_values[self.start] == feature_values[self.end - 1] + + cdef void _breiman_sort_categories(self, intp_t nc) noexcept nogil: + """ + Order self.sorted_cat by ascending average target value + and order self.features_values & self.samples such that + - self.features_values is ordered according to the order of sorted_cat + - the relation `self.features_values[p] = self.X[self.samples[p], f]` is + preserved + + E.g. sorted_cat is [2 0 1] + features_values is [2 2 2 0 0 1 1 1 1] + + This ordering ensures the optimal split will be among the candidate splits + evaluated by the splitter (this is called the Brieman shortcut). + + Time complexity: O(n + nc log nc) + """ + cdef: + intp_t* counts = &self.counts[0] + float64_t* weighted_counts = &self.weighted_counts[0] + float64_t* means = &self.means[0] + intp_t* sorted_cat = &self.sorted_cat[0] + intp_t* offsets = &self.offsets[0] + float32_t* feature_values = &self.feature_values[0] + intp_t* samples = &self.samples[0] + intp_t c, r, p, new_p + float64_t w = 1. + + memset(means, 0, nc * sizeof(float64_t)) + memset(counts, 0, nc * sizeof(intp_t)) + memset(weighted_counts, 0, nc * sizeof(float64_t)) + + # compute counts, weighted_counts and means + for p in range(self.start, self.end): + c = feature_values[p] + counts[c] += 1 + if self.sample_weight is not None: + w = self.sample_weight[samples[p]] + means[c] += w * self.y[samples[p], 0] + self.weighted_counts[c] += w + + for c in range(nc): + if weighted_counts[c] > 0: + means[c] /= weighted_counts[c] + + # sorted_cat[i] = i-th categories sorted by ascending means + for c in range(nc): + sorted_cat[c] = c + sort(means, sorted_cat, nc) + + # build offsets such that: + # offsets[c] = sum( counts[x] for all x s.t. rank(x) <= rank(c) ) - 1 + cdef intp_t offset = 0 + for r in range(nc): + c = sorted_cat[r] + offset += counts[c] + offsets[c] = offset - 1 + + # sort feature_values & samples in-place such that + # they are ordered by the mean of the category + # while ensuring samples of the same categories are contiguous + p = self.start + while p < self.end: + c = feature_values[p] + new_p = offsets[c] + if new_p > p: + swap(feature_values, samples, p, new_p) + # swap preserves invariant: feature[p] = X[samples[p], f] + offsets[c] -= 1 + else: + p += 1 cdef void shift_missing_to_the_left(self) noexcept nogil: """ @@ -165,12 +264,20 @@ cdef class DensePartitioner: cdef intp_t end_non_missing = ( self.end if self.missing_on_the_left else self.end - self.n_missing) + cdef float32_t c if p[0] == end_non_missing and not self.missing_on_the_left: # skip the missing values up to the end # (which will end the for loop in the best split function) p[0] = self.end p_prev[0] = self.end + elif self.n_categories > 0: + c = self.feature_values[p[0]] + p[0] += 1 + while p[0] < end_non_missing and self.feature_values[p[0]] == c: + p[0] += 1 + + # p_prev is unused in this case else: if self.missing_on_the_left and p[0] == self.start: # skip the missing values up to the first non-missing value: @@ -183,6 +290,33 @@ cdef class DensePartitioner: p[0] += 1 p_prev[0] = p[0] - 1 + cdef inline SplitValue pos_to_threshold( + self, intp_t p_prev, intp_t p + ) noexcept nogil: + cdef SplitValue split + cdef intp_t end_non_missing = ( + self.end if self.missing_on_the_left + else self.end - self.n_missing) + + if self.n_categories > 0: + split.cat_split = self._split_pos_to_bitset(p, self.n_categories) + return split + + if p == end_non_missing and not self.missing_on_the_left: + # split with the right node being only the missing values + split.threshold = INFINITY + return split + + # split between two non-missing values + # sum of halves is used to avoid infinite value + split.threshold = ( + self.feature_values[p_prev] / 2.0 + self.feature_values[p] / 2.0 + ) + if split.threshold == INFINITY or split.threshold == -INFINITY: + split.threshold = self.feature_values[p_prev] + + return split + cdef inline intp_t partition_samples( self, float64_t current_threshold, @@ -212,7 +346,7 @@ cdef class DensePartitioner: cdef inline void partition_samples_final( self, intp_t best_pos, - float64_t best_threshold, + SplitValue split_value, intp_t best_feature, bint best_missing_go_to_left ) noexcept nogil: @@ -227,20 +361,28 @@ cdef class DensePartitioner: intp_t partition_end = self.end intp_t* samples = &self.samples[0] float32_t current_value - bint go_to_left + bint is_cat = self.n_categories_in_feature[best_feature] > 0 while p < partition_end: current_value = self.X[samples[p], best_feature] - go_to_left = ( - best_missing_go_to_left if isnan(current_value) - else current_value <= best_threshold - ) - if go_to_left: + if goes_left(split_value, best_missing_go_to_left, is_cat, current_value): p += 1 else: partition_end -= 1 samples[p], samples[partition_end] = samples[partition_end], samples[p] + cdef inline uint64_t _split_pos_to_bitset(self, intp_t p, intp_t nc) noexcept nogil: + cdef uint64_t bitset = 0 + cdef intp_t r, c + cdef intp_t offset = 0 + for r in range(nc): + c = self.sorted_cat[r] + bitset |= 1 << c + offset += self.counts[c] + if offset >= p: + break + return bitset + @final cdef class SparsePartitioner: @@ -279,6 +421,7 @@ cdef class SparsePartitioner: self.index_to_samples[samples[p]] = p self.missing_values_in_feature_mask = missing_values_in_feature_mask + self.n_categories = 0 cdef inline void init_node_split(self, intp_t start, intp_t end) noexcept nogil: """Initialize splitter at the beginning of node_split.""" @@ -287,7 +430,7 @@ cdef class SparsePartitioner: self.is_samples_sorted = 0 self.n_missing = 0 - cdef inline void sort_samples_and_feature_values( + cdef inline bint sort_samples_and_feature_values( self, intp_t current_feature ) noexcept nogil: @@ -326,6 +469,8 @@ cdef class SparsePartitioner: # number of missing values for current_feature self.n_missing = 0 + return feature_values[self.end - 1] <= feature_values[self.start] + FEATURE_THRESHOLD + cdef void shift_missing_to_the_left(self) noexcept nogil: pass # missing values not support for sparse @@ -392,6 +537,21 @@ cdef class SparsePartitioner: p_prev[0] = p[0] p[0] = p_next + cdef inline SplitValue pos_to_threshold( + self, intp_t p_prev, intp_t p + ) noexcept nogil: + + cdef SplitValue split + # split between two non-missing values + # sum of halves is used to avoid infinite value + split.threshold = ( + self.feature_values[p_prev] / 2.0 + self.feature_values[p] / 2.0 + ) + if split.threshold == INFINITY or split.threshold == -INFINITY: + split.threshold = self.feature_values[p_prev] + + return split + cdef inline intp_t partition_samples( self, float64_t current_threshold, @@ -403,13 +563,13 @@ cdef class SparsePartitioner: cdef inline void partition_samples_final( self, intp_t best_pos, - float64_t best_threshold, + SplitValue split_value, intp_t best_feature, bint missing_go_to_left ) noexcept nogil: """Partition samples for X at the best_threshold and best_feature.""" self.extract_nnz(best_feature) - self._partition(best_threshold, best_pos) + self._partition(split_value.threshold, best_pos) cdef inline intp_t _partition(self, float64_t threshold, intp_t zero_pos) noexcept nogil: """Partition samples[start:end] based on threshold.""" @@ -499,6 +659,17 @@ cdef class SparsePartitioner: &self.end_negative, &self.start_positive) +cdef inline bint goes_left( + SplitValue split_value, bint missing_go_to_left, bint is_categorical, float32_t value +) noexcept nogil: + if isnan(value): + return missing_go_to_left + elif is_categorical: + return split_value.cat_split & (1 << ( value)) + else: + return value <= split_value.threshold + + cdef int compare_SIZE_t(const void* a, const void* b) noexcept nogil: """Comparison function for sort. @@ -656,26 +827,30 @@ def _py_sort(float32_t[::1] feature_values, intp_t[::1] samples, intp_t n): sort(&feature_values[0], &samples[0], n) +ctypedef fused floating_t: + float32_t + float64_t + # Sort n-element arrays pointed to by feature_values and samples, simultaneously, # by the values in feature_values. Algorithm: Introsort (Musser, SP&E, 1997). -cdef inline void sort(float32_t* feature_values, intp_t* samples, intp_t n) noexcept nogil: +cdef inline void sort(floating_t* feature_values, intp_t* samples, intp_t n) noexcept nogil: if n == 0: return cdef intp_t maxd = 2 * log2(n) introsort(feature_values, samples, n, maxd) -cdef inline void swap(float32_t* feature_values, intp_t* samples, +cdef inline void swap(floating_t* feature_values, intp_t* samples, intp_t i, intp_t j) noexcept nogil: # Helper for sort feature_values[i], feature_values[j] = feature_values[j], feature_values[i] samples[i], samples[j] = samples[j], samples[i] -cdef inline float32_t median3(float32_t* feature_values, intp_t n) noexcept nogil: +cdef inline floating_t median3(floating_t* feature_values, intp_t n) noexcept nogil: # Median of three pivot selection, after Bentley and McIlroy (1993). # Engineering a sort function. SP&E. Requires 8/3 comparisons on average. - cdef float32_t a = feature_values[0], b = feature_values[n / 2], c = feature_values[n - 1] + cdef floating_t a = feature_values[0], b = feature_values[n / 2], c = feature_values[n - 1] if a < b: if b < c: return b @@ -694,9 +869,9 @@ cdef inline float32_t median3(float32_t* feature_values, intp_t n) noexcept nogi # Introsort with median of 3 pivot selection and 3-way partition function # (robust to repeated elements, e.g. lots of zero features). -cdef void introsort(float32_t* feature_values, intp_t *samples, +cdef void introsort(floating_t* feature_values, intp_t *samples, intp_t n, intp_t maxd) noexcept nogil: - cdef float32_t pivot + cdef floating_t pivot cdef intp_t i, l, r while n > 1: @@ -727,7 +902,7 @@ cdef void introsort(float32_t* feature_values, intp_t *samples, n -= r -cdef inline void sift_down(float32_t* feature_values, intp_t* samples, +cdef inline void sift_down(floating_t* feature_values, intp_t* samples, intp_t start, intp_t end) noexcept nogil: # Restore heap order in feature_values[start:end] by moving the max element to start. cdef intp_t child, maxind, root @@ -750,7 +925,7 @@ cdef inline void sift_down(float32_t* feature_values, intp_t* samples, root = maxind -cdef void heapsort(float32_t* feature_values, intp_t* samples, intp_t n) noexcept nogil: +cdef void heapsort(floating_t* feature_values, intp_t* samples, intp_t n) noexcept nogil: cdef intp_t start, end # heapify diff --git a/sklearn/tree/_splitter.pxd b/sklearn/tree/_splitter.pxd index c0b264f188b5d..574738f69d6c8 100644 --- a/sklearn/tree/_splitter.pxd +++ b/sklearn/tree/_splitter.pxd @@ -8,6 +8,7 @@ from sklearn.utils._typedefs cimport ( ) from sklearn.tree._criterion cimport Criterion from sklearn.tree._tree cimport ParentInfo +from sklearn.tree._utils cimport SplitValue cdef struct SplitRecord: @@ -16,7 +17,7 @@ cdef struct SplitRecord: intp_t pos # Split samples array at the given position, # # i.e. count of samples below threshold for feature. # # pos is >= end if the node is a leaf. - float64_t threshold # Threshold to split at. + SplitValue value # Threshold/Bitset to split at. float64_t improvement # Impurity improvement given parent node. float64_t impurity_left # Impurity of the left split. float64_t impurity_right # Impurity of the right split. @@ -84,6 +85,7 @@ cdef class Splitter: const float64_t[:, ::1] y, const float64_t[:] sample_weight, const uint8_t[::1] missing_values_in_feature_mask, + const intp_t[::1] n_categories_in_feature, ) except -1 cdef int node_reset( diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index e1d9fb7d3b0e2..a133fd788b101 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -50,7 +50,6 @@ cdef inline void _init_split(SplitRecord* self, intp_t start_pos) noexcept nogil self.impurity_right = INFINITY self.pos = start_pos self.feature = 0 - self.threshold = 0. self.improvement = -INFINITY self.missing_go_to_left = False self.n_missing = 0 @@ -130,6 +129,7 @@ cdef class Splitter: const float64_t[:, ::1] y, const float64_t[:] sample_weight, const uint8_t[::1] missing_values_in_feature_mask, + const intp_t[::1] n_categories_in_feature, ) except -1: """Initialize the splitter. @@ -282,7 +282,6 @@ cdef inline int node_split_best( # Find the best split cdef intp_t start = splitter.start cdef intp_t end = splitter.end - cdef intp_t end_non_missing cdef intp_t n_missing = 0 cdef bint has_missing = 0 cdef intp_t n_searches @@ -293,7 +292,6 @@ cdef inline int node_split_best( cdef intp_t[::1] constant_features = splitter.constant_features cdef intp_t n_features = splitter.n_features - cdef float32_t[::1] feature_values = splitter.feature_values cdef intp_t max_features = splitter.max_features cdef intp_t min_samples_leaf = splitter.min_samples_leaf cdef float64_t min_weight_leaf = splitter.min_weight_leaf @@ -309,6 +307,7 @@ cdef inline int node_split_best( cdef intp_t f_i = n_features cdef intp_t f_j + cdef bint is_constant cdef intp_t p cdef intp_t p_prev @@ -368,16 +367,10 @@ cdef inline int node_split_best( f_j += n_found_constants # f_j in the interval [n_total_constants, f_i[ current_split.feature = features[f_j] - partitioner.sort_samples_and_feature_values(current_split.feature) + is_constant = partitioner.sort_samples_and_feature_values(current_split.feature) n_missing = partitioner.n_missing - end_non_missing = end - n_missing - if ( - # All values for this feature are missing, or - end_non_missing == start or - # This feature is considered constant (max - min <= FEATURE_THRESHOLD) - feature_values[end_non_missing - 1] <= feature_values[start] + FEATURE_THRESHOLD - ): + if is_constant: # We consider this feature constant in this case. # Since finding a split among constant feature is not valuable, # we do not consider this feature for splitting. @@ -446,21 +439,7 @@ cdef inline int node_split_best( if current_proxy_improvement > best_proxy_improvement: best_proxy_improvement = current_proxy_improvement - if p == end_non_missing and not missing_go_to_left: - # split with the right node being only the missing values - current_split.threshold = INFINITY - else: - # split between two non-missing values - # sum of halves is used to avoid infinite value - current_split.threshold = ( - feature_values[p_prev] / 2.0 + feature_values[p] / 2.0 - ) - if ( - current_split.threshold == INFINITY or - current_split.threshold == -INFINITY - ): - current_split.threshold = feature_values[p_prev] - + current_split.value = partitioner.pos_to_threshold(p_prev, p) current_split.n_missing = n_missing # if there are no missing values in the training data, during @@ -477,7 +456,7 @@ cdef inline int node_split_best( if best_split.pos < end: partitioner.partition_samples_final( best_split.pos, - best_split.threshold, + best_split.value, best_split.feature, best_split.missing_go_to_left ) @@ -636,7 +615,7 @@ cdef inline int node_split_random( has_missing = n_missing != 0 # Draw a random threshold - current_split.threshold = rand_uniform( + current_split.value.threshold = rand_uniform( min_feature_value, max_feature_value, random_state, @@ -656,12 +635,12 @@ cdef inline int node_split_random( else: missing_go_to_left = 0 - if current_split.threshold == max_feature_value: - current_split.threshold = min_feature_value + if current_split.value.threshold == max_feature_value: + current_split.value.threshold = min_feature_value # Partition current_split.pos = partitioner.partition_samples( - current_split.threshold, missing_go_to_left + current_split.value.threshold, missing_go_to_left ) n_left = current_split.pos - start @@ -715,7 +694,7 @@ cdef inline int node_split_random( if current_split.feature != best_split.feature: partitioner.partition_samples_final( best_split.pos, - best_split.threshold, + best_split.value, best_split.feature, best_split.missing_go_to_left ) @@ -756,10 +735,12 @@ cdef class BestSplitter(Splitter): const float64_t[:, ::1] y, const float64_t[:] sample_weight, const uint8_t[::1] missing_values_in_feature_mask, + const intp_t[::1] n_categories_in_feature, ) except -1: - Splitter.init(self, X, y, sample_weight, missing_values_in_feature_mask) + Splitter.init(self, X, y, sample_weight, missing_values_in_feature_mask, n_categories_in_feature) self.partitioner = DensePartitioner( - X, self.samples, self.feature_values, missing_values_in_feature_mask + X, y, sample_weight, self.samples, self.feature_values, + missing_values_in_feature_mask, n_categories_in_feature ) cdef int node_split( @@ -784,8 +765,9 @@ cdef class BestSparseSplitter(Splitter): const float64_t[:, ::1] y, const float64_t[:] sample_weight, const uint8_t[::1] missing_values_in_feature_mask, + const intp_t[::1] n_categories_in_feature, ) except -1: - Splitter.init(self, X, y, sample_weight, missing_values_in_feature_mask) + Splitter.init(self, X, y, sample_weight, missing_values_in_feature_mask, n_categories_in_feature) self.partitioner = SparsePartitioner( X, self.samples, self.n_samples, self.feature_values, missing_values_in_feature_mask ) @@ -812,10 +794,12 @@ cdef class RandomSplitter(Splitter): const float64_t[:, ::1] y, const float64_t[:] sample_weight, const uint8_t[::1] missing_values_in_feature_mask, + const intp_t[::1] n_categories_in_feature, ) except -1: - Splitter.init(self, X, y, sample_weight, missing_values_in_feature_mask) + Splitter.init(self, X, y, sample_weight, missing_values_in_feature_mask, n_categories_in_feature) self.partitioner = DensePartitioner( - X, self.samples, self.feature_values, missing_values_in_feature_mask + X, y, sample_weight, self.samples, self.feature_values, + missing_values_in_feature_mask, n_categories_in_feature ) cdef int node_split( @@ -840,8 +824,9 @@ cdef class RandomSparseSplitter(Splitter): const float64_t[:, ::1] y, const float64_t[:] sample_weight, const uint8_t[::1] missing_values_in_feature_mask, + const intp_t[::1] n_categories_in_feature, ) except -1: - Splitter.init(self, X, y, sample_weight, missing_values_in_feature_mask) + Splitter.init(self, X, y, sample_weight, missing_values_in_feature_mask, n_categories_in_feature) self.partitioner = SparsePartitioner( X, self.samples, self.n_samples, self.feature_values, missing_values_in_feature_mask ) diff --git a/sklearn/tree/_tree.pxd b/sklearn/tree/_tree.pxd index 593f8d0c5f542..fcd5dceb4fae7 100644 --- a/sklearn/tree/_tree.pxd +++ b/sklearn/tree/_tree.pxd @@ -6,22 +6,13 @@ import numpy as np cimport numpy as cnp -from sklearn.utils._typedefs cimport float32_t, float64_t, intp_t, int32_t, uint8_t, uint32_t +from sklearn.utils._typedefs cimport ( + float32_t, float64_t, intp_t, int32_t, uint8_t, uint32_t, uint64_t +) from sklearn.tree._splitter cimport Splitter from sklearn.tree._splitter cimport SplitRecord - -cdef struct Node: - # Base storage structure for the nodes in a Tree object - - intp_t left_child # id of the left child of the node - intp_t right_child # id of the right child of the node - intp_t feature # Feature used for splitting the node - float64_t threshold # Threshold value at the node - float64_t impurity # Impurity of the node (i.e., the value of the criterion) - intp_t n_node_samples # Number of samples at the node - float64_t weighted_n_node_samples # Weighted number of samples at the node - uint8_t missing_go_to_left # Whether features have missing values +from sklearn.tree._utils cimport Node, SplitValue cdef struct ParentInfo: @@ -44,6 +35,9 @@ cdef class Tree: cdef public intp_t n_outputs # Number of outputs in y cdef public intp_t max_n_classes # max(n_classes) + # FIXME: change to uint8_t: but the error it triggers might be a Cython bug + cdef intp_t* is_categorical # Shape (n_features,) + # Inner structures: values are stored separately from node structure, # since size is determined at runtime. cdef public intp_t max_depth # Max depth of the tree @@ -55,8 +49,8 @@ cdef class Tree: # Methods cdef intp_t _add_node(self, intp_t parent, bint is_left, bint is_leaf, - intp_t feature, float64_t threshold, float64_t impurity, - intp_t n_node_samples, + intp_t feature, float64_t threshold, uint64_t cat_split, + float64_t impurity, intp_t n_node_samples, float64_t weighted_n_node_samples, uint8_t missing_go_to_left) except -1 nogil cdef int _resize(self, intp_t capacity) except -1 nogil diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 7044673189fb6..0537961fca776 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -145,6 +145,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): const float64_t[:, ::1] y, const float64_t[:] sample_weight=None, const uint8_t[::1] missing_values_in_feature_mask=None, + const intp_t [::1] n_categories_in_feature=None, ): """Build a decision tree from the training set (X, y).""" @@ -170,7 +171,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): cdef float64_t min_impurity_decrease = self.min_impurity_decrease # Recursive partition (without actual recursion) - splitter.init(X, y, sample_weight, missing_values_in_feature_mask) + splitter.init(X, y, sample_weight, missing_values_in_feature_mask, n_categories_in_feature) cdef intp_t start cdef intp_t end @@ -254,7 +255,8 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): min_impurity_decrease)) node_id = tree._add_node(parent, is_left, is_leaf, split.feature, - split.threshold, parent_record.impurity, + split.value.threshold, split.value.cat_split, + parent_record.impurity, n_node_samples, weighted_n_node_samples, split.missing_go_to_left) @@ -400,6 +402,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder): const float64_t[:, ::1] y, const float64_t[:] sample_weight=None, const uint8_t[::1] missing_values_in_feature_mask=None, + const intp_t[::1] n_categories_in_feature=None, ): """Build a decision tree from the training set (X, y).""" @@ -411,7 +414,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder): cdef intp_t max_leaf_nodes = self.max_leaf_nodes # Recursive partition (without actual recursion) - splitter.init(X, y, sample_weight, missing_values_in_feature_mask) + splitter.init(X, y, sample_weight, missing_values_in_feature_mask, n_categories_in_feature) cdef vector[FrontierRecord] frontier cdef FrontierRecord record @@ -467,6 +470,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder): node.right_child = _TREE_LEAF node.feature = _TREE_UNDEFINED node.threshold = _TREE_UNDEFINED + # node.categorical_bitset = _TREE_UNDEFINED else: # Node is expandable @@ -611,8 +615,9 @@ cdef class BestFirstTreeBuilder(TreeBuilder): node_id = tree._add_node(parent - tree.nodes if parent != NULL else _TREE_UNDEFINED, - is_left, is_leaf, - split.feature, split.threshold, parent_record.impurity, + is_left, is_leaf, split.feature, + split.value.threshold, split.value.cat_split, + parent_record.impurity, n_node_samples, weighted_n_node_samples, split.missing_go_to_left) if node_id == INTPTR_MAX: @@ -746,6 +751,10 @@ cdef class Tree: def threshold(self): return self._get_node_ndarray()['threshold'][:self.node_count] + @property + def categorical_bitset(self): + return self._get_node_ndarray()['categorical_bitset'][:self.node_count] + @property def impurity(self): return self._get_node_ndarray()['impurity'][:self.node_count] @@ -768,7 +777,7 @@ cdef class Tree: # TODO: Convert n_classes to cython.integral memory view once # https://github.com/cython/cython/issues/5243 is fixed - def __cinit__(self, intp_t n_features, cnp.ndarray n_classes, intp_t n_outputs): + def __cinit__(self, intp_t n_features, cnp.ndarray n_classes, intp_t n_outputs, cnp.ndarray is_categorical): """Constructor.""" cdef intp_t dummy = 0 size_t_dtype = np.array(dummy).dtype @@ -788,6 +797,16 @@ cdef class Tree: for k in range(n_outputs): self.n_classes[k] = n_classes[k] + self.is_categorical = NULL + safe_realloc(&self.is_categorical, n_features) + if is_categorical is None: + for f in range(n_features): + self.is_categorical[f] = False + else: + is_categorical = is_categorical.astype(np.intp) + for f in range(n_features): + self.is_categorical[f] = is_categorical[f] + # Inner structures self.max_depth = 0 self.node_count = 0 @@ -798,15 +817,19 @@ cdef class Tree: def __dealloc__(self): """Destructor.""" # Free all inner structures + free(self.is_categorical) free(self.n_classes) free(self.value) free(self.nodes) def __reduce__(self): """Reduce re-implementation, for pickling.""" - return (Tree, (self.n_features, - sizet_ptr_to_ndarray(self.n_classes, self.n_outputs), - self.n_outputs), self.__getstate__()) + return (Tree, ( + self.n_features, + sizet_ptr_to_ndarray(self.n_classes, self.n_outputs), + self.n_outputs, + sizet_ptr_to_ndarray(self.is_categorical, self.n_features) + ), self.__getstate__()) def __getstate__(self): """Getstate re-implementation, for pickling.""" @@ -895,8 +918,8 @@ cdef class Tree: return 0 cdef intp_t _add_node(self, intp_t parent, bint is_left, bint is_leaf, - intp_t feature, float64_t threshold, float64_t impurity, - intp_t n_node_samples, + intp_t feature, float64_t threshold, uint64_t cat_split, + float64_t impurity, intp_t n_node_samples, float64_t weighted_n_node_samples, uint8_t missing_go_to_left) except -1 nogil: """Add a node to the tree. @@ -927,11 +950,17 @@ cdef class Tree: node.right_child = _TREE_LEAF node.feature = _TREE_UNDEFINED node.threshold = _TREE_UNDEFINED + # node.categorical_bitset = _TREE_UNDEFINED else: # left_child and right_child will be set later node.feature = feature - node.threshold = threshold + if self.is_categorical[feature]: + node.threshold = -INFINITY + node.categorical_bitset = cat_split + else: + node.threshold = threshold + node.categorical_bitset = 0 node.missing_go_to_left = missing_go_to_left self.node_count += 1 @@ -981,13 +1010,18 @@ cdef class Tree: node = self.nodes # While node not a leaf while node.left_child != _TREE_LEAF: - X_i_node_feature = X_ndarray[i, node.feature] # ... and node.right_child != _TREE_LEAF: + X_i_node_feature = X_ndarray[i, node.feature] if isnan(X_i_node_feature): if node.missing_go_to_left: node = &self.nodes[node.left_child] else: node = &self.nodes[node.right_child] + elif self.is_categorical[node.feature]: + if node.categorical_bitset & (1 << ( X_i_node_feature)): + node = &self.nodes[node.left_child] + else: + node = &self.nodes[node.right_child] elif X_i_node_feature <= node.threshold: node = &self.nodes[node.left_child] else: @@ -1109,13 +1143,17 @@ cdef class Tree: # ... and node.right_child != _TREE_LEAF: indices[indptr[i + 1]] = (node - self.nodes) indptr[i + 1] += 1 - X_i_node_feature = X_ndarray[i, node.feature] if isnan(X_i_node_feature): if node.missing_go_to_left: node = &self.nodes[node.left_child] else: node = &self.nodes[node.right_child] + elif self.is_categorical[node.feature]: + if node.categorical_bitset & (1 << ( X_ndarray[i, node.feature])): + node = &self.nodes[node.left_child] + else: + node = &self.nodes[node.right_child] elif X_i_node_feature <= node.threshold: node = &self.nodes[node.left_child] else: @@ -1396,6 +1434,7 @@ cdef class Tree: if is_target_feature: # In this case, we push left or right child on stack + # TODO: handle categorical (and missing?) if X[sample_idx, feature_idx] <= current_node.threshold: node_idx_stack[stack_size] = current_node.left_child else: @@ -1936,7 +1975,8 @@ cdef void _build_pruned_tree( break new_node_id = tree._add_node( - parent, is_left, is_leaf, node.feature, node.threshold, + parent, is_left, is_leaf, node.feature, + node.threshold, node.categorical_bitset, node.impurity, node.n_node_samples, node.weighted_n_node_samples, node.missing_go_to_left) diff --git a/sklearn/tree/_utils.pxd b/sklearn/tree/_utils.pxd index 9692fd9e8c809..426fd29df51bf 100644 --- a/sklearn/tree/_utils.pxd +++ b/sklearn/tree/_utils.pxd @@ -4,9 +4,35 @@ # See _utils.pyx for details. cimport numpy as cnp -from sklearn.tree._tree cimport Node from sklearn.neighbors._quad_tree cimport Cell -from sklearn.utils._typedefs cimport float32_t, float64_t, intp_t, uint8_t, int32_t, uint32_t +from sklearn.utils._typedefs cimport ( + float32_t, float64_t, intp_t, uint8_t, int32_t, uint32_t, uint64_t +) + + +ctypedef union SplitValue: + # Union type to generalize the concept of a threshold to categorical + # features. The floating point view, i.e. ``split_value.threshold`` is used + # for numerical features, where feature values less than or equal to the + # threshold go left, and values greater than the threshold go right. + # + # For categorical features, TODO + float64_t threshold + uint64_t cat_split # bitset + + +cdef struct Node: + # Base storage structure for the nodes in a Tree object + + intp_t left_child # id of the left child of the node + intp_t right_child # id of the right child of the node + intp_t feature # Feature used for splitting the node + float64_t threshold # Threshold value at the node, for continuous split (-INF otherwise) + uint64_t categorical_bitset # Bitset for categorical split (0 otherwise) + float64_t impurity # Impurity of the node (i.e., the value of the criterion) + intp_t n_node_samples # Number of samples at the node + float64_t weighted_n_node_samples # Weighted number of samples at the node + uint8_t missing_go_to_left # Whether features have missing values cdef enum: diff --git a/sklearn/tree/tests/test_split.py b/sklearn/tree/tests/test_split.py new file mode 100644 index 0000000000000..af1734bffc249 --- /dev/null +++ b/sklearn/tree/tests/test_split.py @@ -0,0 +1,281 @@ +from dataclasses import dataclass +from functools import cached_property, partial +from itertools import chain, combinations +from operator import itemgetter + +import numpy as np +import pytest +from scipy.sparse import csc_array + +from sklearn.metrics import ( + log_loss, + mean_absolute_error, + mean_poisson_deviance, + mean_squared_error, +) +from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor + +CLF_CRITERIONS = ("gini", "log_loss") +REG_CRITERIONS = ("squared_error", "absolute_error", "friedman_mse", "poisson") + + +def powerset(iterable): + s = list(iterable) # allows handling sets too + return chain.from_iterable( + (list(c) for c in combinations(s, r)) for r in range(1, (len(s) + 1) // 2 + 1) + ) + + +@dataclass +class NaiveSplitter: + is_clf: bool + criterion: str + with_nans: bool + n_classes: int + is_categorical: np.ndarray + + @staticmethod + def weighted_median(y, w): + sorter = np.argsort(y) + wc = np.cumsum(w[sorter]) + idx = np.searchsorted(wc, wc[-1] / 2) + return y[sorter[idx]] + + @staticmethod + def gini_loss(y, y_pred, sample_weight): + p = y_pred[0] + return (p * (1 - p)).sum() + + @cached_property + def loss(self): + losses = { + "poisson": mean_poisson_deviance, + "squared_error": mean_squared_error, + "absolute_error": mean_absolute_error, + "log_loss": partial(log_loss, labels=np.arange(self.n_classes)), + "gini": self.gini_loss, + } + return losses[self.criterion] + + def class_ratios(self, y, w): + return np.clip(np.bincount(y, w, minlength=self.n_classes) / w.sum(), None, 1) + + @cached_property + def predictor(self): + if self.is_clf: + return self.class_ratios + elif self.criterion == "absolute_error": + return self.weighted_median + else: + return lambda y, w: np.average(y, weights=w) + + def compute_child_loss(self, y: np.ndarray, w: np.ndarray): + if y.size == 0: + return np.inf + pred_dim = (y.size, self.n_classes) if self.is_clf else (y.size,) + y_pred = np.empty(pred_dim) + y_pred[:] = self.predictor(y, w) + return w.sum() * self.loss(y, y_pred, sample_weight=w) + + def compute_split_loss( + self, x, y, w, threshold=None, categories=None, missing_left=False + ): + if categories is not None: + mask_c = np.zeros(int(x.max() + 1), dtype=bool) + mask_c[categories] = True + mask = mask_c[x.astype(int)] + else: + mask = x < threshold + if missing_left: + mask |= np.isnan(x) + if self.criterion == "friedman_mse": + diff = np.average(y[mask], weights=w[mask]) - np.average( + y[~mask], weights=w[~mask] + ) + return (-(diff**2) * w[mask].sum() * w[~mask].sum() / w.sum(),) + return ( + self.compute_child_loss(y[mask], w[mask]), + self.compute_child_loss(y[~mask], w[~mask]), + ) + + def compute_all_losses(self, x, y, w, is_categorical=False, missing_left=False): + if is_categorical: + return self.compute_all_losses_categorical(x, y, w) + nan_mask = np.isnan(x) + xu = np.unique(x[~nan_mask], sorted=True) + thresholds = (xu[1:] + xu[:-1]) / 2 + if nan_mask.any() and not missing_left: + thresholds = np.append(thresholds, xu.max() * 2) + return thresholds, [ + sum(self.compute_split_loss(x, y, w, threshold, missing_left=missing_left)) + for threshold in thresholds + ] + + def compute_all_losses_categorical(self, x, y, w): + cat_splits = list(powerset(np.unique(x).astype(int))) + return cat_splits, [ + sum(self.compute_split_loss(x, y, w, categories=left_cat)) + for left_cat in cat_splits + ] + + def best_split_naive(self, X, y, w): + splits = [] + for f in range(X.shape[1]): + thresholds, losses = self.compute_all_losses( + X[:, f], y, w, is_categorical=self.is_categorical[f] + ) + if self.with_nans: + thresholds_, losses_ = self.compute_all_losses( + X[:, f], y, w, missing_left=True + ) + thresholds = np.concat((thresholds, thresholds_)) + losses = np.concat((losses, losses_)) + if len(losses) == 0: + continue + idx = np.argmin(losses) + splits.append( + ( + losses[idx], + thresholds[idx], + self.with_nans and idx >= thresholds.size // 2, + f, + ) + ) + return min(splits, key=itemgetter(0)) + + +def sparsify(X, density): + X -= 0.5 + th_low = np.quantile(X.ravel(), q=density / 2) + th_up = np.quantile(X.ravel(), q=1 - density / 2) + X[(th_low < X) & (X < th_up)] = 0 + return csc_array(X) + + +def to_categorical(x, nc): + q = np.linspace(0, 1, num=nc + 1)[1:-1] + quantiles = np.quantile(x, q) + cats = np.searchsorted(quantiles, x) + return np.random.permutation(nc)[cats] + + +def make_simple_dataset( + n, + d, + with_nans, + is_sparse, + is_categorical, + is_clf, + n_classes, + rng: np.random.Generator, +): + X_dense = np.random.rand(n, d) + y = np.random.rand(n) + X_dense.sum(axis=1) + w = np.random.rand(n) + + for idx in np.where(is_categorical)[0]: + nc = rng.integers(2, 6) # cant go to high or test will be too slow + X_dense[:, idx] = to_categorical(X_dense[:, idx], nc) + with_duplicates = rng.integers(2) == 0 + if with_duplicates: + X_dense = X_dense.round(1 if n < 50 else 2) + if with_nans: + for i in range(d): + step = rng.integers(2, 10) + X_dense[i::step, i] = np.nan + if is_sparse: + density = rng.uniform(0.05, 0.99) + X = sparsify(X_dense, density) + else: + X = X_dense + + if is_clf: + q = np.linspace(0, 1, num=n_classes + 1)[1:-1] + y = np.searchsorted(np.quantile(y, q), y) + + return X_dense, X, y, w + + +def bitset_to_set(v: np.uint64): + return [c for c in range(64) if v & (1 << c)] + + +@pytest.mark.parametrize("sparse", ["x", "sparse"]) +@pytest.mark.parametrize("categorical", ["x", "categorical"]) +@pytest.mark.parametrize("missing_values", ["x", "missing_values"]) +@pytest.mark.parametrize( + "criterion", + ["gini", "log_loss", "squared_error", "friedman_mse", "poisson"], +) +def test_best_split_optimality( + sparse, categorical, missing_values, criterion, global_random_seed +): + is_clf = criterion in CLF_CRITERIONS + with_nans = missing_values != "x" + is_sparse = sparse != "x" + with_categoricals = categorical != "x" + if is_sparse and with_nans: + pytest.skip("Sparse + missing values not supported yet") + if with_categoricals and (is_sparse or criterion == "absolute_error" or with_nans): + pytest.skip("Categorical features not supported in this case") + + rng = np.random.default_rng() + + ns = [5] * 5 + [10] * 5 + [30, 30] + if not with_categoricals and criterion != "log_loss": + ns.extend([30, 30, 30, 100, 100, 200]) + + for it, n in enumerate(ns): + d = rng.integers(1, 4) + n_classes = 2 if with_categoricals else rng.integers(2, 5) + if with_categoricals: + is_categorical = rng.random(d) < 0.5 + else: + is_categorical = np.zeros(d, dtype=bool) + X_dense, X, y, w = make_simple_dataset( + n, d, with_nans, is_sparse, is_categorical, is_clf, n_classes, rng + ) + + naive_splitter = NaiveSplitter( + is_clf, criterion, with_nans, n_classes, is_categorical + ) + best_split = naive_splitter.best_split_naive(X_dense, y, w) + + is_categorical = is_categorical if is_categorical.any() else None + Tree = DecisionTreeClassifier if is_clf else DecisionTreeRegressor + tree = Tree( + criterion=criterion, + max_depth=1, + max_features=d, + categorical_features=is_categorical, + ) + tree.fit(X, y, sample_weight=w) + + split_feature = tree.tree_.feature[0] + split = ( + {"threshold": tree.tree_.threshold[0]} + if is_categorical is None or not is_categorical[split_feature] + else {"categories": bitset_to_set(tree.tree_.categorical_bitset[0])} + ) + tree_loss = naive_splitter.compute_split_loss( + X_dense[:, tree.tree_.feature[0]], + y, + w, + **split, + missing_left=bool(tree.tree_.missing_go_to_left[0]), + ) + tree_split = ( + sum(tree_loss), + split, + bool(tree.tree_.missing_go_to_left[0]), + tree.tree_.feature[0], + ) + assert np.isclose(best_split[0], tree_split[0]), (it, best_split, tree_split) + + vals = tree.tree_.impurity * tree.tree_.weighted_n_node_samples + if criterion == "log_loss": + vals *= np.log(2) + if criterion == "poisson": + vals *= 2 + if criterion != "friedman_mse": + assert np.allclose(vals[1:], tree_loss), it diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index 4a1aa11d62f8c..1ee6fa33d2ce8 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -2091,11 +2091,21 @@ def test_criterion_entropy_same_as_log_loss(Tree, n_classes): assert_allclose(tree_log_loss.predict(X), tree_entropy.predict(X)) +def to_categorical(x, nc): + q = np.linspace(0, 1, num=nc + 1)[1:-1] + quantiles = np.quantile(x, q) + cats = np.searchsorted(quantiles, x) + return np.random.permutation(nc)[cats] + + def test_different_endianness_pickle(): - X, y = datasets.make_classification(random_state=0) + X, y = datasets.make_classification(random_state=0, n_redundant=0, shuffle=False) + X[:, 0] = to_categorical(X[:, 0], 50) - clf = DecisionTreeClassifier(random_state=0, max_depth=3) + clf = DecisionTreeClassifier(random_state=0, max_depth=3, categorical_features=[0]) clf.fit(X, y) + assert 0 < clf.feature_importances_[0] < 1 + # ^ ensures some splits are categorical, some are continuous score = clf.score(X, y) def reduce_ndarray(arr): @@ -2118,9 +2128,12 @@ def get_pickle_non_native_endianness(): def test_different_endianness_joblib_pickle(): X, y = datasets.make_classification(random_state=0) + X[:, 0] = to_categorical(X[:, 0], 50) - clf = DecisionTreeClassifier(random_state=0, max_depth=3) + clf = DecisionTreeClassifier(random_state=0, max_depth=3, categorical_features=[0]) clf.fit(X, y) + assert 0 < clf.feature_importances_[0] < 1 + # ^ ensures some splits are categorical, some are continuous score = clf.score(X, y) class NonNativeEndiannessNumpyPickler(NumpyPickler): @@ -2179,13 +2192,15 @@ def get_different_alignment_node_ndarray(node_ndarray): def reduce_tree_with_different_bitness(tree): new_dtype = np.int64 if _IS_32BIT else np.int32 - tree_cls, (n_features, n_classes, n_outputs), state = tree.__reduce__() + tree_cls, (n_features, n_classes, n_outputs, is_categorical), state = ( + tree.__reduce__() + ) new_n_classes = n_classes.astype(new_dtype, casting="same_kind") new_state = state.copy() new_state["nodes"] = get_different_bitness_node_ndarray(new_state["nodes"]) - return (tree_cls, (n_features, new_n_classes, n_outputs), new_state) + return (tree_cls, (n_features, new_n_classes, n_outputs, is_categorical), new_state) def test_different_bitness_pickle(): @@ -2826,7 +2841,9 @@ def test_build_pruned_tree_py(): tree.fit(iris.data, iris.target) n_classes = np.atleast_1d(tree.n_classes_) - pruned_tree = CythonTree(tree.n_features_in_, n_classes, tree.n_outputs_) + pruned_tree = CythonTree( + tree.n_features_in_, n_classes, tree.n_outputs_, tree.is_categorical_ + ) # only keep the root note leave_in_subtree = np.zeros(tree.tree_.node_count, dtype=np.uint8) @@ -2840,7 +2857,9 @@ def test_build_pruned_tree_py(): assert_array_equal(tree.tree_.value[0], pruned_tree.value[0]) # now keep all the leaves - pruned_tree = CythonTree(tree.n_features_in_, n_classes, tree.n_outputs_) + pruned_tree = CythonTree( + tree.n_features_in_, n_classes, tree.n_outputs_, tree.is_categorical_ + ) leave_in_subtree = np.zeros(tree.tree_.node_count, dtype=np.uint8) leave_in_subtree[1:] = 1 @@ -2858,7 +2877,9 @@ def test_build_pruned_tree_infinite_loop(): tree = DecisionTreeClassifier(random_state=0, max_depth=1) tree.fit(iris.data, iris.target) n_classes = np.atleast_1d(tree.n_classes_) - pruned_tree = CythonTree(tree.n_features_in_, n_classes, tree.n_outputs_) + pruned_tree = CythonTree( + tree.n_features_in_, n_classes, tree.n_outputs_, tree.is_categorical_ + ) # only keeping one child as a leaf results in an improper tree leave_in_subtree = np.zeros(tree.tree_.node_count, dtype=np.uint8) @@ -2889,3 +2910,23 @@ def test_sort_log2_build(): ] # fmt: on assert_array_equal(samples, expected_samples) + + +@pytest.mark.parametrize("Tree", [DecisionTreeClassifier, DecisionTreeRegressor]) +def test_categorical(Tree): + rng = np.random.default_rng(3) + n = 40 + c = rng.integers(0, 20, size=n) + y = c % 2 + + X = rng.random((n, 3)) + X[:, 0] = c + + tree = Tree(categorical_features=[0], max_depth=1, random_state=8) + # assert perfect tree was reached in one split + assert tree.fit(X, y).score(X, y) == 1 + assert tree.feature_importances_[0] == 1 + + # assert it's not the case without using categorical_features + tree = Tree(max_depth=1) + assert tree.fit(X, y).score(X, y) < 1