Skip to content

Commit

Permalink
refactor indexer to not allow possibility of misinterpreting None
Browse files Browse the repository at this point in the history
failing 40 test wip commit

all unit tests passing

changes found during manual testing of workflows

fix remaining unit tests

why does git think this file still uses VERSION_DEFAULT

add back VERSION_DEFAULT because of arbitrary ci failure?

fix versioning integration checks

migrate missed spots for datafactoryservice

correct version for integration test workspace name

fix wng?

disallow overwriting default calibration entry

disallow overwritting default calibration

add debug info for ci, because I cannot easily recreate this locally

fix integration tests, added neat summary in case they fail

move reduction completion summary note to correct spot

update tests and gitmodules

extend test coverage

catch a couple more lines

up that test coverage

change refspec?

respond to reece's comments
  • Loading branch information
walshmm committed Dec 9, 2024
1 parent 80ea2e0 commit 117b17c
Show file tree
Hide file tree
Showing 45 changed files with 837 additions and 683 deletions.
1 change: 1 addition & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
[submodule "tests/data/snapred-data"]
path = tests/data/snapred-data
url = https://code.ornl.gov/sns-hfir-scse/infrastructure/test-data/snapred-data.git
branch = main
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ markers = [
"integration: mark a test as an integration test",
"mount_snap: mark a test as using /SNS/SNAP/ data mount",
"golden_data(*, path=None, short_name=None, date=None): mark golden data to use with a test",
"datarepo: mark a test as using snapred-data repo"
"datarepo: mark a test as using snapred-data repo",
"ui: mark a test as a UI test",
]
# The following will be overridden by the commandline option "-m integration"
addopts = "-m 'not (integration or datarepo)'"
Expand Down
109 changes: 40 additions & 69 deletions src/snapred/backend/dao/indexing/Versioning.py
Original file line number Diff line number Diff line change
@@ -1,79 +1,50 @@
from typing import Any, Optional

from numpy import integer
from pydantic import BaseModel, computed_field, field_serializer

from pydantic import BaseModel, ConfigDict, field_validator
from snapred.meta.Config import Config
from snapred.meta.Enum import StrEnum

VERSION_START = Config["version.start"]
VERSION_NONE_NAME = Config["version.friendlyName.error"]
VERSION_DEFAULT_NAME = Config["version.friendlyName.default"]

# VERSION_DEFAULT is a SNAPRed-internal "magic" integer:
# * it is implicitely set during `Config` initialization.
VERSION_DEFAULT = Config["version.default"]

class VersionState(StrEnum):
DEFAULT = Config["version.friendlyName.default"]
LATEST = "latest"
NEXT = "next"


# I'm not sure why ci is failing without this, it doesn't seem to be used anywhere
VERSION_DEFAULT = VersionState.DEFAULT

Version = int | VersionState


class VersionedObject(BaseModel):
# Base class for all versioned DAO

# In pydantic, a leading double underscore activates
# the `__pydantic_private__` feature, which limits the visibility
# of the attribute to the interior scope of its own class.
__version: Optional[int] = None

@classmethod
def parseVersion(cls, version, *, exclude_none: bool = False, exclude_default: bool = False) -> int | None:
v: int | None
# handle two special cases
if (not exclude_none) and (version is None or version == VERSION_NONE_NAME):
v = None
elif (not exclude_default) and (version == VERSION_DEFAULT_NAME or version == VERSION_DEFAULT):
v = VERSION_DEFAULT
# parse integers
elif isinstance(version, int | integer):
if int(version) >= VERSION_START:
v = int(version)
else:
raise ValueError(f"Given version {version} is smaller than start version {VERSION_START}")
# otherwise this is an error
else:
raise ValueError(f"Cannot initialize version as {version}")
return v

@classmethod
def writeVersion(cls, version) -> int | str:
v: int | str
if version is None:
v = VERSION_NONE_NAME
elif version == VERSION_DEFAULT:
v = VERSION_DEFAULT_NAME
elif isinstance(version, int | integer):
v = int(version)
else:
raise ValueError("Version is not valid")
return v

def __init__(self, **kwargs):
version = kwargs.pop("version", None)
super().__init__(**kwargs)
self.__version = self.parseVersion(version)

@field_serializer("version", check_fields=False, when_used="json")
def write_user_defaults(self, value: Any): # noqa ARG002
return self.writeVersion(self.__version)

# NOTE some serialization still using the dict() method
def dict(self, **kwargs):
res = super().dict(**kwargs)
res["version"] = self.writeVersion(res["version"])
return res

@computed_field
@property
def version(self) -> int:
return self.__version

@version.setter
def version(self, v):
self.__version = self.parseVersion(v, exclude_none=True)
version: Version

@field_validator("version", mode="before")
def validate_version(cls, value: Version) -> Version:
if value in VersionState.values():
return value

if isinstance(value, str):
raise ValueError(f"Version must be an int or {VersionState.values()}")

