8
8
from sklearn .decomposition import SparsePCA
9
9
from sklearn .utils import check_X_y
10
10
11
+ non_adaptive_penalizations = ['lasso' , 'ridge' , 'gl' , 'sgl' ]
12
+ adaptive_penalizations = ['alasso' , 'aridge' , 'agl' , 'asgl' ]
13
+ grouped_penalizations = ['gl' , 'agl' , 'sgl' , 'asgl' ]
14
+
11
15
12
16
class BaseModel (BaseEstimator , RegressorMixin ):
13
17
def __init__ (self , model = 'lm' , penalization = 'lasso' , quantile = 0.5 , fit_intercept = True , lambda1 = 0.1 , alpha = 0.5 ,
@@ -20,12 +24,6 @@ def __init__(self, model='lm', penalization='lasso', quantile=0.5, fit_intercept
20
24
self .alpha = alpha
21
25
self .solver = solver
22
26
self .tol = tol
23
- self .coef_ = None
24
- self .intercept_ = None
25
- self .solver_stats = None
26
- self .non_adaptive_penalizations = ['lasso' , 'ridge' , 'gl' , 'sgl' ]
27
- self .adaptive_penalizations = ['alasso' , 'aridge' , 'agl' , 'asgl' ]
28
- self .grouped_penalizations = ['gl' , 'agl' , 'sgl' , 'asgl' ]
29
27
30
28
def _quantile_function (self , X ):
31
29
return 0.5 * cvxpy .abs (X ) + (self .quantile - 0.5 ) * X
@@ -182,13 +180,14 @@ def _sgl(self, X, y, group_index):
182
180
183
181
def fit (self , X , y , group_index = None , sample_weight = None ):
184
182
X , y = check_X_y (X , y )
185
- if self .penalization in self .grouped_penalizations and group_index is None :
183
+ self .n_features_in_ = X .shape [1 ]
184
+ if self .penalization in grouped_penalizations and group_index is None :
186
185
raise ValueError (
187
186
f'The penalization provided requires fitting the model with a group_index parameter but no group_index was detected.' )
188
187
if self .penalization is None :
189
188
beta_sol = self ._unpenalized (X = X , y = y )
190
189
self ._split_beta_sol (beta_sol )
191
- elif self .penalization in self . non_adaptive_penalizations :
190
+ elif self .penalization in non_adaptive_penalizations :
192
191
beta_sol = getattr (self , '_' + self .penalization )(X = X , y = y , group_index = group_index )
193
192
self ._split_beta_sol (beta_sol )
194
193
else :
@@ -453,6 +452,8 @@ class Regressor(BaseModel, AdaptiveWeights):
453
452
Estimated coefficients for the regression problem.
454
453
intercept_: float
455
454
Independent term in the regression model
455
+ n_features_in_: int
456
+ Number of features seen during fit.
456
457
"""
457
458
458
459
def __init__ (self , model = 'lm' , penalization = 'lasso' , quantile = 0.5 , fit_intercept = True , lambda1 = 0.1 , alpha = 0.5 ,
@@ -479,6 +480,12 @@ def __init__(self, model='lm', penalization='lasso', quantile=0.5, fit_intercept
479
480
self .group_weights = group_weights
480
481
self .weight_tol = weight_tol
481
482
483
+
484
+
485
+
486
+
487
+
488
+
482
489
def _aridge (self , X , y , group_index ):
483
490
X , m , _ = self ._prepare_data (X )
484
491
beta_var = cvxpy .Variable (m )
@@ -582,16 +589,17 @@ def _asgl(self, X, y, group_index):
582
589
583
590
def fit (self , X , y , group_index = None , sample_weight = None ):
584
591
X , y = check_X_y (X , y )
585
- if self .penalization in self .grouped_penalizations and group_index is None :
592
+ self .n_features_in_ = X .shape [1 ]
593
+ if self .penalization in grouped_penalizations and group_index is None :
586
594
raise ValueError (f'The penalization provided requires fitting the model with a group_index parameter but '
587
595
f'no group_index was detected.' )
588
596
if self .penalization is None :
589
597
beta_sol = self ._unpenalized (X = X , y = y )
590
598
self ._split_beta_sol (beta_sol )
591
- elif self .penalization in self . non_adaptive_penalizations :
599
+ elif self .penalization in non_adaptive_penalizations :
592
600
beta_sol = getattr (self , '_' + self .penalization )(X = X , y = y , group_index = group_index )
593
601
self ._split_beta_sol (beta_sol )
594
- elif self .penalization in self . adaptive_penalizations :
602
+ elif self .penalization in adaptive_penalizations :
595
603
self .fit_weights (X = X , y = y , group_index = group_index )
596
604
beta_sol = getattr (self , '_' + self .penalization )(X = X , y = y , group_index = group_index )
597
605
self ._split_beta_sol (beta_sol )
0 commit comments