-
Hi, I am using hydra_zen to configure a machine learning project. There I want to pass the type of the activation function to the Model, so that the model can instantiate it. Now to the problem: I am unable to overrride the activation function from the launch function. from typing import Type
from torch import nn
from hydra_zen import to_yaml, store, builds, zen, launch
from omegaconf import DictConfig, OmegaConf
class Model:
def __init__(self, activation_fn: Type[nn.Module]):
self.activation_fn = activation_fn
def app(zen_cfg: DictConfig, model: Model) -> None:
OmegaConf.resolve(zen_cfg)
print(to_yaml(zen_cfg))
store(Model, activation_fn=nn.PReLU, group='model')
Config = builds(
app,
hydra_defaults=[
{'model': 'Model'},
'_self_'
]
)
if __name__ == '__main__':
store.add_to_hydra_store()
launch(Config, zen(app), version_base='1.3') This does exactly what I want. The resulting zen_cfg ist
i.e. activation function will actually be a type. launch(Config, zen(app), version_base='1.3', overrides=['model.activation_fn=nn.SELU']) I get the zen_cfg
i.e. the activation function is a string. Is there a way to do this? Or might one of the following alternatives work: Changing the Thanks a lot! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
I did a bit of experimenting and the approach with config groups seems to work: from typing import Type
from torch import nn
from hydra_zen import to_yaml, store, builds, zen, launch, just
from omegaconf import DictConfig, OmegaConf
class Model:
def __init__(self, activation_fn: Type[nn.Module]):
self.activation_fn = activation_fn
def app(zen_cfg: DictConfig, model: Model) -> None:
OmegaConf.resolve(zen_cfg)
print(to_yaml(zen_cfg))
store(Model, group='model')
activation_fn_store = store(group='model/activation_fn')
activation_fn_store(just(nn.ReLU), name='relu')
activation_fn_store(just(nn.SELU), name='selu')
Config = builds(
app,
hydra_defaults=[
{'model': 'Model'},
{'model/activation_fn': 'relu'},
'_self_'
]
)
if __name__ == '__main__':
store.add_to_hydra_store()
launch(Config, zen(app), version_base='1.3', overrides=['model/activation_fn=selu']) results in
Please tell me there is a better way to do this! Or should I approach the problem completely differently? |
Beta Was this translation helpful? Give feedback.
I did a bit of experimenting and the approach with config groups seems to work: