diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 4ff5134d..0d93d57a 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -92,6 +92,7 @@ def train(config: Config): num_workers=config.data.num_workers, fake_data=config.data.fake, ) + model, model_config = get_model( config.name_model, config.type_model,