Skip to content

Commit

Permalink
changes found during manual testing of workflows
Browse files Browse the repository at this point in the history
  • Loading branch information
walshmm committed Nov 25, 2024
1 parent 8383e0d commit 7993b55
Show file tree
Hide file tree
Showing 16 changed files with 103 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from snapred.backend.dao.calibration.Calibration import Calibration
from snapred.backend.dao.calibration.FocusGroupMetric import FocusGroupMetric
from snapred.backend.dao.CrystallographicInfo import CrystallographicInfo
from snapred.backend.dao.indexing.Versioning import Version, VersionState
from snapred.backend.dao.state.PixelGroup import PixelGroup
from snapred.meta.mantid.WorkspaceNameGenerator import WorkspaceName, WorkspaceType

Expand All @@ -18,7 +19,7 @@ class CreateCalibrationRecordRequest(BaseModel, extra="forbid"):

runNumber: str
useLiteMode: bool
version: Optional[int] = None
version: Version = VersionState.NEXT
calculationParameters: Calibration
crystalInfo: CrystallographicInfo
pixelGroups: Optional[List[PixelGroup]] = None
Expand Down
6 changes: 3 additions & 3 deletions src/snapred/backend/dao/request/CreateIndexEntryRequest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional

from pydantic import BaseModel

from snapred.backend.dao.indexing.Versioning import Version, VersionState


class CreateIndexEntryRequest(BaseModel):
"""
Expand All @@ -10,7 +10,7 @@ class CreateIndexEntryRequest(BaseModel):

runNumber: str
useLiteMode: bool
version: Optional[int] = None
version: Version = VersionState.NEXT
comments: str
author: str
appliesTo: str
11 changes: 8 additions & 3 deletions src/snapred/backend/dao/request/FarmFreshIngredients.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

from pydantic import BaseModel, ConfigDict, ValidationError, field_validator, model_validator

from snapred.backend.dao.indexing.Versioning import Version, VersionState
from snapred.backend.dao.Limit import Limit, Pair
from snapred.backend.dao.state import FocusGroup
from snapred.meta.Config import Config
from snapred.meta.mantid.AllowedPeakTypes import SymmetricPeakEnum

# TODO: this declaration is duplicated in `ReductionRequest`.
Versions = NamedTuple("Versions", [("calibration", Optional[int]), ("normalization", Optional[int])])
Versions = NamedTuple("Versions", [("calibration", Version), ("normalization", Version)])


class FarmFreshIngredients(BaseModel):
Expand All @@ -21,7 +22,7 @@ class FarmFreshIngredients(BaseModel):

runNumber: str

versions: Versions = Versions(None, None)
versions: Versions = Versions(VersionState.LATEST, VersionState.LATEST)

# allow 'versions' to be accessed as a single version,
# or, to be accessed ambiguously
Expand Down Expand Up @@ -83,6 +84,10 @@ def focusGroup(self, fg: FocusGroup):
def validate_versions(cls, v) -> Versions:
if not isinstance(v, Versions):
v = Versions(v)
if v.calibration is None:
raise ValueError("Calibration version must be specified")
if v.normalization is None:
raise ValueError("Normalization version must be specified")
return v

@field_validator("crystalDBounds", mode="before")
Expand Down Expand Up @@ -119,4 +124,4 @@ def validate_focusGroups(cls, v: Any):
del v["focusGroup"]
return v

model_config = ConfigDict(extra="forbid")
model_config = ConfigDict(extra="forbid", validate_assignment=True)
9 changes: 7 additions & 2 deletions src/snapred/backend/dao/request/ReductionRequest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

from pydantic import BaseModel, ConfigDict, field_validator

from snapred.backend.dao.indexing.Versioning import Version, VersionState
from snapred.backend.dao.ingredients import ArtificialNormalizationIngredients
from snapred.backend.dao.state.FocusGroup import FocusGroup
from snapred.backend.error.ContinueWarning import ContinueWarning
from snapred.meta.mantid.WorkspaceNameGenerator import WorkspaceName
from snapred.meta.mantid.WorkspaceNameGenerator import WorkspaceNameGenerator as wng

Versions = NamedTuple("Versions", [("calibration", Optional[int]), ("normalization", Optional[int])])
Versions = NamedTuple("Versions", [("calibration", Version), ("normalization", Version)])


class ReductionRequest(BaseModel):
Expand All @@ -22,7 +23,7 @@ class ReductionRequest(BaseModel):

# Calibration and normalization versions:
# `None` => <use latest version>
versions: Versions = Versions(None, None)
versions: Versions = Versions(VersionState.LATEST, VersionState.LATEST)

pixelMasks: List[WorkspaceName] = []
artificialNormalizationIngredients: Optional[ArtificialNormalizationIngredients] = None
Expand All @@ -37,6 +38,10 @@ def validate_versions(cls, v) -> Versions:
if not isinstance(v, Tuple):
raise ValueError("'versions' must be a tuple: '(<calibration version>, <normalization version>)'")
v = Versions(v)
if v.calibration is None:
raise ValueError("Calibration version must be specified")
if v.normalization is None:
raise ValueError("Normalization version must be specified")
return v

model_config = ConfigDict(
Expand Down
10 changes: 5 additions & 5 deletions src/snapred/backend/data/DataExportService.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import time
from pathlib import Path
from typing import Tuple
from typing import Optional, Tuple

from pydantic import validate_call

Expand Down Expand Up @@ -64,11 +64,11 @@ def exportCalibrationIndexEntry(self, entry: IndexEntry):
"""
self.dataService.writeCalibrationIndexEntry(entry)

