From 74c2c43e978c40c2288c436d7479c5edc7f9cc05 Mon Sep 17 00:00:00 2001 From: Michael Walsh Date: Mon, 25 Nov 2024 14:51:32 -0500 Subject: [PATCH] fix versioning integration checks --- .../request/DiffractionCalibrationRequest.py | 4 +- .../request/LoadCalibrationRecordRequest.py | 9 +++++ src/snapred/backend/data/GroceryService.py | 4 +- .../backend/service/CalibrationService.py | 14 ++++++- tests/integration/test_versions_in_order.py | 40 +++++++++++-------- tests/util_tests/test_state_helpers.py | 4 +- 6 files changed, 50 insertions(+), 25 deletions(-) create mode 100644 src/snapred/backend/dao/request/LoadCalibrationRecordRequest.py diff --git a/src/snapred/backend/dao/request/DiffractionCalibrationRequest.py b/src/snapred/backend/dao/request/DiffractionCalibrationRequest.py index d773f13da..17b0ce27a 100644 --- a/src/snapred/backend/dao/request/DiffractionCalibrationRequest.py +++ b/src/snapred/backend/dao/request/DiffractionCalibrationRequest.py @@ -3,7 +3,7 @@ from pydantic import BaseModel, field_validator -from snapred.backend.dao.indexing.Versioning import VERSION_DEFAULT +from snapred.backend.dao.indexing.Versioning import Version, VersionState from snapred.backend.dao.Limit import Pair from snapred.backend.dao.state.FocusGroup import FocusGroup from snapred.backend.error.ContinueWarning import ContinueWarning @@ -40,7 +40,7 @@ class DiffractionCalibrationRequest(BaseModel, extra="forbid"): continueFlags: Optional[ContinueWarning.Type] = ContinueWarning.Type.UNSET - startingTableVersion: int = VERSION_DEFAULT + startingTableVersion: Version = VersionState.DEFAULT @field_validator("fwhmMultipliers", mode="before") @classmethod diff --git a/src/snapred/backend/dao/request/LoadCalibrationRecordRequest.py b/src/snapred/backend/dao/request/LoadCalibrationRecordRequest.py new file mode 100644 index 000000000..ab7f0c180 --- /dev/null +++ b/src/snapred/backend/dao/request/LoadCalibrationRecordRequest.py @@ -0,0 +1,9 @@ +from pydantic import BaseModel + +from snapred.backend.dao import RunConfig +from snapred.backend.dao.indexing.Versioning import Version + + +class LoadCalibrationRecordRequest(BaseModel): + runConfig: RunConfig + version: Version diff --git a/src/snapred/backend/data/GroceryService.py b/src/snapred/backend/data/GroceryService.py index 2046eceb6..d279c825b 100644 --- a/src/snapred/backend/data/GroceryService.py +++ b/src/snapred/backend/data/GroceryService.py @@ -14,7 +14,7 @@ ) from pydantic import validate_call -from snapred.backend.dao.indexing.Versioning import VERSION_START, VersionState +from snapred.backend.dao.indexing.Versioning import VERSION_START, Version, VersionState from snapred.backend.dao.ingredients import GroceryListItem from snapred.backend.dao.state import DetectorState from snapred.backend.dao.WorkspaceMetadata import WorkspaceMetadata @@ -333,7 +333,7 @@ def createDiffcalTableWorkspaceName( self, runNumber: str, useLiteMode: bool, # noqa: ARG002 - version: Optional[int], + version: Optional[Version], ) -> WorkspaceName: """ NOTE: This method will IGNORE runNumber if the provided version is VersionState.DEFAULT diff --git a/src/snapred/backend/service/CalibrationService.py b/src/snapred/backend/service/CalibrationService.py index db113e383..b52be079d 100644 --- a/src/snapred/backend/service/CalibrationService.py +++ b/src/snapred/backend/service/CalibrationService.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List import pydantic @@ -9,6 +9,7 @@ FocusGroupMetric, ) from snapred.backend.dao.indexing import IndexEntry +from snapred.backend.dao.indexing.Versioning import VERSION_START, VersionState from snapred.backend.dao.ingredients import ( CalculateDiffCalResidualIngredients, CalibrationMetricsWorkspaceIngredients, @@ -29,6 +30,7 @@ FocusSpectraRequest, HasStateRequest, InitializeStateRequest, + LoadCalibrationRecordRequest, MatchRunsRequest, SimpleDiffCalRequest, ) @@ -134,6 +136,12 @@ def prepDiffractionCalibrationIngredients( @FromString def fetchDiffractionCalibrationGroceries(self, request: DiffractionCalibrationRequest) -> Dict[str, str]: # groceries + + # TODO: It would be nice for groceryclerk to be smart enough to flatten versions + # However I will save that scope for another time + if request.startingTableVersion == VersionState.DEFAULT: + request.startingTableVersion = VERSION_START + self.groceryClerk.name("inputWorkspace").neutron(request.runNumber).useLiteMode(request.useLiteMode).add() self.groceryClerk.name("groupingWorkspace").fromRun(request.runNumber).grouping( request.focusGroup.name @@ -360,10 +368,12 @@ def save(self, request: CalibrationExportRequest): self.dataExportService.exportCalibrationWorkspaces(record) @FromString - def load(self, run: RunConfig, version: Optional[int] = None): + def load(self, request: LoadCalibrationRecordRequest): """ If no version is given, will load the latest version applicable to the run number """ + run = request.runConfig + version = request.version return self.dataFactoryService.getCalibrationRecord(run.runNumber, run.useLiteMode, version) def matchRunsToCalibrationVersions(self, request: MatchRunsRequest) -> Dict[str, Any]: diff --git a/tests/integration/test_versions_in_order.py b/tests/integration/test_versions_in_order.py index 1a75b514f..53633127c 100644 --- a/tests/integration/test_versions_in_order.py +++ b/tests/integration/test_versions_in_order.py @@ -33,7 +33,7 @@ LoadEmptyInstrument, mtd, ) -from snapred.backend.dao.calibration.CalibrationRecord import CalibrationRecord +from snapred.backend.dao.calibration.CalibrationRecord import CalibrationDefaultRecord, CalibrationRecord from snapred.backend.dao.indexing.IndexEntry import IndexEntry from snapred.backend.dao.indexing.Versioning import VERSION_START, VersionState from snapred.backend.dao.SNAPRequest import SNAPRequest @@ -111,7 +111,7 @@ def _writeDefaultDiffCalTable(self, runNumber: str, useLiteMode: bool): """ Note this replicates the original in every respect, except using the ImitationGroceryService """ - version = VersionState.DEFAULT + version = VERSION_START grocer = ImitationGroceryService() outWS = grocer.fetchDefaultDiffCalTable(runNumber, useLiteMode, version) filename = Path(outWS + ".h5") @@ -295,7 +295,7 @@ def test_calibration_versioning(self): # ensure the new state has grouping map, calibration state, and default diffcal table diffCalTableName = wng.diffCalTable().runNumber("default").version(VersionState.DEFAULT).build() assert self.localDataService._groupingMapPath(self.stateId).exists() - versionDir = wnvf.pathVersion(VersionState.DEFAULT) + versionDir = wnvf.pathVersion(VERSION_START) assert Path(self.stateRoot, "lite", "diffraction", versionDir, "CalibrationParameters.json").exists() assert Path(self.stateRoot, "native", "diffraction", versionDir, "CalibrationParameters.json").exists() assert Path(self.stateRoot, "lite", "diffraction", versionDir, diffCalTableName + ".h5").exists() @@ -305,22 +305,22 @@ def test_calibration_versioning(self): assert [] == self.get_index() # assert the current diffcal version is the default, and the next is the start - assert self.indexer.currentVersion() == VersionState.DEFAULT - assert self.indexer.latestApplicableVersion(self.runNumber) == VersionState.DEFAULT - assert self.indexer.nextVersion() == VERSION_START + assert self.indexer.currentVersion() == VERSION_START + assert self.indexer.latestApplicableVersion(self.runNumber) == VERSION_START + assert self.indexer.nextVersion() == VERSION_START + 1 # run diffraction calibration for the first time, and save res = self.run_diffcal() - self.save_diffcal(res, version=None) + self.save_diffcal(res, version=VersionState.NEXT) # ensure things saved correctly - self.assert_diffcal_saved(VERSION_START) + self.assert_diffcal_saved(VERSION_START + 1) assert len(self.get_index()) == 1 # run diffraction calibration for a second time, and save res = self.run_diffcal() - self.save_diffcal(res, version=None) - self.assert_diffcal_saved(VERSION_START + 1) + self.save_diffcal(res, version=VersionState.NEXT) + self.assert_diffcal_saved(VERSION_START + 2) assert len(self.get_index()) == 2 # now save at version 7 @@ -333,7 +333,7 @@ def test_calibration_versioning(self): # now save at next version -- will be 8 version = 8 res = self.run_diffcal() - self.save_diffcal(res, version=None) # NOTE using None points it to next version + self.save_diffcal(res, version=VersionState.NEXT) self.assert_diffcal_saved(version) assert len(self.get_index()) == 4 @@ -375,7 +375,7 @@ def run_diffcal(self): assert response.code <= ResponseCode.MAX_OK return response.data - def save_diffcal(self, res, version=None): + def save_diffcal(self, res, version=VersionState.NEXT): # send a request through interface controller to save the diffcal results # needs the list of output workspaces, and may take an optional version # create an export request using an existing record as a basis @@ -409,7 +409,7 @@ def save_diffcal(self, res, version=None): "createIndexEntryRequest": createIndexEntryRequest, "createRecordRequest": createRecordRequest, } - request = SNAPRequest(path="calibration/save", payload=json.dumps(payload)) + request = SNAPRequest(path="calibration/save", payload=json.dumps(payload, default=str)) response = self.api.executeRequest(request) assert response.code <= ResponseCode.MAX_OK return response.data @@ -419,13 +419,19 @@ def assert_diffcal_saved(self, version): assert self.indexer.versionPath(version).exists() assert self.indexer.recordPath(version).exists() assert self.indexer.parametersPath(version).exists() - savedRecord = parse_file_as(CalibrationRecord, self.indexer.recordPath(version)) + savedRecord = None + if version == VERSION_START: + savedRecord = parse_file_as(CalibrationDefaultRecord, self.indexer.recordPath(version)) + else: + savedRecord = parse_file_as(CalibrationRecord, self.indexer.recordPath(version)) + assert savedRecord.version == version assert savedRecord.calculationParameters.version == version # make sure all workspaces exist workspaces = savedRecord.workspaces - assert (self.indexer.versionPath(version) / (workspaces[wngt.DIFFCAL_OUTPUT][0] + ".nxs.h5")).exists() - assert (self.indexer.versionPath(version) / (workspaces[wngt.DIFFCAL_DIAG][0] + ".nxs.h5")).exists() + if not version == VERSION_START: + assert (self.indexer.versionPath(version) / (workspaces[wngt.DIFFCAL_OUTPUT][0] + ".nxs.h5")).exists() + assert (self.indexer.versionPath(version) / (workspaces[wngt.DIFFCAL_DIAG][0] + ".nxs.h5")).exists() assert (self.indexer.versionPath(version) / (workspaces[wngt.DIFFCAL_TABLE][0] + ".h5")).exists() # assert this version is in the index index = self.indexer.readIndex() @@ -437,7 +443,7 @@ def assert_diffcal_saved(self, version): assert self.indexer.latestApplicableVersion(self.runNumber) == version assert self.indexer.nextVersion() == version + 1 # load the previous calibration and verify equality - runConfig = {"runNumber": self.runNumber, "useLiteMode": self.useLiteMode} + runConfig = {"runConfig": {"runNumber": self.runNumber, "useLiteMode": self.useLiteMode}, "version": version} request = SNAPRequest(path="calibration/load", payload=json.dumps(runConfig)) response = self.api.executeRequest(request) assert response.code <= ResponseCode.MAX_OK diff --git a/tests/util_tests/test_state_helpers.py b/tests/util_tests/test_state_helpers.py index c99932d00..6aed3315e 100644 --- a/tests/util_tests/test_state_helpers.py +++ b/tests/util_tests/test_state_helpers.py @@ -5,7 +5,7 @@ from shutil import rmtree import pytest -from snapred.backend.dao.indexing.Versioning import VersionState +from snapred.backend.dao.indexing.Versioning import VERSION_START from snapred.backend.data.LocalDataService import LocalDataService from snapred.meta.Config import Config from snapred.meta.mantid.WorkspaceNameGenerator import ValueFormatter as wnvf @@ -67,7 +67,7 @@ def test_state_root_override_enter( assert Path(stateRootPath) == expectedStateRootPath assert Path(stateRootPath).exists() assert Path(stateRootPath).joinpath("groupingMap.json").exists() - versionString = wnvf.pathVersion(VersionState.DEFAULT) + versionString = wnvf.pathVersion(VERSION_START) assert (Path(stateRootPath) / "lite" / "diffraction" / versionString / "CalibrationParameters.json").exists()