Skip to content

Commit

Permalink
fix versioning integration checks
Browse files Browse the repository at this point in the history
  • Loading branch information
walshmm committed Nov 25, 2024
1 parent ba4ab2a commit 74c2c43
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from pydantic import BaseModel, field_validator

from snapred.backend.dao.indexing.Versioning import VERSION_DEFAULT
from snapred.backend.dao.indexing.Versioning import Version, VersionState
from snapred.backend.dao.Limit import Pair
from snapred.backend.dao.state.FocusGroup import FocusGroup
from snapred.backend.error.ContinueWarning import ContinueWarning
Expand Down Expand Up @@ -40,7 +40,7 @@ class DiffractionCalibrationRequest(BaseModel, extra="forbid"):

continueFlags: Optional[ContinueWarning.Type] = ContinueWarning.Type.UNSET

startingTableVersion: int = VERSION_DEFAULT
startingTableVersion: Version = VersionState.DEFAULT

@field_validator("fwhmMultipliers", mode="before")
@classmethod
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from pydantic import BaseModel

from snapred.backend.dao import RunConfig
from snapred.backend.dao.indexing.Versioning import Version


class LoadCalibrationRecordRequest(BaseModel):
runConfig: RunConfig
version: Version
4 changes: 2 additions & 2 deletions src/snapred/backend/data/GroceryService.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
)
from pydantic import validate_call

