Skip to content

Commit

Permalink
Make y optional in GeneralizedLinearRegressorCV when using a formula (
Browse files Browse the repository at this point in the history
#833)

* Make y in GeneralizedLinearRegressorCV optional

* Add test for formula-based glmcv
  • Loading branch information
stanmart authored Sep 9, 2024
1 parent c59db80 commit 28f0ff9
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/glum/_glm_cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ def _validate_hyperparameters(self) -> None:
def fit(
self,
X: ArrayLike,
y: ArrayLike,
y: Optional[ArrayLike] = None,
sample_weight: Optional[ArrayLike] = None,
offset: Optional[ArrayLike] = None,
*,
Expand Down
43 changes: 43 additions & 0 deletions tests/glm/test_cv_glm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import pandas as pd
import pytest
import tabmat as tm
from scipy import sparse as sparse
Expand Down Expand Up @@ -123,3 +124,45 @@ def test_normal_ridge_comparison(fit_intercept):
np.testing.assert_allclose(glm_pred, el_pred, atol=4e-6)
np.testing.assert_allclose(glm.intercept_, ridge.intercept_, atol=4e-7)
np.testing.assert_allclose(glm.coef_, ridge.coef_, atol=3e-6)


def test_formula():
"""Model with formula and model with externally constructed model matrix should
match.
"""
n_samples = 100
n_alphas = 2
tol = 1e-9

np.random.seed(10)
data = pd.DataFrame(
{
"y": np.random.rand(n_samples),
"x1": np.random.rand(n_samples),
"x2": np.random.rand(n_samples),
}
)
formula = "y ~ x1 + x2"

model_formula = GeneralizedLinearRegressorCV(
family="normal",
formula=formula,
fit_intercept=False,
n_alphas=n_alphas,
gradient_tol=tol,
).fit(data)

y = data["y"]
X = data[["x1", "x2"]]

model_pandas = GeneralizedLinearRegressorCV(
family="normal",
fit_intercept=False,
n_alphas=n_alphas,
gradient_tol=tol,
).fit(X, y)

np.testing.assert_almost_equal(model_pandas.coef_, model_formula.coef_)
np.testing.assert_array_equal(
model_pandas.feature_names_, model_formula.feature_names_
)

0 comments on commit 28f0ff9

Please sign in to comment.