diff --git a/src/aiida_quantumespresso/workflows/pdos.py b/src/aiida_quantumespresso/workflows/pdos.py index 09ec45012..2bf6b62b9 100644 --- a/src/aiida_quantumespresso/workflows/pdos.py +++ b/src/aiida_quantumespresso/workflows/pdos.py @@ -10,10 +10,10 @@ Additional functionality: -- Setting ``'align_to_fermi': True`` in the inputs will ensure that the energy range is centred around the Fermi - energy when `Emin` and `Emax` are provided for both the `dos` and `projwfc` inputs. This is useful when you are only - interested in a certain energy range around the Fermi energy. By default the energy range is extracted from the - NSCF calculation. +- Setting ``'energy_range_vs_fermi'`` in the inputs allows to specify an energy range around the Fermi level that should + be covered by the DOS and PDOS. This is useful when you are only interested in a certain energy range around the + Fermi energy. By default, this is not specified and the energy range given in the `dos.x` and `projwfc.x` + inputs will be used. Storage memory management: @@ -97,7 +97,6 @@ def validate_inputs(value, _): - Check that either the `scf` or `nscf.pw.parent_folder` inputs is provided. - Check that the `Emin`, `Emax` and `DeltaE` inputs are the same for the `dos` and `projwfc` namespaces. - - Check that `Emin` and `Emax` are provided in case `align_to_fermi` is set to `True`. """ # Check that either the `scf` input or `nscf.pw.parent_folder` is provided. import warnings @@ -113,10 +112,13 @@ def validate_inputs(value, _): if value['dos']['parameters']['DOS'].get(par, None) != value['projwfc']['parameters']['PROJWFC'].get(par, None): return f'The `{par}`` parameter has to be equal for the `dos` and `projwfc` inputs.' - if value.get('align_to_fermi', False): + if value.get('energy_range_vs_fermi', False): for par in ['Emin', 'Emax']: - if value['dos']['parameters']['DOS'].get(par, None) is None: - return f'The `{par}`` parameter must be set in case `align_to_fermi` is set to `True`.' + if value['dos']['parameters']['DOS'].get(par, None): + warnings.warn( + f'The `{par}` parameter and `energy_range_vs_fermi` were specified.' + 'The value in `energy_range_vs_fermi` will be used.' + ) if 'nbands_factor' in value and 'nbnd' in value['nscf']['pw']['parameters'].base.attributes.get('SYSTEM', {}): return PdosWorkChain.exit_codes.ERROR_INVALID_INPUT_NUMBER_OF_BANDS.message @@ -160,6 +162,17 @@ def validate_projwfc(value, _): jsonschema.validate(value['parameters'].get_dict()['PROJWFC'], get_parameter_schema()) +def validate_energy_range_vs_fermi(value, _): + """Validate specified energy_range_vs_fermi. + + - List needs to consist of two float values. + """ + if len(value) != 2: + return f'`energy_range_vs_fermi` should be a `List` of length two, but got: {value}' + if not all(isinstance(val, (float, int)) for val in value): + return f'`energy_range_vs_fermi` should be a `List` of floats, but got: {value}' + + def clean_calcjob_remote(node): """Clean the remote directory of a ``CalcJobNode``.""" cleaned = False @@ -220,14 +233,15 @@ def define(cls, spec): help='Terminate workchain steps before submitting calculations (test purposes only).' ) spec.input( - 'align_to_fermi', - valid_type=orm.Bool, + 'energy_range_vs_fermi', + valid_type=orm.List, + required=False, serializer=to_aiida_type, - default=lambda: orm.Bool(False), + validator=validate_energy_range_vs_fermi, help=( - 'If true, Emin=>Emin-Efermi & Emax=>Emax-Efermi, where Efermi is taken from the `nscf` calculation. ' - 'Note that it only makes sense to align `Emax` and `Emin` to the fermi level in case they are actually ' - 'provided by in the `dos` and `projwfc` inputs, since otherwise the ' + 'Energy range with respect to the Fermi level that should be covered in DOS and PROJWFC calculation.' + 'If not specified but Emin and Emax are specified in the input parameters, these values will be used.' + 'Otherwise, the default values are extracted from the NSCF calculation.' ) ) spec.input('nbands_factor', valid_type=orm.Float, required=False, @@ -490,10 +504,14 @@ def _generate_dos_inputs(self): dos_inputs = AttributeDict(self.exposed_inputs(DosCalculation, 'dos')) dos_inputs.parent_folder = self.ctx.nscf_parent_folder dos_parameters = self.inputs.dos.parameters.get_dict() + energy_range_vs_fermi = self.inputs.get('energy_range_vs_fermi') - if dos_parameters.pop('align_to_fermi', False): - dos_parameters['DOS']['Emin'] = dos_parameters['Emin'] + self.ctx.nscf_fermi - dos_parameters['DOS']['Emax'] = dos_parameters['Emax'] + self.ctx.nscf_fermi + if energy_range_vs_fermi: + dos_parameters['DOS']['Emin'] = energy_range_vs_fermi[0] + self.ctx.nscf_fermi + dos_parameters['DOS']['Emax'] = energy_range_vs_fermi[1] + self.ctx.nscf_fermi + else: + dos_parameters['DOS'].setdefault('Emin', self.ctx.nscf_emin) + dos_parameters['DOS'].setdefault('Emax', self.ctx.nscf_emax) dos_inputs.parameters = orm.Dict(dos_parameters) dos_inputs['metadata']['call_link_label'] = 'dos' @@ -504,10 +522,14 @@ def _generate_projwfc_inputs(self): projwfc_inputs = AttributeDict(self.exposed_inputs(ProjwfcCalculation, 'projwfc')) projwfc_inputs.parent_folder = self.ctx.nscf_parent_folder projwfc_parameters = self.inputs.projwfc.parameters.get_dict() + energy_range_vs_fermi = self.inputs.get('energy_range_vs_fermi') - if projwfc_parameters.pop('align_to_fermi', False): - projwfc_parameters['PROJWFC']['Emin'] = projwfc_parameters['Emin'] + self.ctx.nscf_fermi - projwfc_parameters['PROJWFC']['Emax'] = projwfc_parameters['Emax'] + self.ctx.nscf_fermi + if energy_range_vs_fermi: + projwfc_parameters['PROJWFC']['Emin'] = energy_range_vs_fermi[0] + self.ctx.nscf_fermi + projwfc_parameters['PROJWFC']['Emax'] = energy_range_vs_fermi[1] + self.ctx.nscf_fermi + else: + projwfc_parameters['PROJWFC'].setdefault('Emin', self.ctx.nscf_emin) + projwfc_parameters['PROJWFC'].setdefault('Emax', self.ctx.nscf_emax) projwfc_inputs.parameters = orm.Dict(projwfc_parameters) projwfc_inputs['metadata']['call_link_label'] = 'projwfc' diff --git a/tests/conftest.py b/tests/conftest.py index 166b3173c..de50bb493 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -879,8 +879,8 @@ def _generate_workchain_ph(exit_code=None, inputs=None, return_inputs=False): def generate_workchain_pdos(generate_workchain, generate_inputs_pw, fixture_code): """Generate an instance of a `PdosWorkChain`.""" - def _generate_workchain_pdos(): - from aiida.orm import Bool, Dict + def _generate_workchain_pdos(emin=None, emax=None, energy_range_vs_fermi=None): + from aiida.orm import Bool, Dict, List from aiida_quantumespresso.utils.resources import get_default_options @@ -902,12 +902,15 @@ def _generate_workchain_pdos(): dos_params = { 'DOS': { - 'Emin': -10, - 'Emax': 10, 'DeltaE': 0.01, } } - projwfc_params = {'PROJWFC': {'Emin': -10, 'Emax': 10, 'DeltaE': 0.01, 'ngauss': 0, 'degauss': 0.01}} + projwfc_params = {'PROJWFC': {'DeltaE': 0.01, 'ngauss': 0, 'degauss': 0.01}} + + if emin and emax: + dos_params['DOS'].update({'Emin': emin, 'Emax': emax}) + projwfc_params['PROJWFC'].update({'Emin': emin, 'Emax': emax}) + dos = { 'code': fixture_code('quantumespresso.dos'), 'parameters': Dict(dos_params), @@ -928,9 +931,10 @@ def _generate_workchain_pdos(): 'nscf': nscf, 'dos': dos, 'projwfc': projwfc, - 'align_to_fermi': Bool(True), 'dry_run': Bool(True) } + if energy_range_vs_fermi: + inputs.update({'energy_range_vs_fermi': List(energy_range_vs_fermi)}) return generate_workchain(entry_point, inputs) diff --git a/tests/workflows/test_pdos.py b/tests/workflows/test_pdos.py index b91d1f518..19b70e62e 100644 --- a/tests/workflows/test_pdos.py +++ b/tests/workflows/test_pdos.py @@ -7,6 +7,7 @@ from aiida.common import LinkType from aiida.engine.utils import instantiate_process from aiida.manage.manager import get_manager +import numpy as np from plumpy import ProcessState import pytest @@ -27,31 +28,38 @@ def instantiate_process_cls(process_cls, inputs): return instantiate_process(runner, process_cls, **inputs) +def check_pdos_energy_range(dos_inputs, projwfc_inputs, expected_p_dos_inputs): + """Check the energy range of the pdos calculation.""" + # check generated inputs + dos_params = dos_inputs.parameters.get_dict() + projwfc_params = projwfc_inputs.parameters.get_dict() + + assert np.isclose(dos_params['DOS']['Emin'], expected_p_dos_inputs[0]) + assert np.isclose(dos_params['DOS']['Emax'], expected_p_dos_inputs[1]) + assert np.isclose(projwfc_params['PROJWFC']['Emin'], expected_p_dos_inputs[0]) + assert np.isclose(projwfc_params['PROJWFC']['Emax'], expected_p_dos_inputs[1]) + + @pytest.mark.parametrize( - 'nscf_output_parameters', [ - { - 'fermi_energy': 6.9 - }, - { - 'fermi_energy_down': 5.9, - 'fermi_energy_up': 6.9 - }, - ] + 'nscf_output_parameters,energy_range_inputs,expected_p_dos_inputs', + [({ + 'fermi_energy': 6.9 + }, (-10, 10, None), (-10, 10)), + ({ + 'fermi_energy_down': 5.9, + 'fermi_energy_up': 6.9 + }, (None, None, [-10, 10]), (-3.1, 16.9)), ({ + 'fermi_energy': 6.9 + }, (None, None, None), (-5.64024889, 8.91047649))] ) def test_default( - generate_workchain_pdos, - generate_workchain_pw, - fixture_localhost, - generate_remote_data, - generate_calc_job, - generate_calc_job_node, - fixture_sandbox, - generate_bands_data, - nscf_output_parameters, + generate_workchain_pdos, generate_workchain_pw, fixture_localhost, generate_remote_data, generate_calc_job, + generate_calc_job_node, fixture_sandbox, generate_bands_data, nscf_output_parameters, energy_range_inputs, + expected_p_dos_inputs ): """Test instantiating the WorkChain, then mock its process, by calling methods in the ``spec.outline``.""" - wkchain = generate_workchain_pdos() + wkchain = generate_workchain_pdos(*energy_range_inputs) assert wkchain.setup() is None assert wkchain.serial_clean() is False @@ -93,8 +101,7 @@ def test_default( result.store() result.base.links.add_incoming(mock_wknode, link_type=LinkType.RETURN, link_label='output_parameters') - bands_data = generate_bands_data() - bands_data.store() + bands_data = generate_bands_data().store() bands_data.base.links.add_incoming(mock_wknode, link_type=LinkType.RETURN, link_label='output_band') wkchain.ctx.workchain_nscf = mock_wknode @@ -104,6 +111,9 @@ def test_default( # mock run dos and projwfc, and check that their inputs are acceptable dos_inputs, projwfc_inputs = wkchain.run_pdos_parallel() + + check_pdos_energy_range(dos_inputs, projwfc_inputs, expected_p_dos_inputs) + generate_calc_job(fixture_sandbox, 'quantumespresso.dos', dos_inputs) generate_calc_job(fixture_sandbox, 'quantumespresso.projwfc', projwfc_inputs)