Skip to content

Commit

Permalink
added the residual to the normalization workflow (#525)
Browse files Browse the repository at this point in the history
* added the residual to the normalization workflow

* add test, up coverage

* move submodule commit to correct hash
  • Loading branch information
walshmm authored Jan 15, 2025
1 parent 32de852 commit e0bbead
Show file tree
Hide file tree
Showing 10 changed files with 113 additions and 24 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from pydantic import BaseModel, ConfigDict

from snapred.meta.mantid.WorkspaceNameGenerator import WorkspaceName


class CalculateNormalizationResidualRequest(BaseModel):
runNumber: int
dataWorkspace: WorkspaceName
calculationWorkspace: WorkspaceName

model_config = ConfigDict(
# required in order to use 'WorkspaceName'
arbitrary_types_allowed=True,
)
6 changes: 5 additions & 1 deletion src/snapred/backend/recipe/GenericRecipe.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
from typing import Generic, TypeVar, get_args

from mantid.simpleapi import ConvertTableToMatrixWorkspace
from mantid.simpleapi import ConvertTableToMatrixWorkspace, Minus
from pydantic import BaseModel

from snapred.backend.log.logger import snapredLogger
Expand Down Expand Up @@ -109,3 +109,7 @@ class BufferMissingColumnsRecipe(GenericRecipe[BufferMissingColumnsAlgo]):

class ArtificialNormalizationRecipe(GenericRecipe[CreateArtificialNormalizationAlgo]):
pass


class MinusRecipe(GenericRecipe[Minus]):
pass
12 changes: 12 additions & 0 deletions src/snapred/backend/service/NormalizationService.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Normalization,
)
from snapred.backend.dao.request import (
CalculateNormalizationResidualRequest,
CalibrationWritePermissionsRequest,
CreateNormalizationRecordRequest,
FarmFreshIngredients,
Expand All @@ -30,6 +31,7 @@
from snapred.backend.log.logger import snapredLogger
from snapred.backend.recipe.GenericRecipe import (
FocusSpectraRecipe,
MinusRecipe,
RawVanadiumCorrectionRecipe,
SmoothDataExcludingPeaksRecipe,
)
Expand Down Expand Up @@ -407,3 +409,13 @@ def fetchMatchingNormalizations(self, request: MatchRunsRequest):
request.useLiteMode
).add()
return set(self.groceryService.fetchGroceryList(self.groceryClerk.buildList())), normalizations

@Register("calculateResidual")
def calculateResidual(self, request: CalculateNormalizationResidualRequest):
outputWorkspace = wng.normCalResidual().runNumber(request.runNumber).unit(wng.Units.DSP).build()
MinusRecipe().executeRecipe(
LHSWorkspace=request.dataWorkspace,
RHSWorkspace=request.calculationWorkspace,
OutputWorkspace=outputWorkspace,
)
return outputWorkspace
12 changes: 12 additions & 0 deletions src/snapred/meta/mantid/WorkspaceNameGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ class WorkspaceType(str, Enum):
SMOOTHED_FOCUSED_RAW_VANADIUM = "smoothedFocusedRawVanadium"
ARTIFICIAL_NORMALIZATION_PREVIEW = "artificialNormalizationPreview"

RESIDUAL = "normCalResidual"

# <reduction tag>_<runNumber>_<timestamp>
REDUCTION_OUTPUT = "reductionOutput"
# <reduction tag>_<stateSHA>_<timestamp>
Expand Down Expand Up @@ -427,6 +429,16 @@ def artificialNormalizationPreview(self):
type=self.ArtificialNormWorkspaceType.PREVIEW,
)

def normCalResidual(self):
return NameBuilder(
WorkspaceType.RESIDUAL,
self._normCalResidualTemplate,
self._normCalResidualTemplateKeys,
self._delimiter,
unit=self.Units.DSP,
version=None,
)

def reductionOutput(self):
return NameBuilder(
WorkspaceType.REDUCTION_OUTPUT,
Expand Down
1 change: 1 addition & 0 deletions src/snapred/resources/application.yml
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ mantid:
focusedRawVanadium: "{unit},{group},{runNumber},raw_van_corr,{version}"
smoothedFocusedRawVanadium: "{unit},{group},{runNumber},fitted_van_corr,{version}"
artificialNormalizationPreview: "artificial_norm,{unit},{group},{runNumber},{type}"
residual: "{unit},{runNumber},residual"
reduction:
output: "reduced,{unit},{group},{runNumber},{timestamp}"
outputGroup: "reduced,{runNumber},{timestamp}"
Expand Down
6 changes: 5 additions & 1 deletion src/snapred/ui/view/NormalizationTweakPeakView.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,16 +161,18 @@ def emitValueChange(self):
return
self.signalValueChanged.emit(index, smoothingValue, xtalDMin, xtalDMax)

def updateWorkspaces(self, focusWorkspace, smoothedWorkspace, peaks):
def updateWorkspaces(self, focusWorkspace, smoothedWorkspace, peaks, residualWorkspace):
self.focusWorkspace = focusWorkspace
self.smoothedWorkspace = smoothedWorkspace
self.residualWorkspace = residualWorkspace
self.groupingSchema = self.groupingFileDropdown.currentText()
self._updateGraphs(peaks)

def _updateGraphs(self, peaks):
# get the updated workspaces and optimal graph grid
focusedWorkspace = mtd[self.focusWorkspace]
smoothedWorkspace = mtd[self.smoothedWorkspace]
residualWorkspace = mtd[self.residualWorkspace]
peaks = pydantic.TypeAdapter(List[GroupPeakList]).validate_python(peaks)
numGraphs = focusedWorkspace.getNumberHistograms()
nrows, ncols = self._optimizeRowsAndCols(numGraphs)
Expand All @@ -186,6 +188,8 @@ def _updateGraphs(self, peaks):

ax.plot(focusedWorkspace, wkspIndex=i, label="Focused Data", normalize_by_bin_width=True)
ax.plot(smoothedWorkspace, wkspIndex=i, label="Smoothed Data", normalize_by_bin_width=True, linestyle="--")
ax.plot(residualWorkspace, wkspIndex=i, label="Residual Data", normalize_by_bin_width=True, linestyle=":")

ax.legend()
ax.tick_params(direction="in")
ax.set_title(f"Group ID: {i + 1}")
Expand Down
2 changes: 2 additions & 0 deletions src/snapred/ui/workflow/DiffCalWorkflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,8 @@ def _specifyRun(self, workflowPresenter):
self.prevXtalDMax = payload.crystalDMax # NOTE set in __init__ to defaults
self.prevFWHM = payload.fwhmMultipliers # NOTE set in __init__ to defaults
self.prevGroupingIndex = view.groupingFileDropdown.currentIndex()

# TODO: These need to be moved to the workspace name generator
self.fitPeaksDiagnostic = f"fit_peak_diag_{self.runNumber}_{self.prevGroupingIndex}_pre"

self.residualWorkspace = f"diffcal_residual_{self.runNumber}"
Expand Down
63 changes: 41 additions & 22 deletions src/snapred/ui/workflow/NormalizationWorkflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from snapred.backend.dao.indexing.IndexEntry import IndexEntry
from snapred.backend.dao.indexing.Versioning import VersionedObject, VersionState
from snapred.backend.dao.request import (
CalculateNormalizationResidualRequest,
CalibrationWritePermissionsRequest,
CreateIndexEntryRequest,
CreateNormalizationRecordRequest,
Expand Down Expand Up @@ -43,6 +44,7 @@ def __init__(self, parent=None):
super().__init__(parent)

self.initializationComplete = False
self.normalizationResponse = None

self.samplePaths = self.request(path="config/samplePaths").data
self.defaultGroupingMap = self.request(path="config/groupingMap", payload="tmfinr").data
Expand Down Expand Up @@ -187,14 +189,24 @@ def _triggerNormalization(self, workflowPresenter):
self._saveView.updateRunNumber(self.runNumber)
self._saveView.updateBackgroundRunNumber(self.backgroundRunNumber)

response = self.request(path="normalization", payload=payload.json())
focusWorkspace = self.responses[-1].data["focusedVanadium"]
smoothWorkspace = self.responses[-1].data["smoothedVanadium"]
peaks = self.responses[-1].data["detectorPeaks"]
self.calibrationRunNumber = self.responses[-1].data["calibrationRunNumber"]
self._tweakPeakView.updateWorkspaces(focusWorkspace, smoothWorkspace, peaks)
self.normalizationResponse = self.request(path="normalization", payload=payload.json())
focusWorkspace = self.normalizationResponse.data["focusedVanadium"]
smoothWorkspace = self.normalizationResponse.data["smoothedVanadium"]
peaks = self.normalizationResponse.data["detectorPeaks"]
self.calibrationRunNumber = self.normalizationResponse.data["calibrationRunNumber"]
# calculate residual
residualWorkspace = self._calcResidual(focusWorkspace, smoothWorkspace)

self._tweakPeakView.updateWorkspaces(focusWorkspace, smoothWorkspace, peaks, residualWorkspace)
self.initializationComplete = True
return response
return self.normalizationResponse

def _calcResidual(self, focusWorkspace, smoothWorkspace):
residualReq = CalculateNormalizationResidualRequest(
runNumber=self.runNumber, dataWorkspace=focusWorkspace, calculationWorkspace=smoothWorkspace
)
residualWorkspace = self.request(path="normalization/calculateResidual", payload=residualReq).data
return residualWorkspace

@EntryExitLogger(logger=logger)
@Slot(WorkflowPresenter, result=SNAPResponse)
Expand All @@ -209,8 +221,8 @@ def _specifyNormalization(self, workflowPresenter): # noqa: ARG002
crystalDBounds={"minimum": self.prevXtalDMin, "maximum": self.prevXtalDMax},
continueFlags=self.continueAnywayFlags,
)
response = self.request(path="normalization/assessment", payload=payload.json())
return response
self.recordResponse = self.request(path="normalization/assessment", payload=payload.json())
return self.recordResponse

@EntryExitLogger(logger=logger)
@Slot(WorkflowPresenter, result=SNAPResponse)
Expand All @@ -224,10 +236,10 @@ def _saveNormalization(self, workflowPresenter):
# validate appliesTo field
appliesTo = IndexEntry.appliesToFormatChecker(appliesTo)

normalizationRecord = self.responses[-1].data
normalizationRecord.workspaceNames.append(self.responses[-2].data["smoothedVanadium"])
normalizationRecord.workspaceNames.append(self.responses[-2].data["focusedVanadium"])
normalizationRecord.workspaceNames.append(self.responses[-2].data["correctedVanadium"])
normalizationRecord = self.recordResponse.data
normalizationRecord.workspaceNames.append(self.normalizationResponse.data["smoothedVanadium"])
normalizationRecord.workspaceNames.append(self.normalizationResponse.data["focusedVanadium"])
normalizationRecord.workspaceNames.append(self.normalizationResponse.data["correctedVanadium"])

createIndexEntryRequest = CreateIndexEntryRequest(
runNumber=runNumber,
Expand Down Expand Up @@ -269,17 +281,20 @@ def callNormalization(self, index, smoothingParameter, xtalDMin, xtalDMax):
crystalDBounds={"minimum": xtalDMin, "maximum": xtalDMax},
continueFlags=self.continueAnywayFlags,
)
self.request(path="normalization", payload=payload.json())
self.normalizationResponse = self.request(path="normalization", payload=payload.json())

focusWorkspace = self.responses[-1].data["focusedVanadium"]
smoothWorkspace = self.responses[-1].data["smoothedVanadium"]
peaks = self.responses[-1].data["detectorPeaks"]
self._tweakPeakView.updateWorkspaces(focusWorkspace, smoothWorkspace, peaks)
focusWorkspace = self.normalizationResponse.data["focusedVanadium"]
smoothWorkspace = self.normalizationResponse.data["smoothedVanadium"]
peaks = self.normalizationResponse.data["detectorPeaks"]

residualWorkspace = self._calcResidual(focusWorkspace, smoothWorkspace)

self._tweakPeakView.updateWorkspaces(focusWorkspace, smoothWorkspace, peaks, residualWorkspace)

@EntryExitLogger(logger=logger)
def applySmoothingUpdate(self, index, smoothingValue, xtalDMin, xtalDMax):
focusWorkspace = self.responses[-1].data["focusedVanadium"]
smoothWorkspace = self.responses[-1].data["smoothedVanadium"]
focusWorkspace = self.normalizationResponse.data["focusedVanadium"]
smoothWorkspace = self.normalizationResponse.data["smoothedVanadium"]

payload = SmoothDataExcludingPeaksRequest(
inputWorkspace=focusWorkspace,
Expand All @@ -295,7 +310,9 @@ def applySmoothingUpdate(self, index, smoothingValue, xtalDMin, xtalDMax):
response = self.request(path="normalization/smooth", payload=payload.json())

peaks = response.data["detectorPeaks"]
self._tweakPeakView.updateWorkspaces(focusWorkspace, smoothWorkspace, peaks)
residualWorkspace = self._calcResidual(focusWorkspace, smoothWorkspace)

self._tweakPeakView.updateWorkspaces(focusWorkspace, smoothWorkspace, peaks, residualWorkspace)

@EntryExitLogger(logger=logger)
@ExceptionToErrLog
Expand Down Expand Up @@ -330,7 +347,9 @@ def renewWhenRecalculate(self, index, smoothingValue, xtalDMin, xtalDMax):
focusWorkspace = self.responses[-1].data["focusedVanadium"]
smoothWorkspace = self.responses[-1].data["smoothedVanadium"]
peaks = self.responses[-1].data["detectorPeaks"]
self._tweakPeakView.updateWorkspaces(focusWorkspace, smoothWorkspace, peaks)
residualWorkspace = self._calcResidual(focusWorkspace, smoothWorkspace)

self._tweakPeakView.updateWorkspaces(focusWorkspace, smoothWorkspace, peaks, residualWorkspace)
else:
raise Exception("Expected data not found in the last response")

Expand Down
1 change: 1 addition & 0 deletions tests/resources/application.yml
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ mantid:
focusedRawVanadium: "_{unit},{group},raw_van_corr,{runNumber},{version}"
smoothedFocusedRawVanadium: "_{unit},{group},fitted_van_corr,{runNumber},{version}"
artificialNormalizationPreview: "artificial_norm,{unit},{group},{runNumber},{type}"
residual: "{unit},{runNumber},residual"
reduction:
output: "_reduced,{unit},{group},{runNumber},{timestamp}"
outputGroup: "_reduced,{runNumber},{timestamp}"
Expand Down
20 changes: 20 additions & 0 deletions tests/unit/backend/service/test_NormalizationService.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
from pathlib import Path
from unittest.mock import ANY, MagicMock, patch

import numpy as np
import pytest
from mantid.simpleapi import (
CreateSingleValuedWorkspace,
CreateWorkspace,
mtd,
)

Expand Down Expand Up @@ -405,3 +407,21 @@ def test_differentStates(self):
self.instance.dataFactoryService.constructStateId = MagicMock()
self.instance.dataFactoryService.constructStateId.side_effect = ["state", "different_state"]
assert not self.instance._sameStates("12345", "different_state")

def test_calcResiduals(self):
self.instance = NormalizationService()
xVals = [1, 2, 3, 4, 5]
yVals = [1, 2, 3, 4]
dataWorkspace = CreateWorkspace(xVals, yVals)
calculationWorkspace = CreateWorkspace(xVals, yVals)

assert np.allclose(dataWorkspace.readY(0), np.array(yVals))

request = mock.Mock(
dataWorkspace=dataWorkspace,
calculationWorkspace=calculationWorkspace,
runNumber="12345",
)
residual = self.instance.calculateResidual(request)
assert mtd.doesExist(residual)
assert np.allclose(mtd[residual].readY(0), np.array([0, 0, 0, 0]))

0 comments on commit e0bbead

Please sign in to comment.