if value is None:
raise ValueError("Version must be specified")

if value < VERSION_START:
raise ValueError(f"Version must be greater than {VERSION_START}")

return value

# NOTE: This approach was taken because 'field_serializer' was checking against the
# INITIAL value of version for some reason. This is a workaround.
#
def model_dump_json(self, *args, **kwargs): # noqa ARG002
if self.version in VersionState.values():
raise ValueError(f"Version {self.version} must be flattened to an int before writing to JSON")
return super().model_dump_json(*args, **kwargs)

model_config = ConfigDict(use_enum_values=True, validate_assignment=True)
12 changes: 4 additions & 8 deletions src/snapred/backend/dao/normalization/NormalizationRecord.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Any, List

from pydantic import field_serializer, field_validator
from pydantic import field_validator

from snapred.backend.dao.indexing.Record import Record
from snapred.backend.dao.indexing.Versioning import VERSION_DEFAULT, VersionedObject
from snapred.backend.dao.indexing.Versioning import VERSION_START, Version, VersionedObject
from snapred.backend.dao.Limit import Limit
from snapred.backend.dao.normalization.Normalization import Normalization
from snapred.meta.mantid.WorkspaceNameGenerator import WorkspaceName
Expand Down Expand Up @@ -31,7 +31,7 @@ class NormalizationRecord(Record, extra="ignore"):
smoothingParameter: float
# detectorPeaks: List[DetectorPeak] # TODO: need to save this for reference during reduction
workspaceNames: List[WorkspaceName] = []
calibrationVersionUsed: int = VERSION_DEFAULT
calibrationVersionUsed: Version = VERSION_START
crystalDBounds: Limit[float]
normalizationCalibrantSamplePath: str

Expand All @@ -44,8 +44,4 @@ def validate_backgroundRunNumber(cls, v: Any) -> Any:
@field_validator("calibrationVersionUsed", mode="before")
@classmethod
def version_is_integer(cls, v: Any) -> Any:
return VersionedObject.parseVersion(v)

@field_serializer("calibrationVersionUsed", when_used="json")
def write_user_defaults(self, value: Any): # noqa ARG002
return VersionedObject.writeVersion(self.calibrationVersionUsed)
return VersionedObject(version=v).version
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from snapred.backend.dao.calibration.Calibration import Calibration
from snapred.backend.dao.calibration.FocusGroupMetric import FocusGroupMetric
from snapred.backend.dao.CrystallographicInfo import CrystallographicInfo
from snapred.backend.dao.indexing.Versioning import Version, VersionState
from snapred.backend.dao.state.PixelGroup import PixelGroup
from snapred.meta.mantid.WorkspaceNameGenerator import WorkspaceName, WorkspaceType

Expand All @@ -18,7 +19,7 @@ class CreateCalibrationRecordRequest(BaseModel, extra="forbid"):

runNumber: str
useLiteMode: bool
version: Optional[int] = None
version: Version = VersionState.NEXT
calculationParameters: Calibration
crystalInfo: CrystallographicInfo
pixelGroups: Optional[List[PixelGroup]] = None
Expand Down
6 changes: 3 additions & 3 deletions src/snapred/backend/dao/request/CreateIndexEntryRequest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional

from pydantic import BaseModel

from snapred.backend.dao.indexing.Versioning import Version, VersionState


