Skip to content

Commit

Permalink
Make loadParams part of the PpafmParameters dataclass
Browse files Browse the repository at this point in the history
Also, implement `from_file` class method.
  • Loading branch information
yakutovicha committed Aug 27, 2024
1 parent 775032b commit 04ed3ca
Show file tree
Hide file tree
Showing 13 changed files with 108 additions and 92 deletions.
3 changes: 1 addition & 2 deletions ppafm/cli/fitting/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,8 @@ def getFzlist(BIGarray, MIN, MAX, points):
raise ValueError('Please provide the file "atomtypes.ini"')


parameters = common.PpafmParameters()
parameters = common.PpafmParameters.from_file("params.yaml")
print(" >> OVEWRITING SETTINGS by params.ini ")
common.loadParams("params.ini", parameters=parameters)
scan_min = parameters.scanMin
scan_max = parameters.scanMax
atoms, nDim, lvec = io.loadGeometry("p_eq.xyz", parameters=parameters)
Expand Down
3 changes: 1 addition & 2 deletions ppafm/cli/fitting/plotLine.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ def selectLine(BIGarray, MIN, MAX, startingPoint, endPoint, nsteps):
if options.points == []:
sys.exit("Error!! The '-p' or '--points' argument is required\npython plotLine.py -p XMINxYMINxZMIN XMAXxYMAXxZMAX")

parameters = common.PpafmParameters()
common.loadParams("params.ini", parameters=parameters)
parameters = common.PpafmParameters.from_file("params.ini")
if os.path.isfile("atomtypes.ini"):
print(">> LOADING LOCAL atomtypes.ini")
FFparams = common.loadSpecies("atomtypes.ini")
Expand Down
4 changes: 1 addition & 3 deletions ppafm/cli/fitting/plotZ.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,7 @@ def find_minimum(array, precision=0.0001):
sys.exit(HELP_MSG)

print(" >> OVEWRITING SETTINGS by params.ini ")
parameters = common.PpafmParameters()

common.loadParams("params.ini", parameters=parameters)
parameters = common.PpafmParameters.from_file("params.ini")
dz = parameters.scanStep[2]
Amp = [parameters.Amplitude]
scan_max = parameters.scanMax[2]
Expand Down
8 changes: 2 additions & 6 deletions ppafm/cli/generateDFTD3.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,10 @@ def main():
)
args = parser.parse_args()

try:
# Try overwriting global parameters with params.ini file.
common.loadParams("params.ini")
except Exception:
print("No params.ini provided => using default parameters.")
parameters = common.PpafmParameters.from_file("params.yaml")

# Overwrite global parameters with command line arguments.
common.apply_options(vars(args))
common.apply_options(vars(args), parameters=parameters)

if args.df_params is not None:
p = args.df_params
Expand Down
3 changes: 1 addition & 2 deletions ppafm/cli/generateElFF.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ def main(argv=None):
args = parser.parse_args(argv)

# Load parameters.
parameters = common.PpafmParameters()
common.loadParams("params.ini", parameters)
parameters = common.PpafmParameters.from_file("params.ini")
common.apply_options(vars(args), parameters)

# Load species.
Expand Down
3 changes: 1 addition & 2 deletions ppafm/cli/generateElFF_point_charges.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@ def main(argv=None):
parser = common.CLIParser(description="Generate electrostatic force field by Coulomb interaction of point charges. The generated force field is saved to FFel_{x,y,z}.[ext].")
parser.add_arguments(["input", "input_format", "output_format", "tip", "energy", "noPBC"])
args = parser.parse_args(argv)
parameters = common.PpafmParameters()
common.loadParams("params.ini", parameters=parameters)
parameters = common.PpafmParameters.from_file("params.ini")
common.apply_options(vars(args), parameters=parameters)

