Skip to content

Commit

Permalink
Fix custom residues in config (#229)
Browse files Browse the repository at this point in the history
* Fix specifying custom residues

* Update changelog
  • Loading branch information
bittremieux authored Aug 16, 2023
1 parent 3ac0887 commit 514db80
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Casanovo now runs on CPU and can passes all tests.
- Upgrade to Depthcharge v0.2.0 to fix sinusoidal encoding.
- Correctly refer to input peak files by their full file path.
- Specifying custom residues to retrain Casanovo is now possible.

## [3.3.0] - 2023-04-04

Expand Down
3 changes: 2 additions & 1 deletion casanovo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class Config:
dropout=float,
dim_intensity=int,
max_length=int,
residues=dict,
n_log=int,
tb_summarywriter=str,
warmup_iters=int,
Expand All @@ -66,9 +67,9 @@ class Config:
save_top_k=int,
model_save_folder_path=str,
val_check_interval=int,
calculate_precision=bool,
accelerator=str,
devices=int,
calculate_precision=bool,
)

def __init__(self, config_file: Optional[str] = None):
Expand Down
16 changes: 13 additions & 3 deletions tests/unit_tests/test_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
"""Test configuration loading"""
import pytest

from casanovo.config import Config


Expand All @@ -17,11 +15,23 @@ def test_override(tmp_path):
"""Test overriding the default"""
yml = tmp_path / "test.yml"
with yml.open("w+") as f_out:
f_out.write("random_seed: 42\ntop_match: 3")
f_out.write(
"""random_seed: 42
top_match: 3
residues:
W: 1
O: 2
U: 3
T: 4
"""
)

config = Config(yml)
assert config.random_seed == 42
assert config["random_seed"] == 42
assert config.accelerator == "auto"
assert config.top_match == 3
assert len(config.residues) == 4
for i, residue in enumerate("WOUT", 1):
assert config["residues"][residue] == i
assert config.file == str(yml)

0 comments on commit 514db80

Please sign in to comment.