Skip to content

Commit

Permalink
fix failing unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
walshmm committed Dec 9, 2024
1 parent 611c6db commit a0b2a81
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 12 deletions.
4 changes: 3 additions & 1 deletion src/snapred/backend/data/LocalDataService.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,8 +850,10 @@ def readCalibrationState(self, runId: str, useLiteMode: bool, version: Optional[
return parameters

@validate_call
def readNormalizationState(self, runId: str, useLiteMode: bool, version: Optional[Version] = VersionState.LATEST):
def readNormalizationState(self, runId: str, useLiteMode: bool, version: Version = VersionState.LATEST):
indexer = self.normalizationIndexer(runId, useLiteMode)
if version is VersionState.LATEST:
version = indexer.latestApplicableVersion(runId)
return indexer.readParameters(version)

def writeCalibrationState(self, calibration: Calibration):
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/backend/data/test_DataFactoryService.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,12 @@ def setUpClass(cls):
cls.mockLookupService.calibrationIndexer.return_value = mock.Mock(
versionPath=mock.Mock(side_effect=lambda *x: cls.expected(cls, "Calibration", *x)),
getIndex=mock.Mock(return_value=[cls.expected(cls, "Calibration")]),
getLatestApplicableVersion=mock.Mock(side_effect=lambda *x: cls.expected(cls, "Calibration", *x)),
latestApplicableVersion=mock.Mock(side_effect=lambda *x: cls.expected(cls, "Calibration", *x)),
)
cls.mockLookupService.normalizationIndexer.return_value = mock.Mock(
versionPath=mock.Mock(side_effect=lambda *x: cls.expected(cls, "Normalization", *x)),
getIndex=mock.Mock(return_value=[cls.expected(cls, "Normalization")]),
getLatestApplicableVersion=mock.Mock(side_effect=lambda *x: cls.expected(cls, "Normalization", *x)),
latestApplicableVersion=mock.Mock(side_effect=lambda *x: cls.expected(cls, "Normalization", *x)),
)

def setUp(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/backend/data/test_Indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from unittest import mock

import pytest
from util.dao import DAOFactory

from snapred.backend.dao.calibration.Calibration import Calibration
from snapred.backend.dao.calibration.CalibrationRecord import CalibrationRecord
Expand Down Expand Up @@ -275,7 +276,6 @@ def test_flattenVersion(self):
indexer.currentVersion = lambda: 3
indexer.nextVersion = lambda: 4
assert indexer._flattenVersion(VersionState.DEFAULT) == indexer.defaultVersion()
assert indexer._flattenVersion(VersionState.LATEST) == indexer.currentVersion()
assert indexer._flattenVersion(VersionState.NEXT) == indexer.nextVersion()
assert indexer._flattenVersion(3) == 3

Expand Down
28 changes: 20 additions & 8 deletions tests/unit/backend/data/test_LocalDataService.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,11 @@ def test_readStateConfig_default():
indexer = localDataService.calibrationIndexer("57514", True)
parameters = DAOFactory.calibrationParameters("57514", True, indexer.defaultVersion())
tmpRoot.saveObjectAt(groupingMap, localDataService._groupingMapPath(tmpRoot.stateId))
tmpRoot.saveObjectAt(parameters, indexer.parametersPath(VersionState.DEFAULT))
indexer.index = {VersionState.DEFAULT: mock.Mock()} # NOTE manually update the Indexer
tmpRoot.saveObjectAt(parameters, indexer.parametersPath(indexer.defaultVersion()))

indexer.index = {
VersionState.DEFAULT: mock.MagicMock(appliesTo="57514", version=indexer.defaultVersion())
} # NOTE manually update the Indexer
actual = localDataService.readStateConfig("57514", True)
assert actual is not None
assert actual.stateId == DAOFactory.magical_state_id
Expand All @@ -332,7 +335,9 @@ def test_readStateConfig_previous():
indexer = localDataService.calibrationIndexer("57514", True)
tmpRoot.saveObjectAt(groupingMap, localDataService._groupingMapPath(tmpRoot.stateId))
tmpRoot.saveObjectAt(parameters, indexer.parametersPath(version))
indexer.index = {version: mock.Mock()} # NOTE manually update the Indexer
indexer.index = {
version: mock.MagicMock(appliesTo="57514", version=version)
} # NOTE manually update the Indexer
actual = localDataService.readStateConfig("57514", True)
assert actual is not None
assert actual.stateId == DAOFactory.magical_state_id
Expand All @@ -348,7 +353,9 @@ def test_readStateConfig_attaches_grouping_map():
indexer = localDataService.calibrationIndexer("57514", True)
tmpRoot.saveObjectAt(groupingMap, localDataService._groupingMapPath(tmpRoot.stateId))
tmpRoot.saveObjectAt(parameters, indexer.parametersPath(version))
indexer.index = {version: mock.Mock()} # NOTE manually update the Indexer
indexer.index = {
version: mock.MagicMock(appliesTo="57514", version=version)
} # NOTE manually update the Indexer
actual = localDataService.readStateConfig("57514", True)
expectedMap = DAOFactory.groupingMap_SNAP()
assert actual.groupingMap == expectedMap
Expand All @@ -365,7 +372,9 @@ def test_readStateConfig_invalid_grouping_map():
indexer = localDataService.calibrationIndexer("57514", True)
tmpRoot.saveObjectAt(groupingMap, localDataService._groupingMapPath(tmpRoot.stateId))
tmpRoot.saveObjectAt(parameters, indexer.parametersPath(version))
indexer.index = {version: mock.Mock()} # NOTE manually update the Indexer
indexer.index = {
version: mock.MagicMock(appliesTo="57514", version=version)
} # NOTE manually update the Indexer
# 'GroupingMap.defaultStateId' is _not_ a valid grouping-map 'stateId' for an existing `StateConfig`.
with pytest.raises( # noqa: PT012
RuntimeError,
Expand All @@ -383,7 +392,9 @@ def test_readStateConfig_calls_prepareStateRoot():
with state_root_redirect(localDataService, stateId=expected.instrumentState.id.hex) as tmpRoot:
indexer = localDataService.calibrationIndexer("57514", True)
tmpRoot.saveObjectAt(expected, indexer.parametersPath(version))
indexer.index = {version: mock.Mock()} # NOTE manually update the Indexer
indexer.index = {
version: mock.MagicMock(appliesTo="57514", version=version)
} # NOTE manually update the Indexer
assert not localDataService._groupingMapPath(tmpRoot.stateId).exists()
localDataService._prepareStateRoot = mock.Mock(
side_effect=lambda x: tmpRoot.saveObjectAt( # noqa ARG005
Expand Down Expand Up @@ -720,7 +731,8 @@ def test_write_model_pretty_StateConfig_excludes_grouping_map():
# move the calculation parameters into correct folder
indexer = localDataService.calibrationIndexer("57514", True)
indexer.writeParameters(DAOFactory.calibrationParameters("57514", True, indexer.defaultVersion()))
indexer.index = {indexer.defaultVersion(): mock.Mock()}
indexer.index = {indexer.defaultVersion(): mock.MagicMock(appliesTo="57514", version=indexer.defaultVersion())}

# move the grouping map into correct folder
write_model_pretty(DAOFactory.groupingMap_SNAP(), localDataService._groupingMapPath(tmpRoot.stateId))

Expand Down Expand Up @@ -2120,7 +2132,7 @@ def test_readWriteNormalizationState():

ans = localDataService.readNormalizationState(runNumber, True, VersionState.LATEST)
assert ans == mockNormalizationIndexer.readParameters.return_value
mockNormalizationIndexer.readParameters.assert_called_once_with(VersionState.LATEST)
mockNormalizationIndexer.readParameters.assert_called_once_with(1)


def test_readDetectorState():
Expand Down

0 comments on commit a0b2a81

Please sign in to comment.