Skip to content

Commit

Permalink
Move PeakFitFactory into correct subpackage refs #219
Browse files Browse the repository at this point in the history
  • Loading branch information
peterfpeterson committed Dec 19, 2019
1 parent 618e682 commit 3bd967c
Show file tree
Hide file tree
Showing 9 changed files with 37 additions and 32 deletions.
16 changes: 5 additions & 11 deletions pyrs/core/pyrscore.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,14 @@
from pyrs.utilities import checkdatatypes
from pyrs.core import instrument_geometry
from pyrs.utilities import file_util
from pyrs.core import peak_fit_factory
from pyrs.peaks import PeakFitEngineFactory, SupportedPeakProfiles, SupportedBackgroundTypes
from pyrs.utilities.rs_project_file import HidraConstants, HidraProjectFile, HidraProjectFileMode
from pyrs.core import strain_stress_calculator
from pyrs.core import reduction_manager
from pyrs.core import polefigurecalculator
import os
import numpy

# Define Constants
SUPPORTED_PEAK_TYPES = ['PseudoVoigt', 'Gaussian', 'Voigt'] # 'Lorentzian': No a profile of HB2B


class PyRsCore(object):
"""
Expand Down Expand Up @@ -108,8 +105,7 @@ def init_peak_fit_engine(self, fit_tag):
# get workspace
workspace = self.reduction_service.get_hidra_workspace(fit_tag)
# create a controller from factory
self._peak_fitting_dict[fit_tag] = peak_fit_factory.PeakFitEngineFactory.getInstance('Mantid')(workspace,
None)
self._peak_fitting_dict[fit_tag] = PeakFitEngineFactory.getInstance('Mantid')(workspace, None)
# set wave length: TODO - #81+ - shall be a way to use calibrated or non-calibrated
wave_length_dict = workspace.get_wavelength(calibrated=False, throw_if_not_set=False)
if wave_length_dict is not None:
Expand Down Expand Up @@ -154,10 +150,8 @@ def fit_peaks(self, project_name="",

# Check Inputs
checkdatatypes.check_dict('Peak fitting (information) parameters', peaks_fitting_setup)
checkdatatypes.check_string_variable('Peak type', peak_type,
peak_fit_factory.SupportedPeakProfiles)
checkdatatypes.check_string_variable('Background type', background_type,
peak_fit_factory.SupportedBackgroundTypes)
checkdatatypes.check_string_variable('Peak type', peak_type, SupportedPeakProfiles)
checkdatatypes.check_string_variable('Background type', background_type, SupportedBackgroundTypes)

# Deal with sub runs
if sub_run_list is None:
Expand Down Expand Up @@ -647,4 +641,4 @@ def supported_peak_types(self):
list of supported peaks' types for fitting
:return:
"""
return SUPPORTED_PEAK_TYPES[:]
return SupportedPeakProfiles[:]
6 changes: 5 additions & 1 deletion pyrs/peaks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# flake8: noqa
from __future__ import (absolute_import, division, print_function) # python3 compatibility

from .peak_collection import PeakCollection
from .peak_collection import *
from .peak_fit_factory import *

__all__ = peak_collection.__all__ + peak_fit_factory.__all__
13 changes: 6 additions & 7 deletions pyrs/peaks/mantid_fit_peak.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from mantid.api import AnalysisDataService
from mantid.simpleapi import CreateWorkspace, FitPeaks

__all__ = ['MantidPeakFitEngine']

DEBUG = False # Flag for debugging mode

Expand Down Expand Up @@ -35,8 +36,6 @@ def __init__(self, workspace, mask_name):
self._fitted_function_error_table = None # fitted function parameters' fitting error table workspace
self._model_matrix_ws = None # MatrixWorkspace of the model from fitted function parameters

return

def _create_peak_center_ws(self, peak_center):
""" Create peak center workspace
:param peak_center: float or numpy array
Expand Down Expand Up @@ -117,30 +116,30 @@ def _set_default_peak_params_value(self, peak_function_name, peak_range):
# Estimate
estimated_heights, flat_bkgds = self.estimate_peak_height(peak_range)
max_estimated_height = estimated_heights.max()
flat_bkgd = flat_bkgds[np.argmax(estimated_heights)]
# do not pass A0 to FitPeaks

# Make the difference between peak profiles
if peak_function_name == 'Gaussian':
# Gaussian
peak_param_names = '{}, {}'.format('Height', 'Sigma', 'A0')
peak_param_names = '{}, {}'.format('Height', 'Sigma')

# sigma
instrument_sigma = Gaussian.cal_sigma(hidra_fwhm)

# set value
peak_param_values = "{}, {}".format(max_estimated_height, instrument_sigma, flat_bkgd)
peak_param_values = "{}, {}".format(max_estimated_height, instrument_sigma)

elif peak_function_name == 'PseudoVoigt':
# Pseudo-voig
default_mixing = 0.6

peak_param_names = '{}, {}, {}'.format('Mixing', 'Intensity', 'FWHM', 'A0')
peak_param_names = '{}, {}, {}'.format('Mixing', 'Intensity', 'FWHM')

# intensity
max_intensity = PseudoVoigt.cal_intensity(max_estimated_height, hidra_fwhm, default_mixing)

# set values
peak_param_values = "{}, {}, {}".format(default_mixing, max_intensity, hidra_fwhm, flat_bkgds)
peak_param_values = "{}, {}, {}".format(default_mixing, max_intensity, hidra_fwhm)

else:
# Non-supported case
Expand Down
2 changes: 2 additions & 0 deletions pyrs/peaks/peak_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from pyrs.utilities import checkdatatypes
from pyrs.core.peak_profile_utility import get_effective_parameters_converter, PeakShape, BackgroundFunction

__all__ = ['PeakCollection']


class PeakCollection(object):
"""
Expand Down
2 changes: 2 additions & 0 deletions pyrs/peaks/peak_fit_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from pyrs.core.peak_profile_utility import PeakShape
from pyrs.utilities import checkdatatypes

__all__ = ['PeakFitEngine']


class PeakFitEngine(object):
"""
Expand Down
18 changes: 9 additions & 9 deletions pyrs/core/peak_fit_factory.py → pyrs/peaks/peak_fit_factory.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,27 @@
# Peak fitting engine
from pyrs.peaks.mantid_fit_peak import MantidPeakFitEngine
from pyrs.utilities import checkdatatypes


SupportedPeakProfiles = ['Gaussian', 'PseudoVoigt', 'Voigt']
SupportedBackgroundTypes = ['Flat', 'Linear', 'Quadratic']

__all__ = ['PeakFitEngineFactory', 'SupportedPeakProfiles', 'SupportedBackgroundTypes']


class PeakFitEngineFactory(object):
"""
Peak fitting engine factory
"""
@staticmethod
def getInstance(engine_name):
def getInstance(name):
""" Get instance of Peak fitting engine
:param engine_name:
:return:
"""
checkdatatypes.check_string_variable('Peak fitting engine', engine_name, ['Mantid', 'PyRS'])
checkdatatypes.check_string_variable('Peak fitting engine', name, ['Mantid', 'PyRS'])

# this must be here for now to stop circular imports
from .mantid_fit_peak import MantidPeakFitEngine

if engine_name == 'Mantid':
engine_class = MantidPeakFitEngine
if name == 'Mantid':
return MantidPeakFitEngine
else:
raise RuntimeError('Implement general scipy peak fitting engine')

return engine_class
4 changes: 4 additions & 0 deletions pyrs/peaks/scipypeakfitengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import numpy as np
from pyrs.utilities import checkdatatypes

__all__ = ['ScipyPeakFitEngine']


class ScipyPeakFitEngine(PeakFitEngine):
"""peak fitting engine class for mantid
Expand Down Expand Up @@ -72,6 +74,7 @@ def calculate_peak(X, Data, TTH, peak_function_name, background_function_name, R
else:
return Data - model_y

# TODO signature doesn't match base class
def fit_peaks(self, peak_function_name, background_function_name, scan_index=None):
"""
fit peaks
Expand Down Expand Up @@ -149,6 +152,7 @@ def fit_peaks(self, peak_function_name, background_function_name, scan_index=Non

return

# TODO arguments don't match base class
def calculate_fitted_peaks(self, scan_index):
"""
get the calculated peak's value
Expand Down
6 changes: 3 additions & 3 deletions tests/integration/test_peak_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pyrs.core.summary_generator import SummaryGenerator
from pyrs.dataobjects import SampleLogs
from pyrs.utilities.rs_project_file import HidraProjectFile
from pyrs.core import peak_fit_factory
from pyrs.peaks import PeakFitEngineFactory
import h5py
from pyrs.core import peak_profile_utility
from matplotlib import pyplot as plt
Expand Down Expand Up @@ -225,7 +225,7 @@ def test_retrieve_fit_metadata(source_project_file, output_project_file, peak_ty

# Set peak fitting engine
# create a controller from factory
fit_engine = peak_fit_factory.PeakFitEngineFactory.getInstance('Mantid')(hd_ws, None)
fit_engine = PeakFitEngineFactory.getInstance('Mantid')(hd_ws, None)

# Fit peak
fit_engine.fit_multiple_peaks(sub_run_range=(None, None), # default is all sub runs
Expand Down Expand Up @@ -348,7 +348,7 @@ def test_improve_quality():

# Set peak fitting engine
# create a controller from factory
fit_engine = peak_fit_factory.PeakFitEngineFactory.getInstance('Mantid')(hd_ws, None)
fit_engine = PeakFitEngineFactory.getInstance('Mantid')(hd_ws, None)

peak_type = 'Gaussian'

Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_peak_fit_engine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from pyrs.core.peak_fit_factory import PeakFitEngineFactory
from pyrs.peaks import PeakFitEngineFactory
from pyrs.core.workspaces import HidraWorkspace
from pyrs.core.peak_profile_utility import pseudo_voigt, PeakShape, BackgroundFunction
from pyrs.core.peak_profile_utility import Gaussian, PseudoVoigt
Expand Down

0 comments on commit 3bd967c

Please sign in to comment.