Skip to content

Commit

Permalink
requested changes, _setup_output unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
Lilferrit committed Sep 4, 2024
1 parent 1ee28be commit 4cb18e1
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 5 deletions.
10 changes: 6 additions & 4 deletions casanovo/casanovo.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ def _setup_output(
be resolved to the current working directory.
output_root : str | None
The base name for the output files. If `None` the output root name will
be resolved to casanovo_<current data and time>
be resolved to casanovo_<current date and time>
overwrite: bool
Whether to overwrite log file if it already exists in the output
directory.
Expand All @@ -544,10 +544,12 @@ def _setup_output(
if output_dir is None:
output_path = Path.cwd()
else:
output_path = Path(output_dir)
output_path = Path(output_dir).expanduser().resolve()
if not output_path.is_dir():
raise FileNotFoundError(
f"Target output directory {output_dir} does not exists."
output_path.mkdir(parents=True)
warnings.warn(
f"Target output directory {output_dir} does not exists, "
"so it will be created."
)

if not overwrite:
Expand Down
2 changes: 1 addition & 1 deletion casanovo/denovo/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __init__(

if output_dir is None:
self.callbacks = []
warnings.warn(
logger.warning(
"Checkpoint directory not set in ModelRunner, "
"no checkpoint files will be saved."
)
Expand Down
20 changes: 20 additions & 0 deletions tests/unit_tests/test_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os
import pathlib
import platform
import re
import requests
import shutil
import tempfile
Expand Down Expand Up @@ -944,3 +945,22 @@ def test_check_dir(tmp_path):

utils.check_dir_file_exists(tmp_path, [dne_pattern])
utils.check_dir_file_exists(tmp_path, dne_pattern)


def test_setup_output(tmp_path, monkeypatch):
with monkeypatch.context() as mnk:
mnk.setattr(pathlib.Path, "cwd", lambda: tmp_path)
output_path, output_root = casanovo._setup_output(
None, None, False, "info"
)
assert output_path.resolve() == tmp_path.resolve()
assert re.fullmatch(r"casanovo_\d+", output_root) is not None

target_path = tmp_path / "foo"
with pytest.warns(UserWarning):
output_path, output_root = casanovo._setup_output(
str(target_path), "bar", False, "info"
)

assert output_path.resolve() == target_path.resolve()
assert output_root == "bar"

0 comments on commit 4cb18e1

Please sign in to comment.