diff --git a/models/generate_model.py b/models/generate_model.py index b6c76b5..02a8f9b 100644 --- a/models/generate_model.py +++ b/models/generate_model.py @@ -102,9 +102,7 @@ def create_model( ), None, ) - elif ( - model_name == "bd_linear_ncde" or "diagonal_linear_ncde" or "dense_linear_ncde" - ): + elif model_name in ["bd_linear_ncde", "diagonal_linear_ncde", "dense_linear_ncde"]: return ( LogLinearCDE( data_dim=data_dim,