Skip to content

Commit

Permalink
Deprecation warning about multiclass, adapt tests, remove liblinear w…
Browse files Browse the repository at this point in the history
…ith multiple classes (because that is going to be unsupported)
  • Loading branch information
markotoplak committed Jan 9, 2025
1 parent b0ed95d commit 8660a32
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 8 deletions.
17 changes: 16 additions & 1 deletion Orange/classification/logistic_regression.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import warnings

import numpy as np
import sklearn.linear_model as skl_linear_model

from Orange.classification import SklLearner, SklModel
from Orange.preprocess import Normalize
from Orange.preprocess.score import LearnerScorer
from Orange.data import Variable, DiscreteVariable
from Orange.util import OrangeDeprecationWarning


__all__ = ["LogisticRegressionLearner"]

Expand Down Expand Up @@ -38,12 +42,23 @@ class LogisticRegressionLearner(SklLearner, _FeatureScorerMixin):
def __init__(self, penalty="l2", dual=False, tol=0.0001, C=1.0,
fit_intercept=True, intercept_scaling=1, class_weight=None,
random_state=None, solver="auto", max_iter=100,
verbose=0, n_jobs=1, preprocessors=None):
multi_class="deprecated", verbose=0, n_jobs=1, preprocessors=None):
if multi_class != "deprecated":
warnings.warn("The multi_class parameter was "
"deprecated in scikit-learn 1.5. Using it with "
"scikit-learn 1.7 will lead to a crash.",
OrangeDeprecationWarning,
stacklevel=2)
super().__init__(preprocessors=preprocessors)
self.params = vars()

def _initialize_wrapped(self):
params = self.params.copy()

multi_class = params.pop("multi_class")
if multi_class != "deprecated":
params["multi_class"] = multi_class

# The default scikit-learn solver `lbfgs` (v0.22) does not support the
# l1 penalty.
solver, penalty = params.pop("solver"), params.get("penalty")
Expand Down
2 changes: 1 addition & 1 deletion Orange/evaluation/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
--------
>>> import Orange
>>> data = Orange.data.Table('iris')
>>> learner = Orange.classification.LogisticRegressionLearner(solver="liblinear")
>>> learner = Orange.classification.LogisticRegressionLearner()
>>> results = Orange.evaluation.TestOnTrainingData(data, [learner])
"""
Expand Down
12 changes: 11 additions & 1 deletion Orange/tests/test_logistic_regression.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Test methods with long descriptive names can omit docstrings
# pylint: disable=missing-docstring

from datetime import datetime
import unittest

import numpy as np
Expand All @@ -9,6 +10,7 @@
from Orange.data import Table, ContinuousVariable, Domain
from Orange.classification import LogisticRegressionLearner, Model
from Orange.evaluation import CrossValidation, CA
from Orange.util import OrangeDeprecationWarning


class TestLogisticRegressionLearner(unittest.TestCase):
Expand Down Expand Up @@ -149,8 +151,16 @@ def test_auto_solver(self):
# liblinear is default for l2 penalty
lr = LogisticRegressionLearner(penalty="l1", solver="auto")
skl_clf = lr._initialize_wrapped()
self.assertEqual(skl_clf.solver, "liblinear")
self.assertEqual(skl_clf.solver, "saga")
self.assertEqual(skl_clf.penalty, "l1")

def test_supports_weights(self):
self.assertTrue(LogisticRegressionLearner().supports_weights)

def test_multi_class_deprecation(self):
with self.assertWarns(OrangeDeprecationWarning):
LogisticRegressionLearner(penalty="l1", multi_class="multinomial")
now = datetime.now()
if (now.year, now.month) >= (2026, 1):
raise Exception("If Orange depends on scikit-learn >= 1.7, remove this test "
"and any mention of multi_class in LogisticRegressionLearner.")
2 changes: 1 addition & 1 deletion Orange/widgets/evaluate/tests/test_owpredictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1303,7 +1303,7 @@ def test_output_error_cls(self):
log_reg = LogisticRegressionLearner()
self.send_signal(self.widget.Inputs.predictors, log_reg(data), 0)
self.send_signal(self.widget.Inputs.predictors,
LogisticRegressionLearner(penalty="l1")(data), 1)
LogisticRegressionLearner(penalty="l1", max_iter=1000)(data), 1)
with data.unlocked(data.Y):
data.Y[1] = np.nan
self.send_signal(self.widget.Inputs.data, data)
Expand Down
6 changes: 2 additions & 4 deletions Orange/widgets/visualize/tests/test_ownomogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,8 @@ def test_nomogram_nb_multiclass(self):
def test_nomogram_lr_multiclass(self):
"""Check probabilities for logistic regression classifier for various
values of classes and radio buttons for multiclass data"""
cls = LogisticRegressionLearner(
solver="liblinear"
)(self.lenses)
self._test_helper(cls, [9, 45, 52])
cls = LogisticRegressionLearner(max_iter=100)(self.lenses)
self._test_helper(cls, [18, 56, 78])

def test_nomogram_with_instance_nb(self):
"""Check initialized marker values and feature sorting for naive bayes
Expand Down

0 comments on commit 8660a32

Please sign in to comment.