diff --git a/src/aiida_koopmans/engine/aiida.py b/src/aiida_koopmans/engine/aiida.py index 7c6def4..030a12a 100644 --- a/src/aiida_koopmans/engine/aiida.py +++ b/src/aiida_koopmans/engine/aiida.py @@ -151,7 +151,7 @@ def get_pseudopotential(self, library: str, element: str): qb.append(orm.Group, filters={'label': {'==': library}}, tag='pseudo_group') qb.append(UpfData, filters={'attributes.element': {'==': element}}, with_group='pseudo_group') - print(qb.all()) + pseudo_data = None for pseudo in qb.all(): with tempfile.TemporaryDirectory() as dirpath: temp_file = pathlib.Path(dirpath) / (pseudo[0].attributes['element'] + '.upf') @@ -159,4 +159,9 @@ def get_pseudopotential(self, library: str, element: str): temp_file.write_bytes(handle.read()) pseudo_data = read_pseudo_file(temp_file) + if not pseudo_data: + raise ValueError(f"Could not find pseudopotential for element {element} in library {library}") + + self.step_data['pseudo_family'] = library + return pseudo_data diff --git a/src/aiida_koopmans/utils.py b/src/aiida_koopmans/utils.py index 2860ad0..c3a9e45 100644 --- a/src/aiida_koopmans/utils.py +++ b/src/aiida_koopmans/utils.py @@ -36,25 +36,6 @@ def get_PwBaseWorkChain_from_ase(pw_calculator, step_data=None): load_profile() - """ - We should check automatically on the accepted keywords in PwCalculation and where are. Should be possible. - we suppose that the calculator has an attribute called mode e.g. - - pw_calculator.parameters.mode = { - "pw_code": "pw-7.2-ok@localhost", - "metadata": { - "options": { - "max_wallclock_seconds": 3600, - "resources": { - "num_machines": 1, - "num_mpiprocs_per_machine": 1, - "num_cores_per_mpiproc": 1 - }, - "custom_scheduler_commands": "export OMP_NUM_THREADS=1" - } - } - } - """ aiida_inputs = step_data['configuration'] calc_params = pw_calculator._parameters @@ -81,12 +62,12 @@ def get_PwBaseWorkChain_from_ase(pw_calculator, step_data=None): for k in pw_keys['electrons']: if k in calc_params.keys() and k not in ALL_BLOCKED_KEYWORDS: pw_overrides["ELECTRONS"][k] = calc_params[k] - + builder = PwBaseWorkChain.get_builder_from_protocol( code=aiida_inputs["pw_code"], structure=structure, overrides={ - "pseudo_family": "PseudoDojo/0.4/LDA/SR/standard/upf", # TODO: automatic store of pseudos from koopmans folder, if not. + "pseudo_family": step_data["pseudo_family"], # TODO: automatic store of pseudos from koopmans folder, if not. "pw": {"parameters": pw_overrides}, }, electronic_type=ElectronicType.INSULATOR, @@ -152,7 +133,7 @@ def get_Wannier90BandsWorkChain_builder_from_ase(w90_calculator, step_data=None) builder = Wannier90BandsWorkChain.get_builder_from_protocol( codes=codes, structure=nscf.inputs.pw.structure, - pseudo_family="PseudoDojo/0.4/LDA/SR/standard/upf", + pseudo_family=step_data["pseudo_family"], protocol="moderate", projection_type=WannierProjectionType.ANALYTIC, print_summary=False,