Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade Optuna from v2.x.x to v3.0.0 #2360

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

__version__ = "1.3.0.dev0"
__version__ = "1.4.0.dev0"
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,8 @@
BaseDistribution,
CategoricalChoiceType,
CategoricalDistribution,
DiscreteUniformDistribution,
IntLogUniformDistribution,
IntUniformDistribution,
LogUniformDistribution,
UniformDistribution,
IntDistribution,
FloatDistribution,
)
from optuna.trial import Trial

Expand All @@ -62,17 +59,17 @@ def create_optuna_distribution_from_config(
assert param.low is not None
assert param.high is not None
if param.log:
return IntLogUniformDistribution(int(param.low), int(param.high))
return IntDistribution(int(param.low), int(param.high), log=True)
step = int(param.step) if param.step is not None else 1
return IntUniformDistribution(int(param.low), int(param.high), step=step)
return IntDistribution(int(param.low), int(param.high), step=step)
if param.type == DistributionType.float:
assert param.low is not None
assert param.high is not None
if param.log:
return LogUniformDistribution(param.low, param.high)
return FloatDistribution(param.low, param.high, log=True)
if param.step is not None:
return DiscreteUniformDistribution(param.low, param.high, param.step)
return UniformDistribution(param.low, param.high)
return FloatDistribution(param.low, param.high, step=param.step)
return FloatDistribution(param.low, param.high)
raise NotImplementedError(f"{param.type} is not supported by Optuna sweeper.")


Expand Down Expand Up @@ -107,23 +104,21 @@ def create_optuna_distribution_from_override(override: Override) -> Any:
or isinstance(value.stop, float)
or isinstance(value.step, float)
):
return DiscreteUniformDistribution(value.start, value.stop, value.step)
return IntUniformDistribution(
int(value.start), int(value.stop), step=int(value.step)
)
return FloatDistribution(value.start, value.stop, step=value.step)
return IntDistribution(int(value.start), int(value.stop), step=int(value.step))

if override.is_interval_sweep():
assert isinstance(value, IntervalSweep)
assert value.start is not None
assert value.end is not None
if "log" in value.tags:
if isinstance(value.start, int) and isinstance(value.end, int):
return IntLogUniformDistribution(int(value.start), int(value.end))
return LogUniformDistribution(value.start, value.end)
return IntDistribution(int(value.start), int(value.end), log=True)
return FloatDistribution(value.start, value.end, log=True)
else:
if isinstance(value.start, int) and isinstance(value.end, int):
return IntUniformDistribution(value.start, value.end)
return UniformDistribution(value.start, value.end)
return IntDistribution(value.start, value.end)
return FloatDistribution(value.start, value.end)

raise NotImplementedError(f"{override} is not supported by Optuna sweeper.")

