We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
OptunaSearchCV
cross_val_predict
cross_val_predict should accept OptunaSearchCV as estimator but fails when scikit-learn >= 1.4.0 due to validate_params.
validate_params
https://github.com/scikit-learn/scikit-learn/blob/46b5f541138458803e39f9ce5810878849e4ecf7/sklearn/model_selection/_validation.py#L1035-L1059
--------------------------------------------------------------------------- InvalidParameterError Traceback (most recent call last) Cell In[1], line 15 6 X, y = make_regression(n_samples=100, n_features=10, bias=1, random_state=334) 8 ocv = optuna.integration.OptunaSearchCV( 9 PLSRegression(), 10 param_distributions=dict( (...) 13 cv=5, 14 ) ---> 15 y_oof = cross_val_predict(ocv, X, y, cv=5) File ~/miniforge3/envs/py311/lib/python3.11/site-packages/sklearn/utils/_param_validation.py:203, in validate_params.<locals>.decorator.<locals>.wrapper(*args, **kwargs) 200 to_ignore += ["self", "cls"] 201 params = {k: v for k, v in params.arguments.items() if k not in to_ignore} --> 203 validate_parameter_constraints( 204 parameter_constraints, params, caller_name=func.__qualname__ 205 ) 207 try: 208 with config_context( 209 skip_parameter_validation=( 210 prefer_skip_nested_validation or global_skip_validation 211 ) 212 ): File ~/miniforge3/envs/py311/lib/python3.11/site-packages/sklearn/utils/_param_validation.py:95, in validate_parameter_constraints(parameter_constraints, params, caller_name) 89 else: 90 constraints_str = ( 91 f"{', '.join([str(c) for c in constraints[:-1]])} or" 92 f" {constraints[-1]}" 93 ) ---> 95 raise InvalidParameterError( 96 f"The {param_name!r} parameter of {caller_name} must be" 97 f" {constraints_str}. Got {param_val!r} instead." 98 ) InvalidParameterError: The 'estimator' parameter of cross_val_predict must be an object implementing 'fit' and 'predict'. Got OptunaSearchCV(cv=5, estimator=PLSRegression(), n_jobs=1, param_distributions={'n_components': IntDistribution(high=10, log=False, low=1, step=1)}) instead.
import optuna from sklearn.cross_decomposition import PLSRegression from sklearn.datasets import make_regression from sklearn.model_selection import cross_val_predict X, y = make_regression(n_samples=100, n_features=10, bias=1, random_state=334) ocv = optuna.integration.OptunaSearchCV( PLSRegression(), param_distributions=dict( n_components=optuna.distributions.IntDistribution(1, 10) ), cv=5, ) y_oof = cross_val_predict(ocv, X, y, cv=5)
No response
The text was updated successfully, but these errors were encountered:
Mapping
dict
param_distributions
No branches or pull requests
Expected behavior
cross_val_predict
should acceptOptunaSearchCV
as estimator but fails when scikit-learn >= 1.4.0 due tovalidate_params
.https://github.com/scikit-learn/scikit-learn/blob/46b5f541138458803e39f9ce5810878849e4ecf7/sklearn/model_selection/_validation.py#L1035-L1059
Environment
Error messages, stack traces, or logs
Steps to reproduce
Additional context (optional)
No response
The text was updated successfully, but these errors were encountered: