Skip to content

Commit

Permalink
fix remaining unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
walshmm committed Nov 25, 2024
1 parent 7993b55 commit dcd2a8f
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 18 deletions.
20 changes: 13 additions & 7 deletions src/snapred/backend/data/Indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,18 +340,24 @@ def readRecord(self, version: int) -> Record:
return record

def _flattenVersion(self, version: Version):
flattenedVersion = None
if version == VersionState.DEFAULT:
return self.defaultVersion()
flattenedVersion = self.defaultVersion()
elif version == VersionState.LATEST:
return self.currentVersion()
flattenedVersion = self.currentVersion()
elif version == VersionState.NEXT:
return self.nextVersion()
flattenedVersion = self.nextVersion()
elif isinstance(version, int):
return version
flattenedVersion = version
else:
if version is None:
raise ValueError("Version must be specified," "likely no available versions found during lookup.")
raise ValueError(f"Version must be an int or {VersionState.values()}")
raise ValueError(f"Version must be an int or {VersionState.values()}, not {version}")

if flattenedVersion is None:
raise ValueError(
f"No available versions found during lookup using: "
f"v={version}, index={self.index}, dir={self.dirVersions}"
)
return flattenedVersion

def versionExists(self, version: Version):
return self._flattenVersion(version) in self.index
Expand Down
8 changes: 3 additions & 5 deletions src/snapred/backend/data/LocalDataService.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,14 +446,15 @@ def _statePathForWorkflow(self, stateId: str, useLiteMode: bool, indexerType: In
raise NotImplementedError(f"Indexer of type {indexerType} is not supported by the LocalDataService")
return path

# @lru_cache
@lru_cache
def _indexer(self, stateId: str, useLiteMode: bool, indexerType: IndexerType):
path = self._statePathForWorkflow(stateId, useLiteMode, indexerType)
return Indexer(indexerType=indexerType, directory=path)

def indexer(self, runNumber: str, useLiteMode: bool, indexerType: IndexerType):
stateId, _ = self.generateStateId(runNumber)
return self._indexer(stateId, useLiteMode, indexerType)
indexer = self._indexer(stateId, useLiteMode, indexerType)
return indexer

def calibrationIndexer(self, runId: str, useLiteMode: bool):
return self.indexer(runId, useLiteMode, IndexerType.CALIBRATION)
Expand Down Expand Up @@ -797,9 +798,6 @@ def readCalibrationState(self, runId: str, useLiteMode: bool, version: Optional[
@validate_call
def readNormalizationState(self, runId: str, useLiteMode: bool, version: Optional[Version] = VersionState.LATEST):
indexer = self.normalizationIndexer(runId, useLiteMode)
# NOTE if we prefer latest version in index, uncomment below
if version is VersionState.LATEST:
version = indexer.latestApplicableVersion(runId)
return indexer.readParameters(version)

def writeCalibrationState(self, calibration: Calibration):
Expand Down
18 changes: 12 additions & 6 deletions tests/unit/backend/data/test_LocalDataService.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,10 @@ def do_test_read_state_no_version(workflow: Literal["Calibration", "Normalizatio
indexer.index = {
currentVersion: mock.MagicMock(appliesTo="123", version=currentVersion)
} # NOTE manually update indexer
actualState = getattr(localDataService, f"read{workflow}State")("123", useLiteMode) # NOTE no version
indexer.dirVersions = [currentVersion] # NOTE manually update indexer
actualState = getattr(localDataService, f"read{workflow}State")(
"123", useLiteMode, VersionState.LATEST
) # NOTE no version
assert actualState == expectedState


Expand Down Expand Up @@ -1175,7 +1178,8 @@ def test_createCalibrationIndexEntry():
request.version = VersionState.NEXT
localDataService.calibrationIndexer(request.runNumber, request.useLiteMode)
ans = localDataService.createCalibrationIndexEntry(request)
assert ans.version == VersionState.NEXT
# Set to next version, which on the first call should be the start version
assert ans.version == VERSION_START


def test_createCalibrationRecord():
Expand All @@ -1192,7 +1196,8 @@ def test_createCalibrationRecord():
request.version = VersionState.NEXT
localDataService.calibrationIndexer(request.runNumber, request.useLiteMode)
ans = localDataService.createCalibrationRecord(request)
assert ans.version == VersionState.NEXT
# Set to next version, which on the first call should be the start version
assert ans.version == VERSION_START


def test_readCalibrationRecord_with_version():
Expand Down Expand Up @@ -1319,7 +1324,8 @@ def test_createNormalizationIndexEntry():
request.version = None
localDataService.normalizationIndexer(request.runNumber, request.useLiteMode)
ans = localDataService.createNormalizationIndexEntry(request)
assert ans.version == VersionState.NEXT
# Set to next version, which on the first call should be the start version
assert ans.version == VERSION_START


def test_createNormalizationRecord():
Expand All @@ -1335,7 +1341,7 @@ def test_createNormalizationRecord():

request.version = VersionState.NEXT
ans = localDataService.createNormalizationRecord(request)
assert ans.version == VersionState.NEXT
assert ans.version == VERSION_START


def test_readNormalizationRecord_with_version():
Expand Down Expand Up @@ -1958,7 +1964,7 @@ def test_readWriteNormalizationState():

ans = localDataService.readNormalizationState(runNumber, True, VersionState.LATEST)
assert ans == mockNormalizationIndexer.readParameters.return_value
mockNormalizationIndexer.readParameters.assert_called_once_with(1)
mockNormalizationIndexer.readParameters.assert_called_once_with(VersionState.LATEST)


def test_readDetectorState():
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/backend/service/test_SousChef.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def test_prepCalibration_userFWHM(self):
self.instance._getThresholdFromCalibrantSample = mock.Mock(return_value=0.5)
fakeLeft = 116
fakeRight = 17
self.ingredients.model_config["validate_assignment"] = False
self.ingredients.fwhmMultipliers = mock.Mock(left=fakeLeft, right=fakeRight)
self.instance.prepCalibrantSample = mock.Mock()

Expand Down Expand Up @@ -305,6 +306,7 @@ def test_prepPeakIngredients(self, PeakIngredients):
self.instance.prepPixelGroup = mock.Mock()
self.instance.prepCalibrantSample = mock.Mock()
calibrantSample = self.instance.prepCalibrantSample()
self.ingredients.model_config["validate_assignment"] = False
self.ingredients.peakIntensityThreshold = calibrantSample.peakIntensityFractionThreshold

result = self.instance.prepPeakIngredients(self.ingredients)
Expand Down Expand Up @@ -463,6 +465,7 @@ def test_prepReductionIngredients(self, ReductionIngredients, mockOS): # noqa:
self.instance.dataFactoryService.getCalibrationRecord = mock.Mock(return_value=record)

ingredients_ = self.ingredients.model_copy()
ingredients_.model_config["validate_assignment"] = False
# ... from calibration record:
ingredients_.calibrantSamplePath = calibrationCalibrantSamplePath
ingredients_.cifPath = self.instance.dataFactoryService.getCifFilePath.return_value
Expand Down

0 comments on commit dcd2a8f

Please sign in to comment.