Skip to content

Commit f490bdd

Browse files
Merge pull request #145 from neutrons/magnetic_form_factor_correction
Add magnetic structure factor correction
2 parents 71fc5f0 + 6cdaa3a commit f490bdd

File tree

8 files changed

+217
-19
lines changed

8 files changed

+217
-19
lines changed

Diff for: .github/workflows/actions.yml

+7-5
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ jobs:
1414
run:
1515
shell: bash -l {0}
1616
steps:
17-
- uses: actions/checkout@v3
18-
- uses: mamba-org/setup-micromamba@v1
17+
- uses: actions/checkout@v4
18+
- uses: mamba-org/setup-micromamba@v2
1919
with:
2020
environment-file: environment.yml
2121
- name: pylint
@@ -25,16 +25,18 @@ jobs:
2525
- name: Run tests
2626
run: xvfb-run --server-args="-screen 0 1920x1080x24" -a python -m pytest --cov=src --cov-report=xml --cov-report=term
2727
- name: Upload coverage reports to Codecov
28-
uses: codecov/codecov-action@v3
28+
uses: codecov/codecov-action@v5
29+
with:
30+
token: ${{ secrets.CODECOV_TOKEN }}
2931

3032
conda-build:
3133
runs-on: ubuntu-latest
3234
defaults:
3335
run:
3436
shell: bash -l {0}
3537
steps:
36-
- uses: actions/checkout@v3
37-
- uses: mamba-org/setup-micromamba@v1
38+
- uses: actions/checkout@v4
39+
- uses: mamba-org/setup-micromamba@v2
3840
with:
3941
environment-file: environment.yml
4042
condarc: |

Diff for: environment.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name: shiver
22
channels:
33
- conda-forge
4-
- mantid
4+
- mantid/label/nightly
55
- oncat
66
- neutrons
77
dependencies:
@@ -11,7 +11,7 @@ dependencies:
1111
- conda-verify
1212
- check-wheel-contents
1313
- wheel
14-
- mantidworkbench
14+
- mantidworkbench>6.11.20241216.1838
1515
- pre-commit
1616
- versioningit
1717
- pylint=2.17.3

Diff for: src/shiver/models/corrections.py

+157-1
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,16 @@ class CorrectionsModel:
1515
def __init__(self) -> None:
1616
self.algorithms_observers = set() # need to add them here so they stay in scope
1717
self.error_callback = None
18+
self.algorithm_running = False
1819

1920
def apply(
2021
self,
2122
ws_name: str,
2223
detailed_balance: bool,
2324
hyspec_polarizer_transmission: bool,
2425
temperature: str = "",
26+
magentic_structure_factor: bool = False,
27+
ion_name: str = "",
2528
) -> None:
2629
"""Apply corrections.
2730
@@ -52,6 +55,8 @@ def apply(
5255
output_ws_name = f"{ws_name}_DB"
5356
if hyspec_polarizer_transmission:
5457
output_ws_name = f"{output_ws_name}_PT"
58+
if magentic_structure_factor:
59+
output_ws_name = f"{output_ws_name}_MSF"
5560

5661
if detailed_balance:
5762
self.apply_detailed_balance(
@@ -64,14 +69,27 @@ def apply(
6469
input_ws_name = ws_name
6570
if detailed_balance:
6671
# wait for detailed balance to finish
67-
while not mtd.doesExist(output_ws_name):
72+
while self.algorithm_running:
6873
time.sleep(0.1)
6974
input_ws_name = output_ws_name
7075
self.apply_scattered_transmission_correction(
7176
input_ws_name,
7277
output_ws_name,
7378
)
7479

80+
if magentic_structure_factor:
81+
input_ws_name = ws_name
82+
if hyspec_polarizer_transmission or detailed_balance:
83+
# wait for others to finish
84+
while self.algorithm_running:
85+
time.sleep(0.1)
86+
input_ws_name = output_ws_name
87+
self.apply_magnetic_form_factor_correction(
88+
input_ws_name,
89+
ion_name,
90+
output_ws_name,
91+
)
92+
7593
def apply_detailed_balance(
7694
self,
7795
ws_name: str,
@@ -99,6 +117,8 @@ def apply_detailed_balance(
99117
)
100118
self.algorithms_observers.add(alg_obs)
101119

120+
self.algorithm_running = True
121+
102122
alg = AlgorithmManager.create("ApplyDetailedBalanceMD")
103123

104124
alg_obs.observeFinish(alg)
@@ -115,6 +135,7 @@ def apply_detailed_balance(
115135
logger.error(str(err))
116136
if self.error_callback:
117137
self.error_callback(str(err))
138+
self.algorithm_running = False
118139

119140
def apply_scattered_transmission_correction(
120141
self,
@@ -143,6 +164,8 @@ def apply_scattered_transmission_correction(
143164
)
144165
self.algorithms_observers.add(alg_obs)
145166

167+
self.algorithm_running = True
168+
146169
exponent_factor = 1.0 / 11.0 # see ref above
147170

148171
logger.information(f"Applying DGS Scattered Transmission Correction with exponent factor {exponent_factor}")
@@ -162,6 +185,56 @@ def apply_scattered_transmission_correction(
162185
logger.error(str(err))
163186
if self.error_callback:
164187
self.error_callback(str(err))
188+
self.algorithm_running = False
189+
190+
def apply_magnetic_form_factor_correction(
191+
self,
192+
ws_name: str,
193+
ion_name: str,
194+
output_ws_name: str,
195+
) -> None:
196+
"""Apply MagneticFormFactorCorrection to the workspace.
197+
198+
Parameters
199+
----------
200+
ws_name : str
201+
Workspace name
202+
ion_name : str
203+
Ion name
204+
output_ws_name : str
205+
Output workspace name
206+
207+
Returns
208+
-------
209+
None
210+
"""
211+
212+
alg_obs = MagneticFormFactorCorrectionMDObserver(
213+
parent=self,
214+
ws_name=ws_name,
215+
)
216+
self.algorithms_observers.add(alg_obs)
217+
218+
self.algorithm_running = True
219+
220+
logger.information(f"Applying Magnetic Form Factor Correction with ion name {ion_name}")
221+
222+
alg = AlgorithmManager.create("MagneticFormFactorCorrectionMD")
223+
alg_obs.observeFinish(alg)
224+
alg_obs.observeError(alg)
225+
226+
alg.initialize()
227+
alg.setLogging(False)
228+
try:
229+
alg.setProperty("InputWorkspace", ws_name)
230+
alg.setProperty("IonName", ion_name)
231+
alg.setProperty("OutputWorkspace", output_ws_name)
232+
alg.execute()
233+
except RuntimeError as err:
234+
logger.error(str(err))
235+
if self.error_callback:
236+
self.error_callback(str(err))
237+
self.algorithm_running = False
165238

166239
def connect_error_message(self, callback):
167240
"""Set the callback function for error messages.
@@ -208,6 +281,7 @@ def apply_detailed_balance_finished(
208281
else:
209282
logger.information(f"Finished ApplyDetailedBalanceMD for {ws_name}")
210283
self.algorithms_observers.remove(alg)
284+
self.algorithm_running = False
211285

212286
def apply_scattered_transmission_correction_finished(
213287
self,
@@ -240,6 +314,40 @@ def apply_scattered_transmission_correction_finished(
240314
else:
241315
logger.information(f"Finished DgsScatteredTransmissionCorrectionMD for {ws_name}")
242316
self.algorithms_observers.remove(alg)
317+
self.algorithm_running = False
318+
319+
def apply_magnetic_form_factor_correction_finished(
320+
self,
321+
ws_name: str,
322+
alg: "MagneticFormFactorCorrectionMDObserver",
323+
error: bool = False,
324+
msg="",
325+
) -> None:
326+
"""Call when MagneticFormFactorCorrectionMD finishes.
327+
328+
Parameters
329+
----------
330+
ws_name : str
331+
Workspace name
332+
alg : MagneticFormFactorCorrectionMDObserver
333+
Observer
334+
error : bool, optional
335+
Error flag, by default False
336+
msg : str, optional
337+
Error message, by default ""
338+
339+
Returns
340+
-------
341+
None
342+
"""
343+
if error:
344+
logger.error(f"Error in MagneticFormFactorCorrectionMD for {ws_name}")
345+
if self.error_callback:
346+
self.error_callback(msg)
347+
else:
348+
logger.information(f"Finished MagneticFormFactorCorrectionMD for {ws_name}")
349+
self.algorithms_observers.remove(alg)
350+
self.algorithm_running = False
243351

244352
def get_ws_alg_histories(self, ws_name: str) -> list:
245353
"""Get algorithm histories of the workspace.
@@ -297,6 +405,26 @@ def has_scattered_transmission_correction(self, ws_name: str) -> bool:
297405
return True
298406
return False
299407

408+
def has_magnetic_form_factor_correction(self, ws_name: str) -> Tuple[bool, str]:
409+
"""Check if the workspace has MagneticFormFactorCorrectionMD applied.
410+
411+
Parameters
412+
----------
413+
ws_name : str
414+
Workspace name
415+
416+
Returns
417+
-------
418+
Tuple[bool, str]
419+
True if the workspace has MagneticFormFactorCorrectionMD applied.
420+
Ion name if the workspace has MagneticFormFactorCorrectionMD applied.
421+
"""
422+
alg_histories = self.get_ws_alg_histories(ws_name)
423+
for alg_history in alg_histories:
424+
if alg_history.name() == "MagneticFormFactorCorrectionMD":
425+
return True, alg_history.getPropertyValue("IonName")
426+
return False, ""
427+
300428

301429
class ApplyDetailedBalanceMDObserver(AlgorithmObserver):
302430
"""Observer for ApplyDetailedBalanceMD algorithm"""
@@ -334,3 +462,31 @@ def errorHandle(self, msg): # pylint: disable=invalid-name
334462
self.parent.apply_scattered_transmission_correction_finished(
335463
ws_name=self.ws_name, alg=self, error=True, msg=msg
336464
)
465+
466+
467+
class MagneticFormFactorCorrectionMDObserver(AlgorithmObserver):
468+
"""Observer for MagneticFormFactorCorrectionMD algorithm"""
469+
470+
def __init__(self, parent, ws_name: str) -> None:
471+
super().__init__()
472+
self.parent = parent
473+
self.ws_name = ws_name
474+
475+
def finishHandle(self): # pylint: disable=invalid-name
476+
"""Call upon algorithm finishing"""
477+
self.parent.apply_magnetic_form_factor_correction_finished(ws_name=self.ws_name, alg=self, error=False, msg="")
478+
479+
def errorHandle(self, msg): # pylint: disable=invalid-name
480+
"""Call upon algorithm error"""
481+
self.parent.apply_magnetic_form_factor_correction_finished(ws_name=self.ws_name, alg=self, error=True, msg=msg)
482+
483+
484+
def get_ions_list():
485+
"""Get the list of allowed ions from the MagneticFormFactorCorrectionMD algorithm"""
486+
try:
487+
alg = AlgorithmManager.create("MagneticFormFactorCorrectionMD")
488+
except RuntimeError as err:
489+
logger.error(str(err))
490+
return []
491+
alg.initialize()
492+
return sorted(alg.getProperty("IonName").allowedValues)

Diff for: src/shiver/presenters/histogram.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
from qtpy.QtWidgets import QWidget
55
from shiver.views.corrections import Corrections
6-
from shiver.models.corrections import CorrectionsModel
6+
from shiver.models.corrections import CorrectionsModel, get_ions_list
77
from shiver.models.generate import gather_mde_config_dict, save_mde_config_dict
88

99
from shiver.models.generate import GenerateModel
@@ -270,9 +270,19 @@ def create_corrections_tab(self, name):
270270
# create a new model
271271
corrections_tab_model = CorrectionsModel()
272272

273+
# populate valid ions
274+
ions = get_ions_list()
275+
if ions:
276+
corrections_tab_view.ion_name.addItems(ions)
277+
else:
278+
# disable the magnetic structure factor correction if MagneticFormFactorCorrectionMD is not available
279+
corrections_tab_view.magnetic_structure_factor.setEnabled(False)
280+
corrections_tab_view.ion_name.setEnabled(False)
281+
273282
# populate initial values
274283
has_detailed_balance, temperature = corrections_tab_model.has_apply_detailed_balance(name)
275284
has_scattered_transmission_correction = corrections_tab_model.has_scattered_transmission_correction(name)
285+
has_magnetic_form_factor, ion_name = corrections_tab_model.has_magnetic_form_factor_correction(name)
276286
if has_detailed_balance:
277287
corrections_tab_view.detailed_balance.setChecked(True)
278288
corrections_tab_view.temperature.setText(str(temperature))
@@ -281,6 +291,13 @@ def create_corrections_tab(self, name):
281291
if has_scattered_transmission_correction:
282292
corrections_tab_view.hyspec_polarizer_transmission.setChecked(True)
283293
corrections_tab_view.hyspec_polarizer_transmission.setEnabled(False)
294+
if has_magnetic_form_factor:
295+
corrections_tab_view.magnetic_structure_factor.setChecked(True)
296+
idx = corrections_tab_view.ion_name.findText(ion_name)
297+
if idx != -1:
298+
corrections_tab_view.ion_name.setCurrentIndex(idx)
299+
corrections_tab_view.magnetic_structure_factor.setEnabled(False)
300+
corrections_tab_view.ion_name.setEnabled(False)
284301

285302
# inline functions
286303
def _apply():
@@ -298,6 +315,8 @@ def _apply():
298315
do_detail_balance,
299316
do_polarizer_transmission,
300317
corrections_tab_view.temperature.text(),
318+
corrections_tab_view.magnetic_structure_factor.isChecked(),
319+
corrections_tab_view.ion_name.currentText(),
301320
)
302321
# clean up
303322
tab_widget.setCurrentWidget(self._view)

Diff for: src/shiver/views/corrections.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
QLineEdit,
1111
QSpacerItem,
1212
QSizePolicy,
13+
QComboBox,
1314
)
1415
from qtpy.QtCore import Qt
1516
from .invalid_styles import INVALID_QLINEEDIT, INVALID_QCHECKBOX
@@ -52,8 +53,15 @@ def __init__(self, parent=None, name=None):
5253
self.debye_waller_correction.setEnabled(False)
5354

5455
# magentic structure factor (disabled for now)
55-
self.magentic_structure_factor = QCheckBox("Magentic structure factor")
56-
self.magentic_structure_factor.setEnabled(False)
56+
self.magnetic_structure_factor = QCheckBox("Magnetic structure factor")
57+
self.magnetic_structure_factor.setToolTip(
58+
"Correct for the magnetic structure factor.\nSee MagneticFormFactorCorrectionMD algorithm."
59+
)
60+
self.ion_name = QComboBox()
61+
magnetic_structure_layout = QHBoxLayout()
62+
magnetic_structure_layout.addWidget(self.magnetic_structure_factor)
63+
magnetic_structure_layout.addWidget(self.ion_name)
64+
magnetic_structure_layout.addStretch()
5765

5866
# action group
5967
# add a apply button
@@ -81,7 +89,7 @@ def __init__(self, parent=None, name=None):
8189
correction_layout.addLayout(detailed_balance_layout)
8290
correction_layout.addWidget(self.hyspec_polarizer_transmission)
8391
correction_layout.addWidget(self.debye_waller_correction)
84-
correction_layout.addWidget(self.magentic_structure_factor)
92+
correction_layout.addLayout(magnetic_structure_layout)
8593
correction_layout.addSpacerItem(QSpacerItem(0, 0, QSizePolicy.Minimum, QSizePolicy.Expanding))
8694
correction_layout.addLayout(action_layout)
8795

Diff for: src/shiver/views/oncat.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -422,12 +422,12 @@ def get_dataset_info( # pylint: disable=too-many-branches
422422
for idx, datafile in enumerate(datafiles):
423423
run_number[idx] = datafile["indexed"]["run_number"]
424424
angle[str(run_number[idx])] = (
425-
datafile["metadata"]["entry"].get("daslogs", {}).get(angle_pv, {}).get("average_value", np.NaN)
425+
datafile["metadata"]["entry"].get("daslogs", {}).get(angle_pv, {}).get("average_value", np.nan)
426426
)
427427
if use_notes:
428428
sid = datafile["metadata"]["entry"].get("notes", None)
429429
else:
430-
sid = datafile["metadata"]["entry"].get("daslogs", {}).get("sequencename", {}).get("value", np.NaN)
430+
sid = datafile["metadata"]["entry"].get("daslogs", {}).get("sequencename", {}).get("value", np.nan)
431431
if isinstance(sid, list):
432432
sid = sid[-1]
433433
sequence[idx] = sid

0 commit comments

Comments
 (0)