computeELFF_pointCharge(args.input, geometry_format=args.input_format, tip=args.tip, save_format=args.output_format, computeVpot=args.energy, parameters=parameters)
Expand Down
3 changes: 1 addition & 2 deletions ppafm/cli/generateLJFF.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@ def main(argv=None):
parser = common.CLIParser(description="Generate a Lennard-Jones, Morse, or vdW force field. The generated force field is saved to FFLJ_{x,y,z}.[ext].")
parser.add_arguments(["input", "input_format", "output_format", "ffModel", "energy", "noPBC"])
args = parser.parse_args(argv)
parameters = common.PpafmParameters()
common.loadParams("params.ini", parameters)
parameters = common.PpafmParameters.from_file("params.yaml")
common.apply_options(vars(args), parameters)
species_file = "atomtypes.ini" if Path("atomtypes.ini").is_file() else None
computeLJ(
Expand Down
5 changes: 1 addition & 4 deletions ppafm/cli/generateTraining_PVE.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,7 @@

file_format = "xsf"

parameters = common.PpafmParameters()

# Arguments definition.
common.loadParams("params.ini", parameters)
parameters = common.PpafmParameters.from_file("params.yaml")

if os.path.isfile("atomtypes.ini"):
print(">> LOADING LOCAL atomtypes.ini")
Expand Down
3 changes: 1 addition & 2 deletions ppafm/cli/plot_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,11 @@ def main(argv=None):
parser.add_argument( "--bI", action="store_true", help="Plot images for Boltzmann current" )
# fmt: on

parameters = common.PpafmParameters()
parameters = common.PpafmParameters.from_file("params.ini")

args = parser.parse_args(argv)
opt_dict = vars(args)

common.loadParams("params.ini", parameters)
common.apply_options(opt_dict, parameters)

if opt_dict["Laplace"]:
Expand Down
3 changes: 1 addition & 2 deletions ppafm/cli/relaxed_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,10 @@ def main(argv=None):
parser.add_argument("--pol_s", action="store", type=float, default=1.0, help="Scaling factor for sample polarization")
# fmt: on

parameters = common.PpafmParameters()
parameters = common.PpafmParameters.from_file("params.ini")

args = parser.parse_args(argv)
opt_dict = vars(args)
common.loadParams("params.ini", parameters)
common.apply_options(opt_dict, parameters)

# =============== Setup
Expand Down
4 changes: 1 addition & 3 deletions ppafm/cli/utilities/evalFFLine.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,7 @@ def getLines(lst, atoms):

# ======== main

parameters = common.PpafmParameters()

PPU.loadParams("params.ini")
parameters = common.PpafmParameters.from_file("params.ini")
FFparams = PPU.loadSpecies("atomtypes.ini", parameters=parameters)
elem_dict = PPU.getFFdict(FFparams)
print(elem_dict)
Expand Down
155 changes: 95 additions & 60 deletions ppafm/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
import os
import typing
from argparse import ArgumentParser
from pathlib import Path

import numpy as np
import pydantic
import toml
import yaml

from . import cpp_utils

Expand Down Expand Up @@ -75,6 +78,98 @@ class Config:
arbitrary_types_allowed = True
validate_assignment = True

@classmethod
def from_file(cls, file_path: typing.Union[str, Path] = Path("params.ini")):
"""Load parameters from a file."""
file_path = Path(file_path)
object = cls()
suffix = file_path.suffix
try:
with open(file_path) as file:
if suffix == ".ini":
lines = file.readlines()
object.load_ini(lines)
elif suffix == ".yaml":
object.parse_obj(yaml.safe_load(file))
elif suffix == ".toml":
object.parse_obj(toml.load(file))
else:
raise ValueError(f"Unsupported file extension: {suffix}")
except FileNotFoundError:
raise ValueError(f"File {file_path} not found")
return object

def parse_obj(self, obj):
for key, value in obj.items():
setattr(self, key, value)

def dump_parameters(self, file_path: typing.Union[str, Path] = Path("params.ini")):
"""Save parameters to a file."""
file_path = Path(file_path)
suffix = file_path.suffix
try:
with open(file_path, "w") as file:
if suffix == ".yaml":
yaml.dump(self.model_dump(), file, default_flow_style=False, indent=4)
elif suffix == ".toml":
toml.dump(self.model_dump(), file)
else:
raise ValueError(f"Unsupported file extension: {suffix}")
except FileNotFoundError:
raise ValueError(f"File {file_path} not found")

def load_ini(self, lines):
for line in lines:
words = line.split()
if len(words) >= 2:
key = words[0]
if hasattr(self, key):
val = getattr(self, key)
if key[0][0] == "#":
continue
if verbose > 0:
print(key, " is class ", val.__class__)
if isinstance(val, bool):
word = words[1].strip()
setattr(self, key, word[0] == "T" or word[0] == "t")
if verbose > 0:
print(key, getattr(self, key), ">>", word, "<<")
elif isinstance(val, float):
setattr(self, key, float(words[1]))
if verbose > 0:
print(key, getattr(self, key), words[1])
elif isinstance(val, int):
setattr(self, key, int(words[1]))
if verbose > 0:
print(key, getattr(self, key), words[1])
elif isinstance(val, str):
setattr(self, key, words[1])
elif isinstance(val, list):
if isinstance(val[0], float):
setattr(self, key, [float(words[1]), float(words[2]), float(words[3])])
if verbose > 0:
print(key, getattr(self, key), words[1], words[2], words[3])
elif isinstance(val[0], int):
if verbose > 0:
print(key)
setattr(self, key, [int(words[1]), int(words[2]), int(words[3])])
if verbose > 0:
print(key, getattr(self, key), words[1], words[2], words[3])
else:
setattr(self, key, [str(words[1]), float(words[2])])
if verbose > 0:
print(key, getattr(self, key), words[1], words[2])
else:
raise ValueError(f"Parameter {key} is not known")
if self.gridN[0] <= 0:
autoGridN(parameters=self)

self.tip = self.tip.replace('"', "")
self.tip = self.tip.replace("'", "")
# Necessary for working even with quotemarks in params.ini
self.tip_base[0] = self.tip_base[0].replace('"', "")
self.tip_base[0] = self.tip_base[0].replace("'", "")


class CLIParser(ArgumentParser):
"""
Expand Down Expand Up @@ -436,66 +531,6 @@ def autoGridN(parameters):
return parameters.gridN


# overide default parameters by parameters read from a file
def loadParams(fname, parameters):
if verbose > 0:
print(" >> OVERWRITING SETTINGS by " + fname)
fin = open(fname)
for line in fin:
words = line.split()
if len(words) >= 2:
key = words[0]
if hasattr(parameters, key):
val = getattr(parameters, key)
if key[0][0] == "#":
continue
if verbose > 0:
print(key, " is class ", val.__class__)
if isinstance(val, bool):
word = words[1].strip()
setattr(parameters, key, word[0] == "T" or word[0] == "t")
if verbose > 0:
print(key, getattr(parameters, key), ">>", word, "<<")
elif isinstance(val, float):
setattr(parameters, key, float(words[1]))
if verbose > 0:
print(key, getattr(parameters, key), words[1])
elif isinstance(val, int):
setattr(parameters, key, int(words[1]))
if verbose > 0:
print(key, getattr(parameters, key), words[1])
elif isinstance(val, str):
setattr(parameters, key, words[1])
elif isinstance(val, list):
if isinstance(val[0], float):
setattr(parameters, key, [float(words[1]), float(words[2]), float(words[3])])
if verbose > 0:
print(key, getattr(parameters, key), words[1], words[2], words[3])
elif isinstance(val[0], int):
if verbose > 0:
print(key)
setattr(parameters, key, [int(words[1]), int(words[2]), int(words[3])])
if verbose > 0:
print(key, getattr(parameters, key), words[1], words[2], words[3])
else:
setattr(parameters, key, [str(words[1]), float(words[2])])
if verbose > 0:
print(key, getattr(parameters, key), words[1], words[2])
else:
raise ValueError(f"Parameter {key} is not known")
fin.close()
if parameters.gridN[0] <= 0:
autoGridN(parameters=parameters)

parameters.tip = parameters.tip.replace('"', "")
parameters.tip = parameters.tip.replace("'", "")
# Necessary for working even with quotemarks in params.ini
parameters.tip_base[0] = parameters.tip_base[0].replace('"', "")
parameters.tip_base[0] = parameters.tip_base[0].replace("'", "")
if verbose > 0:
print("Loaded parameters from ", fname)


def apply_options(opt, parameters):
if verbose > 0:
print("!!!! OVERRIDE parameters !!!! in Apply options:")
Expand Down
3 changes: 1 addition & 2 deletions ppafm/ocl/AFMulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,8 +657,7 @@ def plot_images(self, X, outdir="afm_images", prefix="df"):

def _get_params(file_path):
"""Get AFMulator arguments from a params.ini file."""
parameters = common.PpafmParameters()
common.loadParams(file_path, parameters=parameters)
parameters = common.PpafmParameters.from_file(file_path)
lvec = np.array(
[
parameters.FFgrid0,
Expand Down

0 comments on commit 04ed3ca

Please sign in to comment.