class CreateIndexEntryRequest(BaseModel):
"""
Expand All @@ -10,7 +10,7 @@ class CreateIndexEntryRequest(BaseModel):

runNumber: str
useLiteMode: bool
version: Optional[int] = None
version: Version = VersionState.NEXT
comments: str
author: str
appliesTo: str
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from pydantic import BaseModel, field_validator

from snapred.backend.dao.indexing.Versioning import VERSION_DEFAULT
from snapred.backend.dao.indexing.Versioning import Version, VersionState
from snapred.backend.dao.Limit import Pair
from snapred.backend.dao.state.FocusGroup import FocusGroup
from snapred.backend.error.ContinueWarning import ContinueWarning
Expand Down Expand Up @@ -40,7 +40,7 @@ class DiffractionCalibrationRequest(BaseModel, extra="forbid"):

continueFlags: Optional[ContinueWarning.Type] = ContinueWarning.Type.UNSET

startingTableVersion: int = VERSION_DEFAULT
startingTableVersion: Version = VersionState.DEFAULT

@field_validator("fwhmMultipliers", mode="before")
@classmethod
Expand Down
13 changes: 9 additions & 4 deletions src/snapred/backend/dao/request/FarmFreshIngredients.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

from pydantic import BaseModel, ConfigDict, ValidationError, field_validator, model_validator

from snapred.backend.dao.indexing.Versioning import Version, VersionState
from snapred.backend.dao.Limit import Limit, Pair
from snapred.backend.dao.state import FocusGroup
from snapred.meta.Config import Config
from snapred.meta.mantid.AllowedPeakTypes import SymmetricPeakEnum

# TODO: this declaration is duplicated in `ReductionRequest`.
Versions = NamedTuple("Versions", [("calibration", Optional[int]), ("normalization", Optional[int])])
Versions = NamedTuple("Versions", [("calibration", Version), ("normalization", Version)])


class FarmFreshIngredients(BaseModel):
Expand All @@ -21,7 +22,7 @@ class FarmFreshIngredients(BaseModel):

runNumber: str

versions: Versions = Versions(None, None)
versions: Versions = Versions(VersionState.LATEST, VersionState.LATEST)

# allow 'versions' to be accessed as a single version,
# or, to be accessed ambiguously
Expand All @@ -33,7 +34,7 @@ def version(self) -> Optional[int]:

@version.setter
def version(self, v: Optional[int]):
self.versions = (v, None)
self.versions = Versions(v, None)

useLiteMode: bool

Expand Down Expand Up @@ -83,6 +84,10 @@ def focusGroup(self, fg: FocusGroup):
def validate_versions(cls, v) -> Versions:
if not isinstance(v, Versions):
v = Versions(v)
if v.calibration is None:
raise ValueError("Calibration version must be specified")
if v.normalization is None:
raise ValueError("Normalization version must be specified")
return v

@field_validator("crystalDBounds", mode="before")
Expand Down Expand Up @@ -119,4 +124,4 @@ def validate_focusGroups(cls, v: Any):
del v["focusGroup"]
return v

model_config = ConfigDict(extra="forbid")
model_config = ConfigDict(extra="forbid", validate_assignment=True)
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from pydantic import BaseModel

from snapred.backend.dao import RunConfig
from snapred.backend.dao.indexing.Versioning import Version


class LoadCalibrationRecordRequest(BaseModel):
runConfig: RunConfig
version: Version
9 changes: 7 additions & 2 deletions src/snapred/backend/dao/request/ReductionRequest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

from pydantic import BaseModel, ConfigDict, field_validator

from snapred.backend.dao.indexing.Versioning import Version, VersionState
from snapred.backend.dao.ingredients import ArtificialNormalizationIngredients
from snapred.backend.dao.state.FocusGroup import FocusGroup
from snapred.backend.error.ContinueWarning import ContinueWarning
from snapred.meta.mantid.WorkspaceNameGenerator import WorkspaceName
from snapred.meta.mantid.WorkspaceNameGenerator import WorkspaceNameGenerator as wng

Versions = NamedTuple("Versions", [("calibration", Optional[int]), ("normalization", Optional[int])])
Versions = NamedTuple("Versions", [("calibration", Version), ("normalization", Version)])


class ReductionRequest(BaseModel):
Expand All @@ -22,7 +23,7 @@ class ReductionRequest(BaseModel):

# Calibration and normalization versions:
# `None` => <use latest version>
versions: Versions = Versions(None, None)
versions: Versions = Versions(VersionState.LATEST, VersionState.LATEST)

pixelMasks: List[WorkspaceName] = []
artificialNormalizationIngredients: Optional[ArtificialNormalizationIngredients] = None
Expand All @@ -37,6 +38,10 @@ def validate_versions(cls, v) -> Versions:
if not isinstance(v, Tuple):
raise ValueError("'versions' must be a tuple: '(<calibration version>, <normalization version>)'")
v = Versions(v)
if v.calibration is None:
raise ValueError("Calibration version must be specified")
if v.normalization is None:
raise ValueError("Normalization version must be specified")
return v

model_config = ConfigDict(
Expand Down
10 changes: 5 additions & 5 deletions src/snapred/backend/data/DataExportService.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import time
from pathlib import Path
from typing import Tuple
from typing import Optional, Tuple

from pydantic import validate_call

Expand Down Expand Up @@ -64,11 +64,11 @@ def exportCalibrationIndexEntry(self, entry: IndexEntry):
"""
self.dataService.writeCalibrationIndexEntry(entry)

def exportCalibrationRecord(self, record: CalibrationRecord):
def exportCalibrationRecord(self, record: CalibrationRecord, entry: Optional[IndexEntry] = None):
"""
Record must have correct version set.
"""
self.dataService.writeCalibrationRecord(record)
self.dataService.writeCalibrationRecord(record, entry)

def exportCalibrationWorkspaces(self, record: CalibrationRecord):
"""
Expand All @@ -94,11 +94,11 @@ def exportNormalizationIndexEntry(self, entry: IndexEntry):
"""
self.dataService.writeNormalizationIndexEntry(entry)

def exportNormalizationRecord(self, record: NormalizationRecord):
def exportNormalizationRecord(self, record: NormalizationRecord, entry: Optional[IndexEntry] = None):
"""
Record must have correct version set.
"""
self.dataService.writeNormalizationRecord(record)
self.dataService.writeNormalizationRecord(record, entry)

def exportNormalizationWorkspaces(self, record: NormalizationRecord):
"""
Expand Down
Loading

0 comments on commit 117b17c

Please sign in to comment.