Skip to content

Commit

Permalink
Vesuvio rules (#374)
Browse files Browse the repository at this point in the history
* Add vesuvio rules and extracts

* Add tests for the Vesuvio rules

* Formatting and linting commit

---------

Co-authored-by: github-actions <[email protected]>
  • Loading branch information
Pasarus and github-actions authored Mar 7, 2025
1 parent 021cd8c commit df84dd8
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 1 deletion.
8 changes: 8 additions & 0 deletions rundetection/ingestion/extracts.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,12 @@ def mari_extract(job_request: JobRequest, dataset: Any) -> JobRequest:
return job_request


def vesuvio_extract(job_request: JobRequest, _: Any) -> JobRequest:
job_request.additional_values["runno"] = job_request.run_number

return job_request


def get_extraction_function(instrument: str) -> Callable[[JobRequest, Any], JobRequest]: # noqa: PLR0911
"""
Given an instrument name, return the additional metadata extraction function for the instrument
Expand All @@ -239,6 +245,8 @@ def get_extraction_function(instrument: str) -> Callable[[JobRequest, Any], JobR
return sans2d_extract
case "iris":
return iris_extract
case "vesuvio":
return vesuvio_extract
case _:
return skip_extract

Expand Down
7 changes: 7 additions & 0 deletions rundetection/rules/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
SansSliceWavs,
SansUserFile,
)
from rundetection.rules.vesuvio_rules import VesuvioEmptyRunsRule, VesuvioIPFileRule


def rule_factory(key_: str, value: T) -> Rule[Any]: # noqa: C901, PLR0911, PLR0912
Expand Down Expand Up @@ -86,6 +87,12 @@ def rule_factory(key_: str, value: T) -> Rule[Any]: # noqa: C901, PLR0911, PLR0
case "iriscalibration":
if isinstance(value, dict):
return IrisCalibrationRule(value)
case "vesuviovemptyrunsrule":
if isinstance(value, str):
return VesuvioEmptyRunsRule(value)
case "vesuvioipfilerule":
if isinstance(value, str):
return VesuvioIPFileRule(value)
case _:
raise MissingRuleError(f"Implementation of Rule: {key_} does not exist.")

Expand Down
28 changes: 28 additions & 0 deletions rundetection/rules/vesuvio_rules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""
Vesuvio Rules
"""

import logging

from rundetection.job_requests import JobRequest
from rundetection.rules.rule import Rule

logger = logging.getLogger(__name__)


class VesuvioEmptyRunsRule(Rule[str]):
"""
Adds the empty runs numbers to JobRequest
"""

def verify(self, job_request: JobRequest) -> None:
job_request.additional_values["empty_runs"] = self._value


class VesuvioIPFileRule(Rule[str]):
"""
Adds the ip_file to JobRequest
"""

def verify(self, job_request: JobRequest) -> None:
job_request.additional_values["ip_file"] = self._value
11 changes: 11 additions & 0 deletions test/ingestion/test_extracts.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
sans_extract,
skip_extract,
tosca_extract,
vesuvio_extract,
)
from rundetection.job_requests import JobRequest

Expand Down Expand Up @@ -65,6 +66,7 @@ def test_skip_extract(caplog: LogCaptureFixture):
("loq", "loq_extract"),
("sans2d", "sans2d_extract"),
("iris", "iris_extract"),
("vesuvio", "vesuvio_extract"),
],
)
def test_get_extraction_function(input_value, expected_function_name):
Expand Down Expand Up @@ -389,3 +391,12 @@ def test_get_cycle_string_from_path_invalid():
path = Path("/no/cycle/string/here")
with pytest.raises(IngestError):
get_cycle_string_from_path(path)


def test_vesuvio_extract_adds_runno(job_request):
"""
Tests that the extract adds runno to Vesuvio jobs
"""
result = vesuvio_extract(job_request, None)

assert result.additional_values["runno"] == 12345 # noqa: PLR2004
3 changes: 3 additions & 0 deletions test/rules/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
SansSliceWavs,
SansUserFile,
)
from rundetection.rules.vesuvio_rules import VesuvioEmptyRunsRule, VesuvioIPFileRule


def assert_correct_rule(name: str, value: Any, rule_type: type[Rule]):
Expand Down Expand Up @@ -64,6 +65,8 @@ def assert_correct_rule(name: str, value: Any, rule_type: type[Rule]):
("sansslicewavs", "[2.7, 3.7, 4.7, 5.7, 6.7, 8.7, 10.5]", SansSliceWavs),
("irisreduction", True, IrisReductionRule),
("iriscalibration", {"002": "00148587", "004": "00148587"}, IrisCalibrationRule),
("vesuvioipfilerule", "ip00001.par", VesuvioIPFileRule),
("vesuviovemptyrunsrule", "123-321", VesuvioEmptyRunsRule),
],
)
def test_rule_factory_returns_correct_rule(rule_key, rule_value, expected_rule):
Expand Down
65 changes: 65 additions & 0 deletions test/rules/test_vesuvio_rules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""
Test for vesuvio rules
"""

import os
from pathlib import Path

import pytest

from rundetection.ingestion.ingest import JobRequest
from rundetection.rules.vesuvio_rules import VesuvioEmptyRunsRule, VesuvioIPFileRule


@pytest.fixture(autouse=True)
def _working_directory_fix():
# Set dir to repo root for purposes of the test.
current_working_directory = Path.cwd()
if current_working_directory.name == "rules":
os.chdir(current_working_directory / ".." / "..")


@pytest.fixture
def job_request():
"""
job request fixture
:return: job request
"""
return JobRequest(
run_number=100,
filepath=Path("/archive/100/VESUVIO100.nxs"),
experiment_title="Test experiment",
additional_values={},
additional_requests=[],
raw_frames=3,
good_frames=0,
users="",
run_start="",
run_end="",
instrument="vesuvio",
experiment_number="",
)


def test_vesuvio_empty_runs_rule(job_request):
"""
Test empty runs are set and attached to additional values
:param job_request: job request fixture
:return: none
"""
rule = VesuvioEmptyRunsRule("123-132")
rule.verify(job_request)

assert job_request.additional_values["empty_runs"] == "123-132"


def test_vesuvio_ip_file_rule(job_request):
"""
Test that the IP file is set via the specification
:param job_request: JobRequest fixture
:return: None
"""
rule = VesuvioIPFileRule("IP0001.par")
rule.verify(job_request)

assert job_request.additional_values["ip_file"] == "IP0001.par"
4 changes: 3 additions & 1 deletion test/test_data/specifications/vesuvio_specification.json
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
{
"enabled": false
"enabled": true,
"vesuvioipfilerule": "IP0005.par",
"vesuviovemptyrunsrule": "50309-50341"
}

0 comments on commit df84dd8

Please sign in to comment.