Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow use of OptunaSearchCV with cross_val_predict. #173

Open
yu9824 opened this issue Oct 22, 2024 · 0 comments
Open

Allow use of OptunaSearchCV with cross_val_predict. #173

yu9824 opened this issue Oct 22, 2024 · 0 comments
Labels
bug Something isn't working

Comments

@yu9824
Copy link
Contributor

yu9824 commented Oct 22, 2024

Expected behavior

cross_val_predict should accept OptunaSearchCV as estimator but fails when scikit-learn >= 1.4.0 due to validate_params.

https://github.com/scikit-learn/scikit-learn/blob/46b5f541138458803e39f9ce5810878849e4ecf7/sklearn/model_selection/_validation.py#L1035-L1059

Environment

  • Optuna version:3.5.0
  • Optuna Integration version:3.5.0
  • Python version:3.11.6
  • OS:macOS-14.7-arm64-arm-64bit
  • scikit-learn version: 1.4.0

Error messages, stack traces, or logs

---------------------------------------------------------------------------
InvalidParameterError                     Traceback (most recent call last)
Cell In[1], line 15
      6 X, y = make_regression(n_samples=100, n_features=10, bias=1, random_state=334)
      8 ocv = optuna.integration.OptunaSearchCV(
      9     PLSRegression(),
     10     param_distributions=dict(
   (...)
     13     cv=5,
     14 )
---> 15 y_oof = cross_val_predict(ocv, X, y, cv=5)

File ~/miniforge3/envs/py311/lib/python3.11/site-packages/sklearn/utils/_param_validation.py:203, in validate_params.<locals>.decorator.<locals>.wrapper(*args, **kwargs)
    200 to_ignore += ["self", "cls"]
    201 params = {k: v for k, v in params.arguments.items() if k not in to_ignore}
--> 203 validate_parameter_constraints(
    204     parameter_constraints, params, caller_name=func.__qualname__
    205 )
    207 try:
    208     with config_context(
    209         skip_parameter_validation=(
    210             prefer_skip_nested_validation or global_skip_validation
    211         )
    212     ):

File ~/miniforge3/envs/py311/lib/python3.11/site-packages/sklearn/utils/_param_validation.py:95, in validate_parameter_constraints(parameter_constraints, params, caller_name)
     89 else:
     90     constraints_str = (
     91         f"{', '.join([str(c) for c in constraints[:-1]])} or"
     92         f" {constraints[-1]}"
     93     )
---> 95 raise InvalidParameterError(
     96     f"The {param_name!r} parameter of {caller_name} must be"
     97     f" {constraints_str}. Got {param_val!r} instead."
     98 )

InvalidParameterError: The 'estimator' parameter of cross_val_predict must be an object implementing 'fit' and 'predict'. Got OptunaSearchCV(cv=5, estimator=PLSRegression(), n_jobs=1,
               param_distributions={'n_components': IntDistribution(high=10, log=False, low=1, step=1)}) instead.

Steps to reproduce

import optuna
from sklearn.cross_decomposition import PLSRegression
from sklearn.datasets import make_regression
from sklearn.model_selection import cross_val_predict

X, y = make_regression(n_samples=100, n_features=10, bias=1, random_state=334)

ocv = optuna.integration.OptunaSearchCV(
    PLSRegression(),
    param_distributions=dict(
        n_components=optuna.distributions.IntDistribution(1, 10)
    ),
    cv=5,
)
y_oof = cross_val_predict(ocv, X, y, cv=5)

Additional context (optional)

No response

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant