Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
91348de
test for optimal splits
cakedev0 Sep 15, 2025
7d38553
added friedman mse
cakedev0 Sep 17, 2025
f32d6c0
Compiles. Big chunk of the low-level machinery done; still needs to h…
cakedev0 Sep 23, 2025
35c10ac
Added bitset creation logic; TODO: using it.
cakedev0 Sep 23, 2025
61cdba7
added goes_left function
cakedev0 Sep 23, 2025
eaa6769
Merge branch 'tree-simpler-missing' into categorical
cakedev0 Sep 23, 2025
e8cbf55
Most changes down; all tests (non categorical) pass; still remains to…
cakedev0 Sep 24, 2025
4c571af
added categorical_features kwarg to top-level classes; non-categorica…
cakedev0 Sep 24, 2025
a27b16c
use categorical split in apply/predict; basic test test with categori…
cakedev0 Sep 24, 2025
288c1ea
refacto check_cat; better error messages
cakedev0 Sep 24, 2025
f512e03
fix segfault
cakedev0 Sep 24, 2025
eb7f94c
added basic but strong tests; fixed some minor but impactful bugs rev…
cakedev0 Sep 25, 2025
9afc1d8
Minor tests changes
cakedev0 Sep 25, 2025
9d8b93b
test sparse; clean-up
cakedev0 Sep 25, 2025
5c12733
Merge branch 'testing-split' into categorical
cakedev0 Sep 25, 2025
1d796cf
test categorical; found a new bug in missing
cakedev0 Sep 25, 2025
5c3de79
docstring for brieman sort
cakedev0 Sep 25, 2025
f47ca8e
minor comments fixes
cakedev0 Sep 25, 2025
96998ee
adressed some copilot comments
cakedev0 Sep 26, 2025
f49a185
Merge branch 'tree-simpler-missing' into categorical
cakedev0 Oct 7, 2025
09b03a0
enable more tests
cakedev0 Oct 7, 2025
4a59a7f
Merge branch 'tree-simpler-missing' into categorical
cakedev0 Oct 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sklearn/ensemble/_gradient_boosting.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ from scipy.sparse import issparse
from sklearn.utils._typedefs cimport float32_t, float64_t, intp_t, int32_t, uint8_t
# Note: _tree uses cimport numpy, cnp.import_array, so we need to include
# numpy headers in the build configuration of this extension
from sklearn.tree._tree cimport Node
from sklearn.tree._tree cimport Tree
from sklearn.tree._utils cimport Node
from sklearn.tree._utils cimport safe_realloc


Expand Down
115 changes: 111 additions & 4 deletions sklearn/tree/_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ class BaseDecisionTree(MultiOutputMixin, BaseEstimator, metaclass=ABCMeta):
"min_impurity_decrease": [Interval(Real, 0.0, None, closed="left")],
"ccp_alpha": [Interval(Real, 0.0, None, closed="left")],
"monotonic_cst": ["array-like", None],
"categorical_features": [
"array-like",
None,
],
}

