Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
svandenhaute committed Jul 28, 2024
1 parent f139aef commit 695d53e
Show file tree
Hide file tree
Showing 13 changed files with 47 additions and 40 deletions.
7 changes: 4 additions & 3 deletions psiflow/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def _write_frames(
all_states.append(extra_states)
with open(outputs[0], "w") as f:
for state in all_states: # avoid double newline by using strip!
f.write(state.to_string().strip() + '\n')
f.write(state.to_string().strip() + "\n")


write_frames = python_app(_write_frames, executors=["default_threads"])
Expand Down Expand Up @@ -605,8 +605,9 @@ def _batch_frames(
else: # write and clear
pass
if len(data) > 0:
with open(outputs[batch_index], "w") as g:
g.write("\n".join(data))
with open(outputs[batch_index], "w") as f:
f.write("\n".join([d.strip() for d in data if d is not None]))
f.write("\n")
batch_index += 1
assert batch_index == len(outputs)

Expand Down
39 changes: 21 additions & 18 deletions psiflow/execution.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from __future__ import annotations # necessary for type-guarding class methods

import re

import logging
import math
import re
import shutil
import sys
from pathlib import Path
Expand Down Expand Up @@ -97,7 +96,7 @@ def create_executor(self, path: Path, **kwargs) -> ParslExecutor:
worker_options.append("--gpus={}".format(self.max_workers))

# ensure proper scale in
if getattr(self.parsl_provider, 'nodes_per_block', 1) > 1:
if getattr(self.parsl_provider, "nodes_per_block", 1) > 1:
worker_options.append("--idle-timeout={}".format(20))
else:
worker_options.append("--idle-timeout={}".format(int(1e6)))
Expand Down Expand Up @@ -180,10 +179,10 @@ def __init__(
self.timeout = timeout

default_env_vars = {
'OMP_NUM_THREADS': str(self.cores_per_worker),
'KMP_AFFINITY': 'granularity=fine,compact,1,0',
'OMP_PROC_BIND': 'false',
'PYTHONUNBUFFERED': 'TRUE',
"OMP_NUM_THREADS": str(self.cores_per_worker),
"KMP_AFFINITY": "granularity=fine,compact,1,0",
"OMP_PROC_BIND": "false",
"PYTHONUNBUFFERED": "TRUE",
}
if env_vars is None:
env_vars = default_env_vars
Expand Down Expand Up @@ -254,17 +253,19 @@ def __init__(
assert max_training_time * 60 < self.max_runtime
self.max_training_time = max_training_time
if self.max_workers > 1:
message = ('the max_training_time keyword does not work '
'in combination with multi-gpu training. Adjust '
'the maximum number of epochs to control the '
'duration of training')
message = (
"the max_training_time keyword does not work "
"in combination with multi-gpu training. Adjust "
"the maximum number of epochs to control the "
"duration of training"
)
assert self.max_training_time is None, message

default_env_vars = {
'OMP_NUM_THREADS': str(self.cores_per_worker),
'KMP_AFFINITY': 'granularity=fine,compact,1,0',
'OMP_PROC_BIND': 'false',
'PYTHONUNBUFFERED': 'TRUE',
"OMP_NUM_THREADS": str(self.cores_per_worker),
"KMP_AFFINITY": "granularity=fine,compact,1,0",
"OMP_PROC_BIND": "false",
"PYTHONUNBUFFERED": "TRUE",
}
if env_vars is None:
env_vars = default_env_vars
Expand Down Expand Up @@ -293,7 +294,9 @@ def wq_resources(self):

resource_specification["gpus"] = nworkers # one per GPU
resource_specification["cores"] = self.cores_available
resource_specification["disk"] = 1000 * nworkers # some random nontrivial amount?
resource_specification["disk"] = (
1000 * nworkers
) # some random nontrivial amount?
memory = 1000 * self.cores_available # similarly rather random
resource_specification["memory"] = int(memory)
resource_specification["running_time_min"] = self.max_training_time
Expand Down Expand Up @@ -360,8 +363,8 @@ def command(self):

def parse_size(size):
size = size.upper()
if not re.match(r' ', size):
size = re.sub(r'([KMGT]?B)', r' \1', size)
if not re.match(r" ", size):
size = re.sub(r"([KMGT]?B)", r" \1", size)
number, unit = [string.strip() for string in size.split()]
return int(float(number) * units[unit])

Expand Down
4 changes: 2 additions & 2 deletions psiflow/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,8 @@ def __post_init__(self):
from mace.tools import torch_tools, utils

torch_tools.set_default_dtype(self.dtype)
if self.device == 'gpu': # when it's not a specific GPU, use any
self.device = 'cuda'
if self.device == "gpu": # when it's not a specific GPU, use any
self.device = "cuda"
self.device = torch_tools.init_device(self.device)

torch.set_num_threads(self.ncores)
Expand Down
3 changes: 1 addition & 2 deletions psiflow/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ def _to_wandb(
os.environ["WANDB_API_KEY"] = wandb_api_key
os.environ["WANDB_SILENT"] = "True"
import tempfile
import numpy as np
from pathlib import Path

import colorcet as cc
Expand Down Expand Up @@ -533,6 +532,6 @@ def to_wandb(self):
return to_wandb(
self.wandb_id,
self.wandb_project,
os.environ['WANDB_API_KEY'],
os.environ["WANDB_API_KEY"],
inputs=[self.metrics],
)
2 changes: 1 addition & 1 deletion psiflow/models/_mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def train(
command_cd = "cd $mytmpdir;"
command_env = ""
for key, value in env_vars.items():
command_env += ' export {}={}; '.format(key, value)
command_env += " export {}={}; ".format(key, value)
command_list = [
command_tmp,
command_cd,
Expand Down
3 changes: 2 additions & 1 deletion psiflow/models/mace_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,9 +676,10 @@ def run(rank: int, args: argparse.Namespace, world_size: int) -> None:


def main():
import signal

import torch
from mace import tools
import signal

signal.signal(signal.SIGTERM, timeout_handler)
# main()
Expand Down
4 changes: 2 additions & 2 deletions psiflow/reference/_cp2k.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,8 @@ def get_single_atom_references(self, element):
"ot": {"minimizer": "CG"}
}
# necessary for oxygen calculation, at least in 2024.1
key = 'ignore_convergence_failure'
cp2k_input_dict['force_eval']['dft']['scf'][key] = "TRUE"
key = "ignore_convergence_failure"
cp2k_input_dict["force_eval"]["dft"]["scf"][key] = "TRUE"

reference = CP2K(
dict_to_str(cp2k_input_dict),
Expand Down
2 changes: 1 addition & 1 deletion psiflow/sampling/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def main():
try:
t0 = time.time()
function([template] * 10) # torch warmp-up before simulation
print('time for 10 evaluations: {}'.format(time.time() - t0))
print("time for 10 evaluations: {}".format(time.time() - t0))
run_driver(
unix=True,
address=str(Path.cwd() / args.address),
Expand Down
2 changes: 1 addition & 1 deletion psiflow/sampling/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import psiflow
from psiflow.data import Dataset
from psiflow.data.utils import write_frames, read_frames
from psiflow.data.utils import write_frames
from psiflow.geometry import Geometry
from psiflow.hamiltonians import Hamiltonian
from psiflow.utils.io import save_xml
Expand Down
4 changes: 2 additions & 2 deletions psiflow/sampling/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def _execute_ipi(
for i, plumed_str in enumerate(plumed_list):
write_command += 'echo "{}" > metad_input{}.txt; '.format(plumed_str, i)
for key, value in env_vars.items():
write_command += ' export {}={}; '.format(key, value)
write_command += " export {}={}; ".format(key, value)
command_start = command_server + " --nwalkers={}".format(nwalkers)
command_start += " --input_xml={}".format(inputs[0].filepath)
command_start += " --start_xyz={}".format(inputs[1].filepath)
Expand Down Expand Up @@ -533,7 +533,7 @@ def _sample(
command_server = definition.server_command()
command_client = definition.client_command()
resources = definition.wq_resources(max_nclients)
print('ENV VARS')
print("ENV VARS")
print(definition.env_vars)
result = execute_ipi(
len(walkers),
Expand Down
13 changes: 8 additions & 5 deletions psiflow/sampling/server.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
import os
import ast
import glob
import os
import signal
import xml.etree.ElementTree as ET
from copy import deepcopy
Expand All @@ -24,6 +24,7 @@ def remdsort(inputfile, prefix="SRT_"):
from ipi.inputs.simulation import InputSimulation
from ipi.utils.io.inputs import io_xml
from ipi.utils.messages import verbosity

verbosity.level = "low"
# opens & parses the input file
ifile = open(inputfile, "r")
Expand Down Expand Up @@ -133,8 +134,8 @@ def remdsort(inputfile, prefix="SRT_"):
)
else:
# FIX
if o.format == 'ase':
extension = 'extxyz'
if o.format == "ase":
extension = "extxyz"
else:
extension = o.format
filename = filename + "_" + padb + "." + extension
Expand Down Expand Up @@ -343,6 +344,7 @@ def anisotropic_barostat_h0(input_xml, data_start):
def start(args):
from ipi.engine.simulation import Simulation
from ipi.utils.softexit import softexit

data_start = read(args.start_xyz, index=":")
assert len(data_start) == args.nwalkers
for i in range(args.nwalkers):
Expand Down Expand Up @@ -371,6 +373,7 @@ def start(args):

def cleanup(args):
from psiflow.data.utils import _write_frames

with open("input.xml", "r") as f:
content = f.read()
if "vibrations" in content:
Expand Down Expand Up @@ -458,9 +461,9 @@ def main():
if not args.cleanup:
start(args)
else:
#try:
# try:
cleanup(args)
#except BaseException as e: # noqa: B036
# except BaseException as e: # noqa: B036
# print(e)
# print("i-PI cleanup failed!")
# print("files in directory:")
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def context(request, tmp_path_factory):
psiflow_config = yaml.safe_load(f)
psiflow_config["path"] = tmp_path_factory.mktemp("psiflow_internal")
psiflow.load(psiflow_config)
context = psiflow.context()
context = psiflow.context() # noqa: F841
yield
parsl.dfk().cleanup()

Expand Down
2 changes: 1 addition & 1 deletion tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def test_readwrite_cycle(dataset, tmp_path):

geometry = Geometry.from_data(np.ones(2), np.zeros((2, 3)), cell=None)
geometry.stress = np.array([np.nan] * 9).reshape(3, 3)
assert 'nan' not in geometry.to_string()
assert "nan" not in geometry.to_string()


def test_dataset_empty(tmp_path):
Expand Down

0 comments on commit 695d53e

Please sign in to comment.