From 1f922e825d82b4f40e8e65bb7165a276852a7960 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 5 Jul 2023 14:25:22 -0400 Subject: [PATCH] Use registry.MODEL_FAMILY_FROM_NAME when determining net kwargs in vak.models.get --- src/vak/models/get.py | 39 ++++++++++++++++++++++----------------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/src/vak/models/get.py b/src/vak/models/get.py index 380edff9b..b9a137cf8 100644 --- a/src/vak/models/get.py +++ b/src/vak/models/get.py @@ -53,24 +53,29 @@ def get(name: str, f"Valid model names are: {registry.MODEL_NAMES}" ) from e - # still need to special case model logic here - net_init_params = list( - inspect.signature( - model_class.definition.network.__init__ - ).parameters.keys() - ) - if ('num_input_channels' in net_init_params) and ('num_freqbins' in net_init_params): - num_input_channels = input_shape[-3] - num_freqbins = input_shape[-2] - config["network"].update( - num_classes=num_classes, - num_input_channels=num_input_channels, - num_freqbins=num_freqbins - ) - else: - raise ValueError( - f"Unable to determine network init arguments for model: {name}" + model_family = registry.MODEL_FAMILY_FROM_NAME[name] + + if model_family == 'FrameClassificationModel': + # still need to special case model logic here + net_init_params = list( + inspect.signature( + model_class.definition.network.__init__ + ).parameters.keys() ) + if ('num_input_channels' in net_init_params) and ('num_freqbins' in net_init_params): + num_input_channels = input_shape[-3] + num_freqbins = input_shape[-2] + config["network"].update( + num_classes=num_classes, + num_input_channels=num_input_channels, + num_freqbins=num_freqbins + ) + else: + raise ValueError( + f"Detected that model with name '{name}' was family '{model_family}', but " + f"unable to determine network init arguments for model. Currently all models " + f"in this family must have networks with parameters ``num_input_channels`` and ``num_freqbins``" + ) model = model_class.from_config(config=config, labelmap=labelmap, post_tfm=post_tfm)