diff --git a/src/aiida_koopmans/calculations/kcw.py b/src/aiida_koopmans/calculations/kcw.py index e7abd6f..fe834c9 100644 --- a/src/aiida_koopmans/calculations/kcw.py +++ b/src/aiida_koopmans/calculations/kcw.py @@ -1,13 +1,12 @@ # -*- coding: utf-8 -*- """`CalcJob` implementation for the kcw.x code of Quantum ESPRESSO.""" from pathlib import Path +import os from aiida import orm from aiida.plugins import DataFactory from aiida_quantumespresso.calculations.namelists import NamelistsCalculation -SingleFileData = DataFactory('core.singlefile') - class KcwCalculation(NamelistsCalculation): """`CalcJob` implementation for the kcw.x code of Quantum ESPRESSO. @@ -45,13 +44,13 @@ def define(cls, spec): spec.input('kpoints', valid_type=orm.KpointsData, help='kpoint path if do_bands=True in the parameters', required=False) #spec.input('wann_occ_hr', valid_type=SingleFileData, help='wann_occ_hr', required=False) #spec.input('wann_emp_hr', valid_type=SingleFileData, help='wann_emp_hr', required=False) - spec.input('alpha_occ', valid_type=SingleFileData, help='alpha_occ', required=False) - spec.input('alpha_emp', valid_type=SingleFileData, help='alpha_emp', required=False) - spec.input('wann_u_mat', valid_type=SingleFileData, help='wann_occ_u', required=False) - spec.input('wann_emp_u_mat', valid_type=SingleFileData, help='wann_emp_u', required=False) - spec.input('wann_emp_u_dis_mat', valid_type=SingleFileData, help='wann_dis_u', required=False) - spec.input('wann_centres_xyz', valid_type=SingleFileData, help='wann_occ_centres', required=False) - spec.input('wann_emp_centres_xyz', valid_type=SingleFileData, help='wann_emp_centres', required=False) + spec.input('alpha_occ', valid_type=(orm.SinglefileData, orm.RemoteData), help='alpha_occ', required=False) + spec.input('alpha_emp', valid_type=(orm.SinglefileData, orm.RemoteData), help='alpha_emp', required=False) + spec.input('wann_u_mat', valid_type=(orm.SinglefileData, orm.RemoteData), help='wann_occ_u', required=False) + spec.input('wann_emp_u_mat', valid_type=(orm.SinglefileData, orm.RemoteData), help='wann_emp_u', required=False) + spec.input('wann_emp_u_dis_mat', valid_type=(orm.SinglefileData, orm.RemoteData), help='wann_dis_u', required=False) + spec.input('wann_centres_xyz', valid_type=(orm.SinglefileData, orm.RemoteData), help='wann_occ_centres', required=False) + spec.input('wann_emp_centres_xyz', valid_type=(orm.SinglefileData, orm.RemoteData), help='wann_emp_centres', required=False) spec.input('settings', valid_type=orm.Dict, required=True, default=lambda: orm.Dict({ 'CMDLINE': ["-in", cls._DEFAULT_INPUT_FILE], }), help='Use an additional node for special settings',) #validator=validate_parameters,) @@ -78,12 +77,27 @@ def define(cls, spec): def prepare_for_submission(self, folder): calcinfo = super().prepare_for_submission(folder) - - for wann_file in ['wann_u_mat','wann_emp_u_mat','wann_emp_u_dis_mat','wann_centres_xyz','wann_emp_centres_xyz']: - if hasattr(self.inputs,wann_file): - wannier_singelfiledata = getattr(self.inputs, wann_file) - calcinfo.local_copy_list.append((wannier_singelfiledata.uuid, wannier_singelfiledata.filename, wann_file.replace("_mat",".mat").replace("_xyz",".xyz").replace("wann","aiida"))) - + + for wann_input in ['wann_u_mat','wann_emp_u_mat','wann_emp_u_dis_mat','wann_centres_xyz','wann_emp_centres_xyz']: + wann_parent = getattr(self.inputs, wann_input, None) + if isinstance(wann_parent, orm.SinglefileData): # local copy to be send to the remote + calcinfo.local_copy_list.append((wann_parent.uuid, wann_parent.filename, wann_input.replace("_mat",".mat").replace("_xyz",".xyz").replace("wann","aiida"))) + elif isinstance(wann_parent, orm.RemoteData): + # if remote, we symlink all the files + if wann_input == 'wann_u_mat': + for wann_file in ['wann_u_mat', 'wann_centres_xyz']: + calcinfo.remote_symlink_list.append( + create_symlink_tuple(parent_folder = wann_parent, + filename = wann_file.replace("_mat",".mat").replace("_xyz",".xyz").replace("wann","aiida"), + target = wann_file.replace("_mat",".mat").replace("_xyz",".xyz").replace("wann","aiida"))) + elif wann_input == 'wann_emp_u_mat': + for wann_file in ['wann_emp_u_mat', 'wann_emp_centres_xyz', 'wann_emp_u_dis_mat']: + calcinfo.remote_symlink_list.append( + create_symlink_tuple(parent_folder = wann_parent, + filename = wann_file.replace("_mat",".mat").replace("_xyz",".xyz").replace("wann","aiida").replace("_emp",""), + target = wann_file.replace("_mat",".mat").replace("_xyz",".xyz").replace("wann","aiida"))) + + # TODO: fix the alphas copy for alpha_file in ['alpha_occ','alpha_emp']: if hasattr(self.inputs,alpha_file): suffix = alpha_file.replace("alpha_occ","").replace("alpha_emp","_empty") @@ -97,7 +111,14 @@ def prepare_for_submission(self, folder): handle.write(kpoints_card) return calcinfo - + +def create_symlink_tuple(parent_folder: orm.RemoteData, filename: str, target: str): + return ( + parent_folder.computer.uuid, + os.path.join(parent_folder.get_remote_path(), + filename), target + ) + def prepare_kpoints_card(kpoints=None): # from the BasePwCpInputGenerator, I had to move it here as we cannot just inherit from aiida.common import exceptions diff --git a/src/aiida_koopmans/engine/aiida.py b/src/aiida_koopmans/engine/aiida.py index 60c2fdf..2ba4a70 100644 --- a/src/aiida_koopmans/engine/aiida.py +++ b/src/aiida_koopmans/engine/aiida.py @@ -133,6 +133,7 @@ def update_statuses(self) -> None: elif workchain.is_finished or workchain.is_excepted or workchain.is_killed: self._step_failed_message_by_uid(uid) self.set_status_by_uid(uid, Status.FAILED) + raise ValueError(f"Workchain {workchain.pk} failed.") return @@ -155,7 +156,9 @@ def load_results(self, step: Step) -> None: output = None if step.ext_out == ".wout": output = read_output_file(step, workchain.outputs.wannier90.retrieved) - elif step.ext_out in [".pwo",".kho"]: + if "remote_folder" in workchain.outputs.wannier90: + self.step_data['steps'][step.uid]['remote_folder'] = workchain.outputs.wannier90.remote_folder.pk + elif step.ext_out in [".pwo",".w2ko",".kso",".kho"]: output = read_output_file(step, workchain.outputs.retrieved) if hasattr(output.calc, 'kpts'): step.kpts = output.calc.kpts @@ -163,7 +166,7 @@ def load_results(self, step: Step) -> None: output = read_output_file(step, workchain.outputs.retrieved) - if step.ext_out in [".pwo",".pro",".wout",".kso",".kho"]: + if step.ext_out in [".pwo",".pro",".wout",".w2ko",".kso",".kho"]: step.calc = output.calc step.results = output.calc.results #if step.ext_out == ".pwo": step.generate_band_structure() #nelec=int(workchain.outputs.output_parameters.get_dict()['number_of_electrons'])) diff --git a/src/aiida_koopmans/utils.py b/src/aiida_koopmans/utils.py index 7406fa3..6f60ab1 100644 --- a/src/aiida_koopmans/utils.py +++ b/src/aiida_koopmans/utils.py @@ -288,15 +288,163 @@ def get_projwfc_builder_from_ase(projwfc_calculator, step_data=None): return builder, step_data +def get_kcw_builder_from_ase(kcw_calculator, step_data=None): + + from aiida import load_profile, orm + load_profile() + + aiida_inputs = step_data["configuration"] + + # here we should find the parent folder and the wann files, merged or not (single block for emp or occ manifold). + parent_folder = None + wann_u_mat = None + wann_emp_u_mat = None + wann_emp_u_dis_mat = None + wann_centres_xyz = None + wann_emp_centres_xyz = None + for step_uid, val in step_data['steps'].items(): + if "nscf" in step_uid: + nscf = orm.load_node(val["workchain"]) + parent_folder = nscf.outputs.remote_folder + if "kcw_wannier" in step_uid: + w2kc = orm.load_node(val["workchain"]) + parent_folder = w2kc.outputs.remote_folder + + # SinglefileData merged files: + if "merge_occ_wannier_u" in step_uid: + wann_u_mat = orm.load_node(val['wannier90_u.mat']) + if "merge_occ_wannier_centers" in step_uid: + wann_centres_xyz = orm.load_node(val['wannier90_centres.xyz']) + if "merge_emp_wannier_u" in step_uid: # TODO: check if this is correct + wann_emp_u_mat = orm.load_node(val['wannier90_u.mat']) + if "merge_emp_wannier_centers" in step_uid: + wann_emp_centres_xyz = orm.load_node(val['wannier90_centres.xyz']) + if "merge_emp_wannier_u_dis" in step_uid: + wann_emp_u_dis_mat = orm.load_node(val['wannier90_u_dis.mat']) + + + # RemoteData folders: this is when only one block in occ or emp manifold. + # Instead of the SinglefileData (as searched above), we have only the RemoteData + # of the wannnier90 calc. + # TODO: explain this logic. + tmp_wann_emp_u_mat = None + for step_uid, val in step_data['steps'].items(): + + # the first hit is the single block of occ manifold, + # so we assign it and then we never hit again this block. + if not wann_u_mat and "03-wannier90" in step_uid: + wann_u_mat = orm.load_node(val["remote_folder"]) + + # we continue updating it up to the last hit. + # the last hit is the single block of emp manifold + if not wann_emp_u_mat and "03-wannier90" in step_uid: + tmp_wann_emp_u_mat = orm.load_node(val["remote_folder"]) + + if tmp_wann_emp_u_mat: wann_emp_u_mat = tmp_wann_emp_u_mat + + # get the kcw calculator ext_out: we have three cases: w2ko, kso, kho + ext_out = kcw_calculator.ext_out + + control_namelist = kcw_inputs_keys[ext_out]['control'] + wannier_namelist = kcw_inputs_keys[ext_out]['wannier'] + + control_dict = { + k: v if k in control_namelist else None + for k, v in kcw_calculator.parameters.items() + if k not in ALL_BLOCKED_KEYWORDS + } + + control_dict["calculation"] = "wann2kcw" + for k in list(control_dict): + if control_dict[k] is None: + control_dict.pop(k) + + wannier_dict = { + k: v if k in wannier_namelist else None + for k, v in kcw_calculator.parameters.items() + # ? Using all here, as blocked Wannier90 keywords doesn't contain 'seedname', but kcw does + if k not in ALL_BLOCKED_KEYWORDS + } + for k in list(wannier_dict): + if wannier_dict[k] is None: + wannier_dict.pop(k) + + screening_dict = { + k: v if k in kcs_keys['screen'] else None + for k, v in kcw_calculator.parameters.items() + if k not in ALL_BLOCKED_KEYWORDS + } + for k in list(screening_dict): + if screening_dict[k] is None: + screening_dict.pop(k) + + ham_dict = { + k: v if k in kch_keys['ham'] else None + for k, v in kcw_calculator.parameters.items() + if k not in ALL_BLOCKED_KEYWORDS + } + for k in list(ham_dict): + if ham_dict[k] is None: + ham_dict.pop(k) + + kcw_params = { + "CONTROL": control_dict, + "WANNIER": wannier_dict, + } + if ext_out == ".kso": + kcw_params["SCREEN"] = screening_dict + kcw_params["CONTROL"]["calculation"] = "screen" + elif ext_out == ".kho": + kcw_params["CONTROL"]["calculation"] = "ham" + kcw_params["HAM"] = ham_dict + + # builder. + builder = KcwCalculation.get_builder() + builder.parameters = orm.Dict(kcw_params) + builder.code = orm.load_code(aiida_inputs["kcw_code"]) + + builder.metadata = aiida_inputs["metadata"] + if "metadata_kcw" in aiida_inputs: + builder.metadata = aiida_inputs["metadata_kcw"] + + if ext_out == ".kho": + breakpoint() + # I provide kpoints as an array (output in the wannierized band structure), so I need to convert them. + kpoints = orm.KpointsData() + kpoints.set_kpoints(kcw_calculator._parameters.kpts.kpts, cartesian=False) + builder.kpoints = kpoints + + builder.parent_folder = parent_folder + + if control_dict.get( + "read_unitary_matrix", False + ): + if wann_u_mat: builder.wann_u_mat = wann_u_mat + if wann_emp_u_mat: builder.wann_emp_u_mat = wann_emp_u_mat + if wann_emp_u_dis_mat: builder.wann_emp_u_dis_mat = wann_emp_u_dis_mat + if wann_centres_xyz: builder.wann_centres_xyz = wann_centres_xyz + if wann_emp_centres_xyz: builder.wann_emp_centres_xyz = wann_centres + + if hasattr(kcw_calculator, "alphas"): # TODO: add support for this. + builder.alpha_occ = kcw_calculator.alphas_files["alpha"] + builder.alpha_emp = kcw_calculator.alphas_files["alpha_empty"] + + return builder, step_data ## Here we have the mapping for the calculators initialization. used in the `aiida_calculate_trigger`. mapping_calculators = { ".pwo" : get_PwBaseWorkChain_from_ase, ".wout": get_Wannier90BandsWorkChain_builder_from_ase, ".pro": get_projwfc_builder_from_ase, - #".w2ko": from_wann2kc_to_KcwCalculation, - #".kso": from_kcwscreen_to_KcwCalculation, - #".kho": from_kcwham_to_KcwCalculation, + ".w2ko": get_kcw_builder_from_ase, + ".kso": get_kcw_builder_from_ase, + ".kho": get_kcw_builder_from_ase, +} + +kcw_inputs_keys = { + ".w2ko": w2kcw_keys, + ".kso": kcs_keys, + ".kho": kch_keys, } # read the output file, mimicking the read_results method of ase-koopmans: https://github.com/elinscott/ase_koopmans/blob/master/ase/calculators/espresso/_espresso.py