def exportCalibrationRecord(self, record: CalibrationRecord):
def exportCalibrationRecord(self, record: CalibrationRecord, entry: Optional[IndexEntry] = None):
"""
Record must have correct version set.
"""
self.dataService.writeCalibrationRecord(record)
self.dataService.writeCalibrationRecord(record, entry)

def exportCalibrationWorkspaces(self, record: CalibrationRecord):
"""
Expand All @@ -94,11 +94,11 @@ def exportNormalizationIndexEntry(self, entry: IndexEntry):
"""
self.dataService.writeNormalizationIndexEntry(entry)

def exportNormalizationRecord(self, record: NormalizationRecord):
def exportNormalizationRecord(self, record: NormalizationRecord, entry: Optional[IndexEntry] = None):
"""
Record must have correct version set.
"""
self.dataService.writeNormalizationRecord(record)
self.dataService.writeNormalizationRecord(record, entry)

def exportNormalizationWorkspaces(self, record: NormalizationRecord):
"""
Expand Down
29 changes: 19 additions & 10 deletions src/snapred/backend/data/DataFactoryService.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import os
from pathlib import Path
from typing import Dict, List, Optional
from typing import Dict, List

from pydantic import validate_call

from snapred.backend.dao.calibration.CalibrationRecord import CalibrationRecord
from snapred.backend.dao.indexing.IndexEntry import IndexEntry
from snapred.backend.dao.indexing.Versioning import Version, VersionState
from snapred.backend.dao.InstrumentConfig import InstrumentConfig
from snapred.backend.dao.normalization.NormalizationRecord import NormalizationRecord
from snapred.backend.dao.reduction import ReductionRecord
Expand Down Expand Up @@ -81,7 +82,7 @@ def calibrationExists(self, runId: str, useLiteMode: bool):
return self.lookupService.calibrationExists(runId, useLiteMode)

@validate_call
def getCalibrationDataPath(self, runId: str, useLiteMode: bool, version: int):
def getCalibrationDataPath(self, runId: str, useLiteMode: bool, version: Version):
return self.lookupService.calibrationIndexer(runId, useLiteMode).versionPath(version)

def checkCalibrationStateExists(self, runId: str):
Expand All @@ -102,14 +103,18 @@ def getCalibrationIndex(self, runId: str, useLiteMode: bool) -> List[IndexEntry]
return self.lookupService.calibrationIndexer(runId, useLiteMode).getIndex()

@validate_call
def getCalibrationRecord(self, runId: str, useLiteMode: bool, version: Optional[int] = None) -> CalibrationRecord:
def getCalibrationRecord(
self, runId: str, useLiteMode: bool, version: Version = VersionState.LATEST
) -> CalibrationRecord:
"""
If no version is passed, will use the latest version applicable to runId
"""
if version is None:
raise ValueError("Version must be specified")
return self.lookupService.readCalibrationRecord(runId, useLiteMode, version)

@validate_call
def getCalibrationDataWorkspace(self, runId: str, useLiteMode: bool, version: int, name: str):
def getCalibrationDataWorkspace(self, runId: str, useLiteMode: bool, version: Version, name: str):
path = self.lookupService.calibrationIndexer(runId, useLiteMode).versionPath(version)
return self.groceryService.fetchWorkspace(os.path.join(path, name) + ".nxs", name)

Expand All @@ -123,7 +128,7 @@ def normalizationExists(self, runId: str, useLiteMode: bool):
return self.lookupService.normalizationExists(runId, useLiteMode)

@validate_call
def getNormalizationDataPath(self, runId: str, useLiteMode: bool, version: int):
def getNormalizationDataPath(self, runId: str, useLiteMode: bool, version: Version):
return self.lookupService.normalizationIndexer(runId, useLiteMode).versionPath(version)

def createNormalizationIndexEntry(self, request: NormalizationExportRequest) -> IndexEntry:
Expand All @@ -141,14 +146,16 @@ def getNormalizationIndex(self, runId: str, useLiteMode: bool) -> List[IndexEntr
return self.lookupService.normalizationIndexer(runId, useLiteMode).getIndex()

@validate_call
def getNormalizationRecord(self, runId: str, useLiteMode: bool, version: Optional[int] = None):
def getNormalizationRecord(
self, runId: str, useLiteMode: bool, version: Version = VersionState.LATEST
) -> NormalizationRecord:
"""
If no version is passed, will use the latest version applicable to runId
"""
return self.lookupService.readNormalizationRecord(runId, useLiteMode, version)

@validate_call
def getNormalizationDataWorkspace(self, runId: str, useLiteMode: bool, version: int, name: str):
def getNormalizationDataWorkspace(self, runId: str, useLiteMode: bool, version: Version, name: str):
path = self.getNormalizationDataPath(runId, useLiteMode, version)
return self.groceryService.fetchWorkspace(os.path.join(path, name) + ".nxs", name)

Expand Down Expand Up @@ -179,18 +186,20 @@ def getReductionState(self, runId: str, useLiteMode: bool) -> ReductionState:
return reductionState

@validate_call
def getReductionDataPath(self, runId: str, useLiteMode: bool, version: int) -> Path:
def getReductionDataPath(self, runId: str, useLiteMode: bool, version: Version) -> Path:
return self.lookupService._constructReductionDataPath(runId, useLiteMode, version)

@validate_call
def getReductionRecord(self, runId: str, useLiteMode: bool, version: Optional[int] = None) -> ReductionRecord:
def getReductionRecord(
self, runId: str, useLiteMode: bool, version: Version = VersionState.LATEST
) -> ReductionRecord:
"""
If no version is passed, will use the latest version applicable to runId
"""
return self.lookupService.readReductionRecord(runId, useLiteMode, version)

@validate_call
def getReductionData(self, runId: str, useLiteMode: bool, version: int) -> ReductionRecord:
def getReductionData(self, runId: str, useLiteMode: bool, version: Version) -> ReductionRecord:
return self.lookupService.readReductionData(runId, useLiteMode, version)

@validate_call
Expand Down
4 changes: 2 additions & 2 deletions src/snapred/backend/data/GroceryService.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
)
from pydantic import validate_call

from snapred.backend.dao.indexing.Versioning import VersionState
from snapred.backend.dao.indexing.Versioning import VERSION_START, VersionState
from snapred.backend.dao.ingredients import GroceryListItem
from snapred.backend.dao.state import DetectorState
from snapred.backend.dao.WorkspaceMetadata import WorkspaceMetadata
Expand Down Expand Up @@ -338,7 +338,7 @@ def createDiffcalTableWorkspaceName(
NOTE: This method will IGNORE runNumber if the provided version is VersionState.DEFAULT
"""
wsName = wng.diffCalTable().runNumber(runNumber).version(version).build()
if version == VersionState.DEFAULT:
if version in [VersionState.DEFAULT, VERSION_START]:
wsName = wsName = wng.diffCalTable().runNumber("default").version(VersionState.DEFAULT).build()
return wsName