from snapred.backend.dao.indexing.Versioning import VERSION_START, VersionState
from snapred.backend.dao.indexing.Versioning import VERSION_START, Version, VersionState
from snapred.backend.dao.ingredients import GroceryListItem
from snapred.backend.dao.state import DetectorState
from snapred.backend.dao.WorkspaceMetadata import WorkspaceMetadata
Expand Down Expand Up @@ -333,7 +333,7 @@ def createDiffcalTableWorkspaceName(
self,
runNumber: str,
useLiteMode: bool, # noqa: ARG002
version: Optional[int],
version: Optional[Version],
) -> WorkspaceName:
"""
NOTE: This method will IGNORE runNumber if the provided version is VersionState.DEFAULT
Expand Down
14 changes: 12 additions & 2 deletions src/snapred/backend/service/CalibrationService.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List

import pydantic

Expand All @@ -9,6 +9,7 @@
FocusGroupMetric,
)
from snapred.backend.dao.indexing import IndexEntry
from snapred.backend.dao.indexing.Versioning import VERSION_START, VersionState
from snapred.backend.dao.ingredients import (
CalculateDiffCalResidualIngredients,
CalibrationMetricsWorkspaceIngredients,
Expand All @@ -29,6 +30,7 @@
FocusSpectraRequest,
HasStateRequest,
InitializeStateRequest,
LoadCalibrationRecordRequest,
MatchRunsRequest,
SimpleDiffCalRequest,
)
Expand Down Expand Up @@ -134,6 +136,12 @@ def prepDiffractionCalibrationIngredients(
@FromString
def fetchDiffractionCalibrationGroceries(self, request: DiffractionCalibrationRequest) -> Dict[str, str]:
# groceries

# TODO: It would be nice for groceryclerk to be smart enough to flatten versions
# However I will save that scope for another time
if request.startingTableVersion == VersionState.DEFAULT:
request.startingTableVersion = VERSION_START

self.groceryClerk.name("inputWorkspace").neutron(request.runNumber).useLiteMode(request.useLiteMode).add()
self.groceryClerk.name("groupingWorkspace").fromRun(request.runNumber).grouping(
request.focusGroup.name
Expand Down Expand Up @@ -360,10 +368,12 @@ def save(self, request: CalibrationExportRequest):
self.dataExportService.exportCalibrationWorkspaces(record)

@FromString
def load(self, run: RunConfig, version: Optional[int] = None):
def load(self, request: LoadCalibrationRecordRequest):
"""
If no version is given, will load the latest version applicable to the run number
"""
run = request.runConfig
version = request.version
return self.dataFactoryService.getCalibrationRecord(run.runNumber, run.useLiteMode, version)

def matchRunsToCalibrationVersions(self, request: MatchRunsRequest) -> Dict[str, Any]:
Expand Down
40 changes: 23 additions & 17 deletions tests/integration/test_versions_in_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
LoadEmptyInstrument,
mtd,
)
from snapred.backend.dao.calibration.CalibrationRecord import CalibrationRecord
from snapred.backend.dao.calibration.CalibrationRecord import CalibrationDefaultRecord, CalibrationRecord
from snapred.backend.dao.indexing.IndexEntry import IndexEntry
from snapred.backend.dao.indexing.Versioning import VERSION_START, VersionState
from snapred.backend.dao.SNAPRequest import SNAPRequest
Expand Down Expand Up @@ -111,7 +111,7 @@ def _writeDefaultDiffCalTable(self, runNumber: str, useLiteMode: bool):
"""
Note this replicates the original in every respect, except using the ImitationGroceryService
"""
version = VersionState.DEFAULT
version = VERSION_START
grocer = ImitationGroceryService()
outWS = grocer.fetchDefaultDiffCalTable(runNumber, useLiteMode, version)
filename = Path(outWS + ".h5")
Expand Down Expand Up @@ -295,7 +295,7 @@ def test_calibration_versioning(self):
# ensure the new state has grouping map, calibration state, and default diffcal table
diffCalTableName = wng.diffCalTable().runNumber("default").version(VersionState.DEFAULT).build()
assert self.localDataService._groupingMapPath(self.stateId).exists()
versionDir = wnvf.pathVersion(VersionState.DEFAULT)
versionDir = wnvf.pathVersion(VERSION_START)
assert Path(self.stateRoot, "lite", "diffraction", versionDir, "CalibrationParameters.json").exists()
assert Path(self.stateRoot, "native", "diffraction", versionDir, "CalibrationParameters.json").exists()
assert Path(self.stateRoot, "lite", "diffraction", versionDir, diffCalTableName + ".h5").exists()
Expand All @@ -305,22 +305,22 @@ def test_calibration_versioning(self):
assert [] == self.get_index()

# assert the current diffcal version is the default, and the next is the start
assert self.indexer.currentVersion() == VersionState.DEFAULT
assert self.indexer.latestApplicableVersion(self.runNumber) == VersionState.DEFAULT
assert self.indexer.nextVersion() == VERSION_START
assert self.indexer.currentVersion() == VERSION_START
assert self.indexer.latestApplicableVersion(self.runNumber) == VERSION_START
assert self.indexer.nextVersion() == VERSION_START + 1

# run diffraction calibration for the first time, and save
res = self.run_diffcal()
self.save_diffcal(res, version=None)
self.save_diffcal(res, version=VersionState.NEXT)

# ensure things saved correctly
self.assert_diffcal_saved(VERSION_START)
self.assert_diffcal_saved(VERSION_START + 1)
assert len(self.get_index()) == 1

# run diffraction calibration for a second time, and save
res = self.run_diffcal()
self.save_diffcal(res, version=None)
self.assert_diffcal_saved(VERSION_START + 1)
self.save_diffcal(res, version=VersionState.NEXT)
self.assert_diffcal_saved(VERSION_START + 2)
assert len(self.get_index()) == 2

# now save at version 7
Expand All @@ -333,7 +333,7 @@ def test_calibration_versioning(self):
# now save at next version -- will be 8
version = 8
res = self.run_diffcal()
self.save_diffcal(res, version=None) # NOTE using None points it to next version
self.save_diffcal(res, version=VersionState.NEXT)
self.assert_diffcal_saved(version)
assert len(self.get_index()) == 4

Expand Down Expand Up @@ -375,7 +375,7 @@ def run_diffcal(self):
assert response.code <= ResponseCode.MAX_OK
return response.data

def save_diffcal(self, res, version=None):
def save_diffcal(self, res, version=VersionState.NEXT):
# send a request through interface controller to save the diffcal results
# needs the list of output workspaces, and may take an optional version
# create an export request using an existing record as a basis
Expand Down Expand Up @@ -409,7 +409,7 @@ def save_diffcal(self, res, version=None):
"createIndexEntryRequest": createIndexEntryRequest,
"createRecordRequest": createRecordRequest,
}
request = SNAPRequest(path="calibration/save", payload=json.dumps(payload))
request = SNAPRequest(path="calibration/save", payload=json.dumps(payload, default=str))
response = self.api.executeRequest(request)
assert response.code <= ResponseCode.MAX_OK
return response.data
Expand All @@ -419,13 +419,19 @@ def assert_diffcal_saved(self, version):
assert self.indexer.versionPath(version).exists()
assert self.indexer.recordPath(version).exists()
assert self.indexer.parametersPath(version).exists()
savedRecord = parse_file_as(CalibrationRecord, self.indexer.recordPath(version))
savedRecord = None
if version == VERSION_START:
savedRecord = parse_file_as(CalibrationDefaultRecord, self.indexer.recordPath(version))
else:
savedRecord = parse_file_as(CalibrationRecord, self.indexer.recordPath(version))

assert savedRecord.version == version
assert savedRecord.calculationParameters.version == version
# make sure all workspaces exist
workspaces = savedRecord.workspaces
assert (self.indexer.versionPath(version) / (workspaces[wngt.DIFFCAL_OUTPUT][0] + ".nxs.h5")).exists()
assert (self.indexer.versionPath(version) / (workspaces[wngt.DIFFCAL_DIAG][0] + ".nxs.h5")).exists()
if not version == VERSION_START:
assert (self.indexer.versionPath(version) / (workspaces[wngt.DIFFCAL_OUTPUT][0] + ".nxs.h5")).exists()
assert (self.indexer.versionPath(version) / (workspaces[wngt.DIFFCAL_DIAG][0] + ".nxs.h5")).exists()
assert (self.indexer.versionPath(version) / (workspaces[wngt.DIFFCAL_TABLE][0] + ".h5")).exists()
# assert this version is in the index
index = self.indexer.readIndex()
Expand All @@ -437,7 +443,7 @@ def assert_diffcal_saved(self, version):
assert self.indexer.latestApplicableVersion(self.runNumber) == version
assert self.indexer.nextVersion() == version + 1
# load the previous calibration and verify equality
runConfig = {"runNumber": self.runNumber, "useLiteMode": self.useLiteMode}
runConfig = {"runConfig": {"runNumber": self.runNumber, "useLiteMode": self.useLiteMode}, "version": version}
request = SNAPRequest(path="calibration/load", payload=json.dumps(runConfig))
response = self.api.executeRequest(request)
assert response.code <= ResponseCode.MAX_OK
Expand Down
4 changes: 2 additions & 2 deletions tests/util_tests/test_state_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from shutil import rmtree

import pytest
from snapred.backend.dao.indexing.Versioning import VersionState
from snapred.backend.dao.indexing.Versioning import VERSION_START
from snapred.backend.data.LocalDataService import LocalDataService
from snapred.meta.Config import Config
from snapred.meta.mantid.WorkspaceNameGenerator import ValueFormatter as wnvf
Expand Down Expand Up @@ -67,7 +67,7 @@ def test_state_root_override_enter(
assert Path(stateRootPath) == expectedStateRootPath
assert Path(stateRootPath).exists()
assert Path(stateRootPath).joinpath("groupingMap.json").exists()
versionString = wnvf.pathVersion(VersionState.DEFAULT)
versionString = wnvf.pathVersion(VERSION_START)
assert (Path(stateRootPath) / "lite" / "diffraction" / versionString / "CalibrationParameters.json").exists()


Expand Down

0 comments on commit 74c2c43

Please sign in to comment.