From 01571757387f5bb41fd89190a38a6bf3217c43f9 Mon Sep 17 00:00:00 2001 From: divakaivan Date: Thu, 7 Aug 2025 17:36:11 +0100 Subject: [PATCH 01/12] feat: Add predict_proba on SKLearnClassifier --- keras/src/wrappers/sklearn_wrapper.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/keras/src/wrappers/sklearn_wrapper.py b/keras/src/wrappers/sklearn_wrapper.py index 90d36c669792..8103f6a46c26 100644 --- a/keras/src/wrappers/sklearn_wrapper.py +++ b/keras/src/wrappers/sklearn_wrapper.py @@ -278,6 +278,14 @@ def dynamic_model(X, y, loss, layers=[10]): ``` """ + def predict_proba(self, X): + """Predict class probabilities of the input samples X.""" + from sklearn.utils.validation import check_is_fitted + + check_is_fitted(self) + X = _validate_data(self, X, reset=False) + return self.model_.predict(X) + def _process_target(self, y, reset=False): """Classifiers do OHE.""" target_type = type_of_target(y, raise_unknown=True) From eef00dc7f7696ad8d3919bacd7cd3f4c27c51875 Mon Sep 17 00:00:00 2001 From: divakaivan Date: Thu, 2 Oct 2025 13:34:51 +0100 Subject: [PATCH 02/12] feat: Add _estimator_has check --- keras/src/wrappers/utils.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/keras/src/wrappers/utils.py b/keras/src/wrappers/utils.py index 8c2954b055ad..cd2dea7bffe4 100644 --- a/keras/src/wrappers/utils.py +++ b/keras/src/wrappers/utils.py @@ -32,6 +32,21 @@ def _check_model(model): ) +def _estimator_has(attr): + def check(self): + from sklearn.utils.validation import check_is_fitted + + check_is_fitted(self) + return ( + True + if self.model_.layers[-1].activation.__name__ + in ("sigmoid", "softmax") + else False + ) + + return check + + class TargetReshaper(TransformerMixin, BaseEstimator): """Convert 1D targets to 2D and back. From 1c8f4f8eb17112d09737f232b2a699d819959a5f Mon Sep 17 00:00:00 2001 From: divakaivan Date: Thu, 2 Oct 2025 13:35:26 +0100 Subject: [PATCH 03/12] feat: Use available_if for classifier predict_proba --- keras/src/wrappers/sklearn_wrapper.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/keras/src/wrappers/sklearn_wrapper.py b/keras/src/wrappers/sklearn_wrapper.py index 8103f6a46c26..ca38b513b8fa 100644 --- a/keras/src/wrappers/sklearn_wrapper.py +++ b/keras/src/wrappers/sklearn_wrapper.py @@ -10,6 +10,7 @@ from keras.src.wrappers.fixes import type_of_target from keras.src.wrappers.utils import TargetReshaper from keras.src.wrappers.utils import _check_model +from keras.src.wrappers.utils import _estimator_has from keras.src.wrappers.utils import assert_sklearn_installed try: @@ -278,6 +279,7 @@ def dynamic_model(X, y, loss, layers=[10]): ``` """ + @sklearn.utils._available_if.available_if(_estimator_has("predict_proba")) def predict_proba(self, X): """Predict class probabilities of the input samples X.""" from sklearn.utils.validation import check_is_fitted From c47e9568e00f87b70131fa758c6cd4bf64ebb6ed Mon Sep 17 00:00:00 2001 From: divakaivan Date: Thu, 2 Oct 2025 13:35:51 +0100 Subject: [PATCH 04/12] test: Add test checks for predict_proba --- keras/src/wrappers/sklearn_test.py | 45 ++++++++++++++++++++++++++++-- 1 file changed, 43 insertions(+), 2 deletions(-) diff --git a/keras/src/wrappers/sklearn_test.py b/keras/src/wrappers/sklearn_test.py index 250b12c51274..3daad291d932 100644 --- a/keras/src/wrappers/sklearn_test.py +++ b/keras/src/wrappers/sklearn_test.py @@ -57,7 +57,7 @@ def patched_more_tags(self): return parametrize_with_checks(estimators) -def dynamic_model(X, y, loss, layers=[10]): +def dynamic_model(X, y, loss, out_activation_function="softmax", layers=[10]): """Creates a basic MLP classifier dynamically choosing binary/multiclass classification loss and ouput activations. """ @@ -69,7 +69,7 @@ def dynamic_model(X, y, loss, layers=[10]): hidden = Dense(layer_size, activation="relu")(hidden) n_outputs = y.shape[1] if len(y.shape) > 1 else 1 - out = [Dense(n_outputs, activation="softmax")(hidden)] + out = [Dense(n_outputs, activation=out_activation_function)(hidden)] model = Model(inp, out) model.compile(loss=loss, optimizer="rmsprop") @@ -158,3 +158,44 @@ def test_sklearn_estimator_checks(estimator, check): pytest.xfail("Backend not implemented") else: raise + + +@pytest.mark.parametrize( + "estimator, has_predict_proba", + [ + ( + SKLearnClassifier( + model=dynamic_model, + model_kwargs={ + "out_activation_function": "softmax", + "loss": "categorical_crossentropy", + }, + fit_kwargs={"epochs": 1}, + ), + True, + ), + ( + SKLearnClassifier( + model=dynamic_model, + model_kwargs={ + "out_activation_function": "linear", + "loss": "categorical_crossentropy", + }, + fit_kwargs={"epochs": 1}, + ), + False, + ), + ], +) +def test_sklearn_estimator_predict_proba(estimator, has_predict_proba): + X, y = sklearn.datasets.make_classification( + n_samples=100, + n_features=10, + n_informative=4, + n_classes=4, + random_state=42, + ) + + estimator.fit(X, y) + + assert hasattr(estimator, "predict_proba") == has_predict_proba From b5f809ec393b674a9946e6c0eb9901d00232f4a7 Mon Sep 17 00:00:00 2001 From: divakaivan Date: Thu, 2 Oct 2025 17:43:03 +0100 Subject: [PATCH 05/12] test: Avoid numpy+openvino known errors --- keras/src/wrappers/sklearn_test.py | 32 ++++++++++++++++++++---------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/keras/src/wrappers/sklearn_test.py b/keras/src/wrappers/sklearn_test.py index 3daad291d932..ba77de651e3e 100644 --- a/keras/src/wrappers/sklearn_test.py +++ b/keras/src/wrappers/sklearn_test.py @@ -188,14 +188,24 @@ def test_sklearn_estimator_checks(estimator, check): ], ) def test_sklearn_estimator_predict_proba(estimator, has_predict_proba): - X, y = sklearn.datasets.make_classification( - n_samples=100, - n_features=10, - n_informative=4, - n_classes=4, - random_state=42, - ) - - estimator.fit(X, y) - - assert hasattr(estimator, "predict_proba") == has_predict_proba + """Checks that ``SKLearnClassifier`` exposes the ``predict_proba`` method + only when the model outputs probabilities. + """ + try: + X, y = sklearn.datasets.make_classification( + n_samples=100, + n_features=10, + n_informative=4, + n_classes=4, + random_state=42, + ) + estimator.fit(X, y) + assert hasattr(estimator, "predict_proba") == has_predict_proba + except Exception as exc: + if keras.config.backend() in ["numpy", "openvino"] and ( + isinstance(exc, NotImplementedError) + or "NotImplementedError" in str(exc) + ): + pytest.xfail("Backend not implemented") + else: + raise From 654033428d4c069cccee07675409fbe4e50a58d0 Mon Sep 17 00:00:00 2001 From: Ivan Ivanov <54508530+divakaivan@users.noreply.github.com> Date: Fri, 3 Oct 2025 18:24:16 +0100 Subject: [PATCH 06/12] use proper sklearn available_if import path Co-authored-by: Adrin Jalali --- keras/src/wrappers/sklearn_wrapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras/src/wrappers/sklearn_wrapper.py b/keras/src/wrappers/sklearn_wrapper.py index ca38b513b8fa..5999cbb748d0 100644 --- a/keras/src/wrappers/sklearn_wrapper.py +++ b/keras/src/wrappers/sklearn_wrapper.py @@ -279,7 +279,7 @@ def dynamic_model(X, y, loss, layers=[10]): ``` """ - @sklearn.utils._available_if.available_if(_estimator_has("predict_proba")) + @sklearn.utils.metaestimators.available_if(_estimator_has("predict_proba")) def predict_proba(self, X): """Predict class probabilities of the input samples X.""" from sklearn.utils.validation import check_is_fitted From 359e5d05b9fb739bf233b64866d651590d71b722 Mon Sep 17 00:00:00 2001 From: Ivan Ivanov <54508530+divakaivan@users.noreply.github.com> Date: Fri, 3 Oct 2025 18:25:25 +0100 Subject: [PATCH 07/12] fix Co-authored-by: Adrin Jalali --- keras/src/wrappers/utils.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/keras/src/wrappers/utils.py b/keras/src/wrappers/utils.py index cd2dea7bffe4..2d84481afe69 100644 --- a/keras/src/wrappers/utils.py +++ b/keras/src/wrappers/utils.py @@ -38,10 +38,7 @@ def check(self): check_is_fitted(self) return ( - True - if self.model_.layers[-1].activation.__name__ - in ("sigmoid", "softmax") - else False + self.model_.layers[-1].activation.__name__ in ("sigmoid", "softmax") ) return check From 437eb7f7e813b60f45836b02b84390f8ed5c482c Mon Sep 17 00:00:00 2001 From: divakaivan Date: Fri, 3 Oct 2025 19:43:49 +0100 Subject: [PATCH 08/12] update predict_proba check --- keras/src/wrappers/sklearn_wrapper.py | 5 +++-- keras/src/wrappers/utils.py | 15 +++++---------- 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/keras/src/wrappers/sklearn_wrapper.py b/keras/src/wrappers/sklearn_wrapper.py index 5999cbb748d0..b60c2d5cf117 100644 --- a/keras/src/wrappers/sklearn_wrapper.py +++ b/keras/src/wrappers/sklearn_wrapper.py @@ -10,7 +10,7 @@ from keras.src.wrappers.fixes import type_of_target from keras.src.wrappers.utils import TargetReshaper from keras.src.wrappers.utils import _check_model -from keras.src.wrappers.utils import _estimator_has +from keras.src.wrappers.utils import _estimator_has_proba from keras.src.wrappers.utils import assert_sklearn_installed try: @@ -19,6 +19,7 @@ from sklearn.base import ClassifierMixin from sklearn.base import RegressorMixin from sklearn.base import TransformerMixin + from sklearn.utils.metaestimators import available_if except ImportError: sklearn = None @@ -279,7 +280,7 @@ def dynamic_model(X, y, loss, layers=[10]): ``` """ - @sklearn.utils.metaestimators.available_if(_estimator_has("predict_proba")) + @available_if(_estimator_has_proba) def predict_proba(self, X): """Predict class probabilities of the input samples X.""" from sklearn.utils.validation import check_is_fitted diff --git a/keras/src/wrappers/utils.py b/keras/src/wrappers/utils.py index 2d84481afe69..efb0c6924e0d 100644 --- a/keras/src/wrappers/utils.py +++ b/keras/src/wrappers/utils.py @@ -32,16 +32,11 @@ def _check_model(model): ) -def _estimator_has(attr): - def check(self): - from sklearn.utils.validation import check_is_fitted - - check_is_fitted(self) - return ( - self.model_.layers[-1].activation.__name__ in ("sigmoid", "softmax") - ) - - return check +def _estimator_has_proba(self): + return self.model_.layers[-1].activation.__name__ in ( + "sigmoid", + "softmax", + ) class TargetReshaper(TransformerMixin, BaseEstimator): From 4bb59363f0e0a982fbbc902b39c49e535d04d1f9 Mon Sep 17 00:00:00 2001 From: divakaivan Date: Wed, 8 Oct 2025 15:40:42 +0100 Subject: [PATCH 09/12] add decision_function --- keras/src/wrappers/sklearn_test.py | 49 +++++++++++++-------------- keras/src/wrappers/sklearn_wrapper.py | 7 ++-- keras/src/wrappers/utils.py | 9 +---- 3 files changed, 26 insertions(+), 39 deletions(-) diff --git a/keras/src/wrappers/sklearn_test.py b/keras/src/wrappers/sklearn_test.py index ba77de651e3e..05b57e852d1b 100644 --- a/keras/src/wrappers/sklearn_test.py +++ b/keras/src/wrappers/sklearn_test.py @@ -161,46 +161,43 @@ def test_sklearn_estimator_checks(estimator, check): @pytest.mark.parametrize( - "estimator, has_predict_proba", + "estimator", [ - ( - SKLearnClassifier( - model=dynamic_model, - model_kwargs={ - "out_activation_function": "softmax", - "loss": "categorical_crossentropy", - }, - fit_kwargs={"epochs": 1}, - ), - True, + SKLearnClassifier( + model=dynamic_model, + model_kwargs={ + "out_activation_function": "softmax", + "loss": "categorical_crossentropy", + }, + fit_kwargs={"epochs": 1}, ), - ( - SKLearnClassifier( - model=dynamic_model, - model_kwargs={ - "out_activation_function": "linear", - "loss": "categorical_crossentropy", - }, - fit_kwargs={"epochs": 1}, - ), - False, + SKLearnClassifier( + model=dynamic_model, + model_kwargs={ + "out_activation_function": "linear", + "loss": "categorical_crossentropy", + }, + fit_kwargs={"epochs": 1}, ), ], ) -def test_sklearn_estimator_predict_proba(estimator, has_predict_proba): - """Checks that ``SKLearnClassifier`` exposes the ``predict_proba`` method - only when the model outputs probabilities. +def test_sklearn_estimator_decision_function(estimator): + """Checks that the argmax of ``decision_function`` is the same as that of + ``predict`` for classifiers. """ try: X, y = sklearn.datasets.make_classification( - n_samples=100, + n_samples=10, n_features=10, n_informative=4, n_classes=4, random_state=42, ) estimator.fit(X, y) - assert hasattr(estimator, "predict_proba") == has_predict_proba + assert ( + estimator.decision_function(X[:1]).argmax(axis=-1) + == estimator.predict(X[:1]).flatten() + ) except Exception as exc: if keras.config.backend() in ["numpy", "openvino"] and ( isinstance(exc, NotImplementedError) diff --git a/keras/src/wrappers/sklearn_wrapper.py b/keras/src/wrappers/sklearn_wrapper.py index b60c2d5cf117..28a4ac701743 100644 --- a/keras/src/wrappers/sklearn_wrapper.py +++ b/keras/src/wrappers/sklearn_wrapper.py @@ -10,7 +10,6 @@ from keras.src.wrappers.fixes import type_of_target from keras.src.wrappers.utils import TargetReshaper from keras.src.wrappers.utils import _check_model -from keras.src.wrappers.utils import _estimator_has_proba from keras.src.wrappers.utils import assert_sklearn_installed try: @@ -19,7 +18,6 @@ from sklearn.base import ClassifierMixin from sklearn.base import RegressorMixin from sklearn.base import TransformerMixin - from sklearn.utils.metaestimators import available_if except ImportError: sklearn = None @@ -280,9 +278,8 @@ def dynamic_model(X, y, loss, layers=[10]): ``` """ - @available_if(_estimator_has_proba) - def predict_proba(self, X): - """Predict class probabilities of the input samples X.""" + def decision_function(self, X): + """Get raw model outputs.""" from sklearn.utils.validation import check_is_fitted check_is_fitted(self) diff --git a/keras/src/wrappers/utils.py b/keras/src/wrappers/utils.py index efb0c6924e0d..6d4bbaabdea2 100644 --- a/keras/src/wrappers/utils.py +++ b/keras/src/wrappers/utils.py @@ -32,13 +32,6 @@ def _check_model(model): ) -def _estimator_has_proba(self): - return self.model_.layers[-1].activation.__name__ in ( - "sigmoid", - "softmax", - ) - - class TargetReshaper(TransformerMixin, BaseEstimator): """Convert 1D targets to 2D and back. @@ -87,7 +80,7 @@ def inverse_transform(self, y): If the transformer was fit to a 1D numpy array, and a 2D numpy array with a singleton second dimension is passed, it will be squeezed back to 1D. Otherwise, it - will eb left untouched. + will be left untouched. """ from sklearn.utils.validation import check_is_fitted From 8dc1345d748125be28ffe701aa708e31181e63d5 Mon Sep 17 00:00:00 2001 From: divakaivan Date: Wed, 8 Oct 2025 16:31:23 +0100 Subject: [PATCH 10/12] fix decision_function tests --- keras/src/wrappers/sklearn_test.py | 2 +- keras/src/wrappers/sklearn_wrapper.py | 10 +++++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/keras/src/wrappers/sklearn_test.py b/keras/src/wrappers/sklearn_test.py index 05b57e852d1b..897b0e649766 100644 --- a/keras/src/wrappers/sklearn_test.py +++ b/keras/src/wrappers/sklearn_test.py @@ -182,7 +182,7 @@ def test_sklearn_estimator_checks(estimator, check): ], ) def test_sklearn_estimator_decision_function(estimator): - """Checks that the argmax of ``decision_function`` is the same as that of + """Checks that the argmax of ``decision_function`` is the same as ``predict`` for classifiers. """ try: diff --git a/keras/src/wrappers/sklearn_wrapper.py b/keras/src/wrappers/sklearn_wrapper.py index 28a4ac701743..30f885d1606d 100644 --- a/keras/src/wrappers/sklearn_wrapper.py +++ b/keras/src/wrappers/sklearn_wrapper.py @@ -18,6 +18,7 @@ from sklearn.base import ClassifierMixin from sklearn.base import RegressorMixin from sklearn.base import TransformerMixin + from sklearn.utils._array_api import get_namespace except ImportError: sklearn = None @@ -283,8 +284,15 @@ def decision_function(self, X): from sklearn.utils.validation import check_is_fitted check_is_fitted(self) + xp, _ = get_namespace(X) + X = _validate_data(self, X, reset=False) - return self.model_.predict(X) + scores = self.model_.predict(X) + return ( + xp.reshape(scores, (-1,)) + if (scores.ndim > 1 and scores.shape[1] == 1) + else scores + ) def _process_target(self, y, reset=False): """Classifiers do OHE.""" From 5cd54c66b4dbe8392c246148d6a231f2484da9ce Mon Sep 17 00:00:00 2001 From: divakaivan Date: Wed, 8 Oct 2025 18:30:13 +0100 Subject: [PATCH 11/12] fix tests tmp --- keras/src/wrappers/sklearn_test.py | 9 ++++++--- keras/src/wrappers/sklearn_wrapper.py | 8 +------- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/keras/src/wrappers/sklearn_test.py b/keras/src/wrappers/sklearn_test.py index 897b0e649766..6b3dc3fb59f3 100644 --- a/keras/src/wrappers/sklearn_test.py +++ b/keras/src/wrappers/sklearn_test.py @@ -107,6 +107,9 @@ def use_floatx(x): ), "check_supervised_y_2d": "This test assumes reproducibility in fit.", "check_fit_idempotent": "This test assumes reproducibility in fit.", + "check_classifiers_train": ( + "decision_function can return both probabilities and logits" + ), }, "SKLearnRegressor": { "check_parameters_default_constructible": ( @@ -167,7 +170,7 @@ def test_sklearn_estimator_checks(estimator, check): model=dynamic_model, model_kwargs={ "out_activation_function": "softmax", - "loss": "categorical_crossentropy", + "loss": "binary_crossentropy", }, fit_kwargs={"epochs": 1}, ), @@ -175,7 +178,7 @@ def test_sklearn_estimator_checks(estimator, check): model=dynamic_model, model_kwargs={ "out_activation_function": "linear", - "loss": "categorical_crossentropy", + "loss": "binary_crossentropy", }, fit_kwargs={"epochs": 1}, ), @@ -190,7 +193,7 @@ def test_sklearn_estimator_decision_function(estimator): n_samples=10, n_features=10, n_informative=4, - n_classes=4, + n_classes=2, random_state=42, ) estimator.fit(X, y) diff --git a/keras/src/wrappers/sklearn_wrapper.py b/keras/src/wrappers/sklearn_wrapper.py index 30f885d1606d..02f5777bf6f5 100644 --- a/keras/src/wrappers/sklearn_wrapper.py +++ b/keras/src/wrappers/sklearn_wrapper.py @@ -284,15 +284,9 @@ def decision_function(self, X): from sklearn.utils.validation import check_is_fitted check_is_fitted(self) - xp, _ = get_namespace(X) X = _validate_data(self, X, reset=False) - scores = self.model_.predict(X) - return ( - xp.reshape(scores, (-1,)) - if (scores.ndim > 1 and scores.shape[1] == 1) - else scores - ) + return self.model_.predict(X) def _process_target(self, y, reset=False): """Classifiers do OHE.""" From 44594bd0a19195e36bc3c2c7298db953cbe9c46b Mon Sep 17 00:00:00 2001 From: divakaivan Date: Fri, 17 Oct 2025 23:01:01 +0100 Subject: [PATCH 12/12] fix ruff ci: remove unused import --- keras/src/wrappers/sklearn_wrapper.py | 1 - 1 file changed, 1 deletion(-) diff --git a/keras/src/wrappers/sklearn_wrapper.py b/keras/src/wrappers/sklearn_wrapper.py index 02f5777bf6f5..5d92b8ae00ba 100644 --- a/keras/src/wrappers/sklearn_wrapper.py +++ b/keras/src/wrappers/sklearn_wrapper.py @@ -18,7 +18,6 @@ from sklearn.base import ClassifierMixin from sklearn.base import RegressorMixin from sklearn.base import TransformerMixin - from sklearn.utils._array_api import get_namespace except ImportError: sklearn = None