Skip to content

Commit

Permalink
Deprecation warning about multiclass, adapt tests, remove liblinear w…
Browse files Browse the repository at this point in the history
…ith multiple classes (because that is going to be unsupported)
  • Loading branch information
markotoplak committed Jan 9, 2025
1 parent b0ed95d commit 823f73d
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 10 deletions.
17 changes: 16 additions & 1 deletion Orange/classification/logistic_regression.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import warnings

import numpy as np
import sklearn.linear_model as skl_linear_model

from Orange.classification import SklLearner, SklModel
from Orange.preprocess import Normalize
from Orange.preprocess.score import LearnerScorer
from Orange.data import Variable, DiscreteVariable
from Orange.util import OrangeDeprecationWarning


__all__ = ["LogisticRegressionLearner"]

Expand Down Expand Up @@ -38,12 +42,23 @@ class LogisticRegressionLearner(SklLearner, _FeatureScorerMixin):
def __init__(self, penalty="l2", dual=False, tol=0.0001, C=1.0,
fit_intercept=True, intercept_scaling=1, class_weight=None,
random_state=None, solver="auto", max_iter=100,
verbose=0, n_jobs=1, preprocessors=None):
multi_class="deprecated", verbose=0, n_jobs=1, preprocessors=None):
if multi_class != "deprecated":
warnings.warn("The multi_class parameter was "
"deprecated in scikit-learn 1.5. Using it with "
"scikit-learn 1.7 will lead to a crash.",
OrangeDeprecationWarning,
stacklevel=2)
super().__init__(preprocessors=preprocessors)
self.params = vars()

def _initialize_wrapped(self):
params = self.params.copy()

multi_class = params.pop("multi_class")
if multi_class != "deprecated":
params["multi_class"] = multi_class

Check warning on line 60 in Orange/classification/logistic_regression.py

View check run for this annotation

Codecov / codecov/patch

Orange/classification/logistic_regression.py#L60

Added line #L60 was not covered by tests

# The default scikit-learn solver `lbfgs` (v0.22) does not support the
# l1 penalty.
solver, penalty = params.pop("solver"), params.get("penalty")
Expand Down
4 changes: 2 additions & 2 deletions Orange/evaluation/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
--------
>>> import Orange
>>> data = Orange.data.Table('iris')
>>> learner = Orange.classification.LogisticRegressionLearner(solver="liblinear")
>>> learner = Orange.classification.LogisticRegressionLearner()
>>> results = Orange.evaluation.TestOnTrainingData(data, [learner])
"""
Expand Down Expand Up @@ -296,7 +296,7 @@ class LogLoss(ClassificationScore):
Examples
--------
>>> Orange.evaluation.LogLoss(results)
array([0.3...])
array([0.1...])
"""
__wraps__ = skl_metrics.log_loss
Expand Down
12 changes: 11 additions & 1 deletion Orange/tests/test_logistic_regression.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Test methods with long descriptive names can omit docstrings
# pylint: disable=missing-docstring

from datetime import datetime
import unittest

import numpy as np
Expand All @@ -9,6 +10,7 @@
from Orange.data import Table, ContinuousVariable, Domain
from Orange.classification import LogisticRegressionLearner, Model
from Orange.evaluation import CrossValidation, CA
from Orange.util import OrangeDeprecationWarning


class TestLogisticRegressionLearner(unittest.TestCase):
Expand Down Expand Up @@ -149,8 +151,16 @@ def test_auto_solver(self):
# liblinear is default for l2 penalty
lr = LogisticRegressionLearner(penalty="l1", solver="auto")
skl_clf = lr._initialize_wrapped()
self.assertEqual(skl_clf.solver, "liblinear")
self.assertEqual(skl_clf.solver, "saga")
self.assertEqual(skl_clf.penalty, "l1")

def test_supports_weights(self):
self.assertTrue(LogisticRegressionLearner().supports_weights)

def test_multi_class_deprecation(self):
with self.assertWarns(OrangeDeprecationWarning):
LogisticRegressionLearner(penalty="l1", multi_class="multinomial")
now = datetime.now()
if (now.year, now.month) >= (2026, 1):
raise Exception("If Orange depends on scikit-learn >= 1.7, remove this test "
"and any mention of multi_class in LogisticRegressionLearner.")
4 changes: 2 additions & 2 deletions Orange/widgets/evaluate/tests/test_owpredictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1303,7 +1303,7 @@ def test_output_error_cls(self):
log_reg = LogisticRegressionLearner()
self.send_signal(self.widget.Inputs.predictors, log_reg(data), 0)
self.send_signal(self.widget.Inputs.predictors,
LogisticRegressionLearner(penalty="l1")(data), 1)
LogisticRegressionLearner(penalty="l1", max_iter=1000)(data), 1)
with data.unlocked(data.Y):
data.Y[1] = np.nan
self.send_signal(self.widget.Inputs.data, data)
Expand All @@ -1315,7 +1315,7 @@ def test_output_error_cls(self):
names = [f"{log_reg.name}{x}" for x in names]
self.assertEqual(names, [m.name for m in pred.domain.metas])
self.assertAlmostEqual(pred.metas[0, 4], 0.018, 3)
self.assertAlmostEqual(pred.metas[0, 9], 0.113, 3)
self.assertAlmostEqual(pred.metas[0, 9], 0.008, 3)
self.assertTrue(np.isnan(pred.metas[1, 4]))
self.assertTrue(np.isnan(pred.metas[1, 9]))

Expand Down
6 changes: 2 additions & 4 deletions Orange/widgets/visualize/tests/test_ownomogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,8 @@ def test_nomogram_nb_multiclass(self):
def test_nomogram_lr_multiclass(self):
"""Check probabilities for logistic regression classifier for various
values of classes and radio buttons for multiclass data"""
cls = LogisticRegressionLearner(
solver="liblinear"
)(self.lenses)
self._test_helper(cls, [9, 45, 52])
cls = LogisticRegressionLearner(max_iter=100)(self.lenses)
self._test_helper(cls, [18, 56, 78])

def test_nomogram_with_instance_nb(self):
"""Check initialized marker values and feature sorting for naive bayes
Expand Down

0 comments on commit 823f73d

Please sign in to comment.