Expand Down Expand Up @@ -266,13 +261,13 @@ def _parse_sweeper_params_config(self) -> List[str]:
def _to_grid_sampler_choices(self, distribution: BaseDistribution) -> Any:
if isinstance(distribution, CategoricalDistribution):
return distribution.choices
elif isinstance(distribution, IntUniformDistribution):
elif isinstance(distribution, IntDistribution):
assert (
distribution.step is not None
), "`step` of IntUniformDistribution must be a positive integer."
), "`step` of IntDistribution must be a positive integer."
n_items = (distribution.high - distribution.low) // distribution.step
return [distribution.low + i * distribution.step for i in range(n_items)]
elif isinstance(distribution, DiscreteUniformDistribution):
elif isinstance(distribution, FloatDistribution):
n_items = int((distribution.high - distribution.low) // distribution.q)
return [distribution.low + i * distribution.q for i in range(n_items)]
else:
Expand Down
3 changes: 1 addition & 2 deletions plugins/hydra_optuna_sweeper/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@
],
install_requires=[
"hydra-core>=1.1.0.dev7",
"optuna>=2.10.0,<3.0.0",
"sqlalchemy~=1.3.0", # TODO: Unpin when upgrading to optuna v3.0
"optuna>=3.0.0",
],
include_package_data=True,
)
35 changes: 16 additions & 19 deletions plugins/hydra_optuna_sweeper/tests/test_optuna_sweeper_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,8 @@
from optuna.distributions import (
BaseDistribution,
CategoricalDistribution,
DiscreteUniformDistribution,
IntLogUniformDistribution,
IntUniformDistribution,
LogUniformDistribution,
UniformDistribution,
IntDistribution,
FloatDistribution,
)
from optuna.samplers import RandomSampler
from pytest import mark, warns
Expand Down Expand Up @@ -59,24 +56,24 @@ def check_distribution(expected: BaseDistribution, actual: BaseDistribution) ->
{"type": "categorical", "choices": [1, 2, 3]},
CategoricalDistribution([1, 2, 3]),
),
({"type": "int", "low": 0, "high": 10}, IntUniformDistribution(0, 10)),
({"type": "int", "low": 0, "high": 10}, IntDistribution(0, 10)),
(
{"type": "int", "low": 0, "high": 10, "step": 2},
IntUniformDistribution(0, 10, step=2),
IntDistribution(0, 10, step=2),
),
({"type": "int", "low": 0, "high": 5}, IntUniformDistribution(0, 5)),
({"type": "int", "low": 0, "high": 5}, IntDistribution(0, 5)),
(
{"type": "int", "low": 1, "high": 100, "log": True},
IntLogUniformDistribution(1, 100),
IntDistribution(1, 100, log=True),
),
({"type": "float", "low": 0, "high": 1}, UniformDistribution(0, 1)),
({"type": "float", "low": 0, "high": 1}, FloatDistribution(0, 1)),
(
{"type": "float", "low": 0, "high": 10, "step": 2},
DiscreteUniformDistribution(0, 10, 2),
FloatDistribution(0, 10, step=2),
),
(
{"type": "float", "low": 1, "high": 100, "log": True},
LogUniformDistribution(1, 100),
FloatDistribution(1, 100, log=True),
),
],
)
Expand All @@ -92,12 +89,12 @@ def test_create_optuna_distribution_from_config(input: Any, expected: Any) -> No
("key=choice(true, false)", CategoricalDistribution([True, False])),
("key=choice('hello', 'world')", CategoricalDistribution(["hello", "world"])),
("key=shuffle(range(1,3))", CategoricalDistribution((1, 2))),
("key=range(1,3)", IntUniformDistribution(1, 3)),
("key=interval(1, 5)", UniformDistribution(1, 5)),
("key=int(interval(1, 5))", IntUniformDistribution(1, 5)),
("key=tag(log, interval(1, 5))", LogUniformDistribution(1, 5)),
("key=tag(log, int(interval(1, 5)))", IntLogUniformDistribution(1, 5)),
("key=range(0.5, 5.5, step=1)", DiscreteUniformDistribution(0.5, 5.5, 1)),
("key=range(1,3)", IntDistribution(1, 3)),
("key=interval(1, 5)", FloatDistribution(1, 5)),
("key=int(interval(1, 5))", IntDistribution(1, 5)),
("key=tag(log, interval(1, 5))", FloatDistribution(1, 5, log=True)),
("key=tag(log, int(interval(1, 5)))", IntDistribution(1, 5, log=True)),
("key=range(0.5, 5.5, step=1)", FloatDistribution(0.5, 5.5, step=1)),
],
)
def test_create_optuna_distribution_from_override(input: Any, expected: Any) -> None:
Expand All @@ -121,7 +118,7 @@ def test_create_optuna_distribution_from_override(input: Any, expected: Any) ->
(
{
"key1": CategoricalDistribution([1, 2]),
"key3": IntUniformDistribution(1, 3),
"key3": IntDistribution(1, 3),
},
{"key2": "5"},
),
Expand Down
8 changes: 4 additions & 4 deletions website/docs/plugins/optuna_sweeper.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ Hydra provides a override parser that support rich syntax. Please refer to [Over

#### Interval override

By default, `interval` is converted to [`UniformDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.UniformDistribution.html). You can use [`IntUniformDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.IntUniformDistribution.html), [`LogUniformDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.LogUniformDistribution.html) or [`IntLogUniformDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.IntLogUniformDistribution.html) by casting the interval to `int` and tagging it with `log`.
By default, `interval` is converted to [`FloatDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.FloatDistribution.html). You can use [`IntDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.IntDistribution.html) by casting the interval to `int`.

<details><summary>Example for interval override</summary>

Expand Down Expand Up @@ -147,8 +147,8 @@ The output is as follows:

#### Range override

`range` is converted to [`IntUniformDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.IntUniformDistribution.html). If you apply `shuffle` to `range`, [`CategoricalDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.CategoricalDistribution.html) is used instead.
If any of `range`'s start, stop or step is of type float, it will be converted to [`DiscreteUniformDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.DiscreteUniformDistribution.html)
`range` is converted to [`IntDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.IntDistribution.html). If you apply `shuffle` to `range`, [`CategoricalDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.CategoricalDistribution.html) is used instead.
If any of `range`'s start, stop or step is of type float, it will be converted to [`FloatDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.FloatDistribution.html)

<details><summary>Example for range override</summary>

Expand Down Expand Up @@ -321,4 +321,4 @@ Configuring a trial object is done in the following sequence:
- Command line overrides are set
- `custom_search_space` parameters are set

It is not allowed to set search space parameters in the `custom_search_space` method for parameters which have a fixed value from command line overrides. [Trial.user_attrs](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.trial.Trial.html#optuna.trial.Trial.user_attrs) can be inspected to find any of such fixed parameters.
It is not allowed to set search space parameters in the `custom_search_space` method for parameters which have a fixed value from command line overrides. [Trial.user_attrs](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.trial.Trial.html#optuna.trial.Trial.user_attrs) can be inspected to find any of such fixed parameters.
Loading