Expand Down
23 changes: 16 additions & 7 deletions src/snapred/backend/data/Indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def readDirectoryList(self):
version = str(fname).split("_")[-1]
# Warning: order matters here:
# check VersionState.DEFAULT _before_ the `isdigit` check.
if version == VersionState.DEFAULT:
if version in [VersionState.DEFAULT, self.defaultVersion()]:
version = self.defaultVersion()
elif version.isdigit():
version = int(version)
Expand Down Expand Up @@ -173,7 +173,7 @@ def latestApplicableVersion(self, runNumber: str) -> int:
elif len(relevantEntries) == 1:
version = relevantEntries[0].version
else:
if VersionState.DEFAULT in self.index:
if self.defaultVersion() in self.index:
relevantEntries.remove(self.index[self.defaultVersion()])
version = relevantEntries[-1].version
return version
Expand Down Expand Up @@ -272,7 +272,7 @@ def getLatestApplicablePath(self, runNumber: str) -> Path:

def createIndexEntry(self, *, version, **other_arguments):
return IndexEntry(
version=version,
version=self._flattenVersion(version),
**other_arguments,
)

Expand Down Expand Up @@ -311,7 +311,7 @@ def addIndexEntry(self, entry: IndexEntry):

def createRecord(self, *, version, **other_arguments):
record = RECORD_TYPE[self.indexerType](
version=version,
version=self._flattenVersion(version),
**other_arguments,
)
record.calculationParameters.version = record.version
Expand Down Expand Up @@ -349,15 +349,24 @@ def _flattenVersion(self, version: Version):
elif isinstance(version, int):
return version
else:
if version is None:
raise ValueError("Version must be specified," "likely no available versions found during lookup.")
raise ValueError(f"Version must be an int or {VersionState.values()}")

def versionExists(self, version: Version):
return self._flattenVersion(version) in self.index

def writeNewRecord(self, record: Record, entry: IndexEntry):
"""
Coupled write of a record and an index entry.
As required for new records.
"""
if record.version in self.index:
if self.versionExists(record.version):
raise ValueError(f"Version {record.version} already exists in index, please write a new version.")

if entry.appliesTo is None:
entry.appliesTo = ">=" + record.runNumber

self.addIndexEntry(entry)
# make sure they flatten to the same value.
record.version = entry.version
Expand All @@ -370,7 +379,7 @@ def writeRecord(self, record: Record):
"""
record.version = self._flattenVersion(record.version)

if record.version not in self.index:
if not self.versionExists(record.version):
raise ValueError(f"Version {record.version} not found in index, please write an index entry first.")

filePath = self.recordPath(record.version)
Expand All @@ -384,7 +393,7 @@ def writeRecord(self, record: Record):

def createParameters(self, *, version, **other_arguments) -> CalculationParameters:
return PARAMS_TYPE[self.indexerType](
version=version,
version=self._flattenVersion(version),
**other_arguments,
)

Expand Down
Loading

0 comments on commit 7993b55

Please sign in to comment.