Skip to content

Commit

Permalink
get_pseudpotential method
Browse files Browse the repository at this point in the history
  • Loading branch information
mikibonacci committed Dec 4, 2024
1 parent c906654 commit 35d793f
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 20 deletions.
39 changes: 22 additions & 17 deletions src/aiida_koopmans/engine/aiida.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(self, *args, **kwargs):
'configuration': kwargs.pop('configuration', None),
'steps': {}
}
self.skip_message = False

# here we add the logic to populate configuration by default
# 1. we look for codes stored in AiiDA at localhost, e.g. pw-version@localhost,
Expand All @@ -49,7 +50,6 @@ def run(self, step: Step):
self.get_status(step)
if step.prefix in ['wannier90_preproc', 'pw2wannier90']:
self.set_status(step, Status.COMPLETED)
#self._step_completed_message(step)
return

self.step_data['steps'][step.uid] = {} # maybe not needed
Expand Down Expand Up @@ -77,7 +77,9 @@ def dump_step_data(self):
pickle.dump(self.step_data, f)

def get_status(self, step: Step) -> Status:
return self.get_status_by_uid(step.uid)
status = self.get_status_by_uid(step.uid)
#print(f"Getting status for step {step.uid}: {status}")
return status


def get_status_by_uid(self, uid: str) -> Status:
Expand All @@ -88,30 +90,38 @@ def get_status_by_uid(self, uid: str) -> Status:

def set_status(self, step: Step, status: Status):
self.set_status_by_uid(step.uid, status)
#print(f"Step {step.uid} is {status}")

def set_status_by_uid(self, uid: str, status: Status):
self.step_data['steps'][uid]['status'] = status
self.dump_step_data()

def update_statuses(self) -> None:
time.sleep(5)

time.sleep(1)
for uid in self.step_data['steps']:
# convert from AiiDA to ASE results and populate ASE calculator

if not self.get_status_by_uid(uid) == Status.RUNNING:
continue

workchain = orm.load_node(self.step_data['steps'][uid]['workchain'])
if workchain.is_finished_ok:
self._step_completed_message_by_uid(uid)
self.set_status_by_uid(uid, Status.COMPLETED)

elif workchain.is_finished or workchain.is_excepted or workchain.is_killed:
self._step_failed_message_by_uid(uid)
self.set_status_by_uid(uid, Status.FAILED)

return

def load_results(self, step: Step) -> None:

self.load_step_data()

if step.prefix in ['wannier90_preproc', 'pw2wannier90']:
self.set_status(step, Status.COMPLETED)
return
workchain = orm.load_node(self.step_data['steps'][step.uid]['workchain'])
if "remote_folder" in workchain.outputs:
self.step_data['steps'][step.uid]['remote_folder'] = workchain.outputs.remote_folder.pk
Expand All @@ -127,31 +137,26 @@ def load_results(self, step: Step) -> None:
if step.ext_out in [".pwo",".wout",".kso",".kho"]:
step.calc = output.calc
step.results = output.calc.results
step.generate_band_structure() #nelec=int(workchain.outputs.output_parameters.get_dict()['number_of_electrons']))

self._step_completed_message(step)
if step.ext_out == ".pwo": step.generate_band_structure() #nelec=int(workchain.outputs.output_parameters.get_dict()['number_of_electrons']))

self.dump_step_data()


def load_old_calculator(self, calc: Calc):
raise NotImplementedError # load_old_calculator(calc)

def get_pseudo_data(self, workflow):
pseudo_data = {}
symbols_list = []
for symbol in workflow.pseudopotentials.keys():
symbols_list.append(symbol)
def get_pseudopotential(self, library: str, element: str):

qb = orm.QueryBuilder()
qb.append(orm.Group, filters={'label': {'==': 'pseudo_group'}}, tag='pseudo_group')
qb.append(UpfData, filters={'attributes.element': {'in': symbols_list}}, with_group='pseudo_group')
qb.append(orm.Group, filters={'label': {'==': library}}, tag='pseudo_group')
qb.append(UpfData, filters={'attributes.element': {'==': element}}, with_group='pseudo_group')

print(qb.all())
for pseudo in qb.all():
with tempfile.TemporaryDirectory() as dirpath:
temp_file = pathlib.Path(dirpath) / pseudo[0].attributes.element + '.upf'
with pseudo[0].open(pseudo[0].attributes.element + '.upf', 'wb') as handle:
temp_file = pathlib.Path(dirpath) / (pseudo[0].attributes['element'] + '.upf')
with pseudo[0].open(pseudo[0].attributes['element'] + '.upf', 'rb') as handle:
temp_file.write_bytes(handle.read())
pseudo_data[pseudo[0].attributes.element] = read_pseudo_file(temp_file)
pseudo_data = read_pseudo_file(temp_file)

return pseudo_data
8 changes: 5 additions & 3 deletions src/aiida_koopmans/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,11 @@ def get_Wannier90BandsWorkChain_builder_from_ase(w90_calculator, step_data=None)
codes=codes,
structure=nscf.inputs.pw.structure,
pseudo_family="PseudoDojo/0.4/LDA/SR/standard/upf",
protocol="fast",
protocol="moderate",
projection_type=WannierProjectionType.ANALYTIC,
print_summary=False,
)

# Use nscf explicit kpoints
kpoints = orm.KpointsData()
kpoints.set_cell_from_structure(builder.structure)
Expand All @@ -167,7 +167,6 @@ def get_Wannier90BandsWorkChain_builder_from_ase(w90_calculator, step_data=None)
# set kpath using the WannierizeWFL data.
k_coords = []
k_labels = []
print(w90_calculator.kpts)
k_path=w90_calculator.parameters.kpoint_path.kpts
special_k = w90_calculator.parameters.kpoint_path.todict()["special_points"]
k_linear,special_k_coords,special_k_labels = w90_calculator.parameters.kpoint_path.get_linear_kpoint_axis()
Expand All @@ -189,6 +188,9 @@ def get_Wannier90BandsWorkChain_builder_from_ase(w90_calculator, step_data=None)
del builder.nscf
del builder.projwfc

# pop dis_froz_max
params.pop('dis_froz_max',None)

for k,v in w90_calculator.parameters.items():
if k not in ["kpoints","kpoint_path","projections"]:
params[k] = v
Expand Down

0 comments on commit 35d793f

Please sign in to comment.