@abstractmethod
Expand All @@ -144,6 +148,7 @@ def __init__(
class_weight=None,
ccp_alpha=0.0,
monotonic_cst=None,
categorical_features=None,
):
self.criterion = criterion
self.splitter = splitter
Expand All @@ -158,6 +163,7 @@ def __init__(
self.class_weight = class_weight
self.ccp_alpha = ccp_alpha
self.monotonic_cst = monotonic_cst
self.categorical_features = categorical_features

def get_depth(self):
"""Return the depth of the decision tree.
Expand Down Expand Up @@ -259,13 +265,18 @@ def _fit(
missing_values_in_feature_mask = (
self._compute_missing_values_in_feature_mask(X)
)
if issparse(X):
is_sparse_X = issparse(X)
if is_sparse_X:
X.sort_indices()

if X.indices.dtype != np.intc or X.indptr.dtype != np.intc:
raise ValueError(
"No support for np.int64 index based sparse matrices"
)
if is_sparse_X and self.categorical_features is not None:
raise NotImplementedError(
"Categorical features not supported with sparse inputs"
)

if self.criterion == "poisson":
if np.any(y < 0):
Expand Down Expand Up @@ -431,6 +442,10 @@ def _fit(
# *positive class*, all signs must be flipped.
monotonic_cst *= -1

self.is_categorical_, n_categories_in_feature = (
self._check_categorical_features(X, monotonic_cst)
)

if not isinstance(self.splitter, Splitter):
splitter = SPLITTERS[self.splitter](
criterion,
Expand All @@ -442,13 +457,19 @@ def _fit(
)

if is_classifier(self):
self.tree_ = Tree(self.n_features_in_, self.n_classes_, self.n_outputs_)
self.tree_ = Tree(
self.n_features_in_,
self.n_classes_,
self.n_outputs_,
self.is_categorical_,
)
else:
self.tree_ = Tree(
self.n_features_in_,
# TODO: tree shouldn't need this in this case
np.array([1] * self.n_outputs_, dtype=np.intp),
self.n_outputs_,
self.is_categorical_,
)

# Use BestFirst if max_leaf_nodes given; use DepthFirst otherwise
Expand All @@ -472,7 +493,14 @@ def _fit(
self.min_impurity_decrease,
)

builder.build(self.tree_, X, y, sample_weight, missing_values_in_feature_mask)
builder.build(
self.tree_,
X,
y,
sample_weight,
missing_values_in_feature_mask,
n_categories_in_feature,
)

if self.n_outputs_ == 1 and is_classifier(self):
self.n_classes_ = self.n_classes_[0]
Expand All @@ -482,6 +510,78 @@ def _fit(

return self

def _check_categorical_features(self, X, monotonic_cst):
"""Check and validate categorical features in X

Parameters
----------
X : {array-like} of shape (n_samples, n_features)
Input data (after `validate_data` was called)

Return
------
is_categorical : ndarray of shape (n_features,) or None, dtype=bool
Indicates whether a feature is categorical. If no feature is
categorical, this is None.
n_categories_in_feature: TODO
"""
n_features = X.shape[1]
categorical_features = np.asarray(self.categorical_features)

if self.categorical_features is None or categorical_features.size == 0:
is_categorical = np.zeros(n_features, dtype=bool)
elif categorical_features.dtype.kind not in ("i", "b"):
raise ValueError(
"categorical_features must be an array-like of bool or int, "
f"got: {categorical_features.dtype.name}."
)
elif categorical_features.dtype.kind == "i":
# check for categorical features as indices
if (
np.max(categorical_features) >= n_features
or np.min(categorical_features) < 0
):
raise ValueError(
"categorical_features set as integer "
"indices must be in [0, n_features - 1]"
)
is_categorical = np.zeros(n_features, dtype=bool)
is_categorical[categorical_features] = True
else:
if categorical_features.shape[0] != n_features:
raise ValueError(
"categorical_features set as a boolean mask "
"must have shape (n_features,), got: "
f"{categorical_features.shape}"
)
is_categorical = categorical_features

n_categories_in_feature = np.full(self.n_features_in_, -1, dtype=np.intp)
MAX_NC = 64 # TODO import from somewhere
base_msg = (
f"Values for categorical features should be integers in [0, {MAX_NC - 1}]."
)
for idx in np.where(is_categorical)[0]:
if np.isnan(X[:, idx]).any():
raise ValueError(
"Missing values are not supported in categorical features"
)
if not np.allclose(X[:, idx].astype(np.intp), X[:, idx]):
raise ValueError(f"{base_msg} Found non-integer values.")
if X[:, idx].min() < 0:
raise ValueError(f"{base_msg} Found negative values.")
X_idx_max = X[:, idx].max()
if X_idx_max >= MAX_NC:
raise ValueError(f"{base_msg} Found {X_idx_max}.")
n_categories_in_feature[idx] = X_idx_max + 1
if monotonic_cst is not None and monotonic_cst[idx] != 0:
raise ValueError(
"A categorical feature cannot have a non-null monotonic"
" constraint. "
)

return is_categorical, n_categories_in_feature

def _validate_X_predict(self, X, check_input):
"""Validate the training data on predict (probabilities)."""
if check_input:
Expand Down Expand Up @@ -620,13 +720,16 @@ def _prune_tree(self):
# build pruned tree
if is_classifier(self):
n_classes = np.atleast_1d(self.n_classes_)
pruned_tree = Tree(self.n_features_in_, n_classes, self.n_outputs_)
pruned_tree = Tree(
self.n_features_in_, n_classes, self.n_outputs_, self.is_categorical_
)
else:
pruned_tree = Tree(
self.n_features_in_,
# TODO: the tree shouldn't need this param
np.array([1] * self.n_outputs_, dtype=np.intp),
self.n_outputs_,
self.is_categorical_,
)
_build_pruned_tree_ccp(pruned_tree, self.tree_, self.ccp_alpha)

Expand Down Expand Up @@ -976,6 +1079,7 @@ def __init__(
class_weight=None,
ccp_alpha=0.0,
monotonic_cst=None,
categorical_features=None,
):
super().__init__(
criterion=criterion,
Expand All @@ -991,6 +1095,7 @@ def __init__(
min_impurity_decrease=min_impurity_decrease,
monotonic_cst=monotonic_cst,
ccp_alpha=ccp_alpha,
categorical_features=categorical_features,
)

@_fit_context(prefer_skip_nested_validation=True)
Expand Down Expand Up @@ -1353,6 +1458,7 @@ def __init__(
min_impurity_decrease=0.0,
ccp_alpha=0.0,
monotonic_cst=None,
categorical_features=None,
):
super().__init__(
criterion=criterion,
Expand All @@ -1367,6 +1473,7 @@ def __init__(
min_impurity_decrease=min_impurity_decrease,
ccp_alpha=ccp_alpha,
monotonic_cst=monotonic_cst,
categorical_features=categorical_features,
)

@_fit_context(prefer_skip_nested_validation=True)
Expand Down
31 changes: 26 additions & 5 deletions sklearn/tree/_partitioner.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@

# See _partitioner.pyx for details.


from sklearn.utils._typedefs cimport (
float32_t, float64_t, int8_t, int32_t, intp_t, uint8_t, uint32_t
float32_t, float64_t, int8_t, int32_t, intp_t, uint8_t, uint32_t, uint64_t
)
from sklearn.tree._splitter cimport SplitRecord
from sklearn.tree._utils cimport SplitValue


# Mitigate precision differences between 32 bit and 64 bit
Expand Down Expand Up @@ -68,15 +70,25 @@ cdef class DensePartitioner:
Note that this partitioner is agnostic to the splitting strategy (best vs. random).
"""
cdef const float32_t[:, :] X
cdef const float64_t[:, :] y
cdef const float64_t[::1] sample_weight
cdef intp_t[::1] samples
cdef float32_t[::1] feature_values
cdef intp_t start
cdef intp_t end
cdef intp_t n_missing
cdef const uint8_t[::1] missing_values_in_feature_mask
cdef const intp_t[::1] n_categories_in_feature
cdef bint missing_on_the_left
cdef intp_t n_categories

cdef intp_t[::1] counts
cdef float64_t[::1] weighted_counts
cdef float64_t[::1] means
cdef intp_t[::1] sorted_cat
cdef intp_t[::1] offsets

cdef void sort_samples_and_feature_values(
cdef bint sort_samples_and_feature_values(
self, intp_t current_feature
) noexcept nogil
cdef void shift_missing_to_the_left(self) noexcept nogil
Expand All @@ -96,6 +108,9 @@ cdef class DensePartitioner:
intp_t* p_prev,
intp_t* p
) noexcept nogil
cdef inline SplitValue pos_to_threshold(
self, intp_t p_prev, intp_t p
) noexcept nogil
cdef intp_t partition_samples(
self,
float64_t current_threshold,
Expand All @@ -104,10 +119,12 @@ cdef class DensePartitioner:
cdef void partition_samples_final(
self,
intp_t best_pos,
float64_t best_threshold,
SplitValue split_value,
intp_t best_feature,
bint best_missing_go_to_left,
) noexcept nogil
cdef void _breiman_sort_categories(self, intp_t nc) noexcept nogil
cdef inline uint64_t _split_pos_to_bitset(self, intp_t p, intp_t nc) noexcept nogil


cdef class SparsePartitioner:
Expand All @@ -132,8 +149,9 @@ cdef class SparsePartitioner:
cdef intp_t end
cdef intp_t n_missing
cdef const uint8_t[::1] missing_values_in_feature_mask
cdef intp_t n_categories

cdef void sort_samples_and_feature_values(
cdef bint sort_samples_and_feature_values(
self, intp_t current_feature
) noexcept nogil
cdef void shift_missing_to_the_left(self) noexcept nogil
Expand All @@ -153,6 +171,9 @@ cdef class SparsePartitioner:
intp_t* p_prev,
intp_t* p
) noexcept nogil
cdef inline SplitValue pos_to_threshold(
self, intp_t p_prev, intp_t p
) noexcept nogil
cdef intp_t partition_samples(
self,
float64_t current_threshold,
Expand All @@ -161,7 +182,7 @@ cdef class SparsePartitioner:
cdef void partition_samples_final(
self,
intp_t best_pos,
float64_t best_threshold,
SplitValue split_value,
intp_t best_feature,
bint best_missing_go_to_left,
) noexcept nogil
Expand Down
Loading