From 514db80ab5fc6963d8a15daf89acd2e4200d9cfa Mon Sep 17 00:00:00 2001 From: Wout Bittremieux Date: Wed, 16 Aug 2023 11:10:32 +0200 Subject: [PATCH] Fix custom residues in config (#229) * Fix specifying custom residues * Update changelog --- CHANGELOG.md | 1 + casanovo/config.py | 3 ++- tests/unit_tests/test_config.py | 16 +++++++++++++--- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a94acf9c..f600c7b1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Casanovo now runs on CPU and can passes all tests. - Upgrade to Depthcharge v0.2.0 to fix sinusoidal encoding. - Correctly refer to input peak files by their full file path. +- Specifying custom residues to retrain Casanovo is now possible. ## [3.3.0] - 2023-04-04 diff --git a/casanovo/config.py b/casanovo/config.py index f2d24e3b..0274a1b1 100644 --- a/casanovo/config.py +++ b/casanovo/config.py @@ -50,6 +50,7 @@ class Config: dropout=float, dim_intensity=int, max_length=int, + residues=dict, n_log=int, tb_summarywriter=str, warmup_iters=int, @@ -66,9 +67,9 @@ class Config: save_top_k=int, model_save_folder_path=str, val_check_interval=int, + calculate_precision=bool, accelerator=str, devices=int, - calculate_precision=bool, ) def __init__(self, config_file: Optional[str] = None): diff --git a/tests/unit_tests/test_config.py b/tests/unit_tests/test_config.py index 8da26f8c..1e2ef338 100644 --- a/tests/unit_tests/test_config.py +++ b/tests/unit_tests/test_config.py @@ -1,6 +1,4 @@ """Test configuration loading""" -import pytest - from casanovo.config import Config @@ -17,11 +15,23 @@ def test_override(tmp_path): """Test overriding the default""" yml = tmp_path / "test.yml" with yml.open("w+") as f_out: - f_out.write("random_seed: 42\ntop_match: 3") + f_out.write( + """random_seed: 42 +top_match: 3 +residues: + W: 1 + O: 2 + U: 3 + T: 4 +""" + ) config = Config(yml) assert config.random_seed == 42 assert config["random_seed"] == 42 assert config.accelerator == "auto" assert config.top_match == 3 + assert len(config.residues) == 4 + for i, residue in enumerate("WOUT", 1): + assert config["residues"][residue] == i assert config.file == str(yml)