Skip to content

Commit e439afc

Browse files
committed
Minor update
Add n_features_in_
1 parent a4dfb66 commit e439afc

File tree

3 files changed

+302
-174
lines changed

3 files changed

+302
-174
lines changed

asgl/skmodels.py

+19-11
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88
from sklearn.decomposition import SparsePCA
99
from sklearn.utils import check_X_y
1010

11+
non_adaptive_penalizations = ['lasso', 'ridge', 'gl', 'sgl']
12+
adaptive_penalizations = ['alasso', 'aridge', 'agl', 'asgl']
13+
grouped_penalizations = ['gl', 'agl', 'sgl', 'asgl']
14+
1115

1216
class BaseModel(BaseEstimator, RegressorMixin):
1317
def __init__(self, model='lm', penalization='lasso', quantile=0.5, fit_intercept=True, lambda1=0.1, alpha=0.5,
@@ -20,12 +24,6 @@ def __init__(self, model='lm', penalization='lasso', quantile=0.5, fit_intercept
2024
self.alpha = alpha
2125
self.solver = solver
2226
self.tol = tol
23-
self.coef_ = None
24-
self.intercept_ = None
25-
self.solver_stats = None
26-
self.non_adaptive_penalizations = ['lasso', 'ridge', 'gl', 'sgl']
27-
self.adaptive_penalizations = ['alasso', 'aridge', 'agl', 'asgl']
28-
self.grouped_penalizations = ['gl', 'agl', 'sgl', 'asgl']
2927

3028
def _quantile_function(self, X):
3129
return 0.5 * cvxpy.abs(X) + (self.quantile - 0.5) * X
@@ -182,13 +180,14 @@ def _sgl(self, X, y, group_index):
182180

183181
def fit(self, X, y, group_index=None, sample_weight=None):
184182
X, y = check_X_y(X, y)
185-
if self.penalization in self.grouped_penalizations and group_index is None:
183+
self.n_features_in_ = X.shape[1]
184+
if self.penalization in grouped_penalizations and group_index is None:
186185
raise ValueError(
187186
f'The penalization provided requires fitting the model with a group_index parameter but no group_index was detected.')
188187
if self.penalization is None:
189188
beta_sol = self._unpenalized(X=X, y=y)
190189
self._split_beta_sol(beta_sol)
191-
elif self.penalization in self.non_adaptive_penalizations:
190+
elif self.penalization in non_adaptive_penalizations:
192191
beta_sol = getattr(self, '_' + self.penalization)(X=X, y=y, group_index=group_index)
193192
self._split_beta_sol(beta_sol)
194193
else:
@@ -453,6 +452,8 @@ class Regressor(BaseModel, AdaptiveWeights):
453452
Estimated coefficients for the regression problem.
454453
intercept_: float
455454
Independent term in the regression model
455+
n_features_in_: int
456+
Number of features seen during fit.
456457
"""
457458

458459
def __init__(self, model='lm', penalization='lasso', quantile=0.5, fit_intercept=True, lambda1=0.1, alpha=0.5,
@@ -479,6 +480,12 @@ def __init__(self, model='lm', penalization='lasso', quantile=0.5, fit_intercept
479480
self.group_weights = group_weights
480481
self.weight_tol = weight_tol
481482

483+
484+
485+
486+
487+
488+
482489
def _aridge(self, X, y, group_index):
483490
X, m, _ = self._prepare_data(X)
484491
beta_var = cvxpy.Variable(m)
@@ -582,16 +589,17 @@ def _asgl(self, X, y, group_index):
582589

583590
def fit(self, X, y, group_index=None, sample_weight=None):
584591
X, y = check_X_y(X, y)
585-
if self.penalization in self.grouped_penalizations and group_index is None:
592+
self.n_features_in_ = X.shape[1]
593+
if self.penalization in grouped_penalizations and group_index is None:
586594
raise ValueError(f'The penalization provided requires fitting the model with a group_index parameter but '
587595
f'no group_index was detected.')
588596
if self.penalization is None:
589597
beta_sol = self._unpenalized(X=X, y=y)
590598
self._split_beta_sol(beta_sol)
591-
elif self.penalization in self.non_adaptive_penalizations:
599+
elif self.penalization in non_adaptive_penalizations:
592600
beta_sol = getattr(self, '_' + self.penalization)(X=X, y=y, group_index=group_index)
593601
self._split_beta_sol(beta_sol)
594-
elif self.penalization in self.adaptive_penalizations:
602+
elif self.penalization in adaptive_penalizations:
595603
self.fit_weights(X=X, y=y, group_index=group_index)
596604
beta_sol = getattr(self, '_' + self.penalization)(X=X, y=y, group_index=group_index)
597605
self._split_beta_sol(beta_sol)

setup.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@
99

1010
setup(
1111
name='asgl',
12-
version='2.1.1',
12+
version='2.1.2',
1313
author='Alvaro Mendez Civieta',
1414
author_email='[email protected]',
1515
license='GNU General Public License',
1616
zip_safe=False,
1717
url='https://github.com/alvaromc317/asgl',
18-
dowload_url='https://github.com/alvaromc317/asgl/archive/refs/tags/2.1.1.tar.gz',
18+
dowload_url='https://github.com/alvaromc317/asgl/archive/refs/tags/2.1.2.tar.gz',
1919
description='A regression solver for high dimensional penalized linear, quantile and logistic regression models',
2020
long_description=long_description,
2121
long_description_content_type='text/markdown',

0 commit comments

Comments
 (0)