Skip to content

Commit 02f9ae8

Browse files
authored
Load functions via UI (#29)
* add ui feature to load function * make the box look better * updates * updates * single select * only update selection if changed * bugfix * sorted * small fix * update version * update documentation * v0.1.3
1 parent 7dccabb commit 02f9ae8

File tree

9 files changed

+454
-53
lines changed

9 files changed

+454
-53
lines changed

README.md

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,38 @@ Install via ``pip install zndraw`` or ``pip install zndraw[webview]`` to open zn
1212
You can use ZnDraw with the CLI ``zndraw atoms.xyz``.
1313
For a full list of arguments use `zndraw --help`.
1414

15-
To interface with ``zndraw --update-function zndraw.examples.explode`` you need to be able to import via ``from module import function``.
15+
ZnDraw is designed to work with your Python scripts.
16+
To interface you can inherit from `zndraw.examples.UpdateScene` or follow this base class:
1617

17-
The ZnDraw function expects as inputs
18+
```python
19+
import abc
20+
from pydantic import BaseModel
21+
22+
class UpdateScene(BaseModel, abc.ABC):
23+
@abc.abstractmethod
24+
def run(self, atom_ids: list[int], atoms: ase.Atoms) -> list[ase.Atoms]:
25+
pass
26+
```
27+
28+
The ``run`` method expects as inputs
1829
- atom_ids: list[int], the ids of the currently selected atoms
1930
- atoms: ase.Atoms, the configuration as `ase.Atoms` file where atom_ids where selected.
2031

2132
and as an output:
2233
- list[ase.Atoms], a list of ase Atoms objects to display.
2334

35+
36+
You can define the parameters using `pydantic.Field` which will be displayed in the UI.
37+
2438
```python
25-
def function(atom_ids: list[int], atoms: ase.Atoms) -> list[ase.Atoms]|Generator[ase.Atoms, None, None]:
26-
...
39+
class MyUpdateCls(UpdateScene):
40+
steps: int = Field(100, le=1000, ge=1)
41+
x: float = Field(0.5, le=5, ge=0)
42+
symbol: str = Field("same")
2743
```
2844

45+
To add your method click on the `+` on the right side of the window.
46+
Your should be able to add your method from the working directory via `module.MyUpdateCls` as long
47+
as it can be imported via `from module import MyUpdateCls`.
48+
2949
![Alt text](https://raw.githubusercontent.com/zincware/ZnDraw/main/misc/zndraw_ui.png "ZnDraw UI")

poetry.lock

Lines changed: 55 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "zndraw"
3-
version = "0.1.2"
3+
version = "0.1.3"
44
description = ""
55
authors = ["zincwarecode <[email protected]>"]
66
license = "Apache-2.0"
@@ -15,6 +15,7 @@ flask = "^2.2.3"
1515
tqdm = "^4.65.0"
1616
pywebview = {version = "^4.0.2", optional = true}
1717
znh5md = "^0.1.6"
18+
pydantic = "^1.10.7"
1819

1920
[tool.poetry.group.dev.dependencies]
2021
black = "^23.3.0"

zndraw/app.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,49 @@ def bonds_step():
5858
@app.route("/select", methods=["POST"])
5959
def select() -> list[int]:
6060
"""Update the selected atoms."""
61-
return (
62-
request.json
63-
) # + [x + 1 for x in request.json] + [x - 1 for x in request.json]
61+
step = request.json["step"]
62+
selected_ids = request.json["selected_ids"]
63+
return {"selected_ids": selected_ids, "updated": False}
64+
65+
# atoms = globals.config.get_atoms(step)
66+
67+
# for id in tuple(selected_ids):
68+
# selected_symbol = atoms[id].symbol
69+
# selected_ids += [
70+
# idx for idx, atom in enumerate(atoms) if atom.symbol == selected_symbol
71+
# ]
72+
73+
# return {"selected_ids": list(set(selected_ids)), "updated": True}
74+
75+
76+
@app.route("/add_update_function", methods=["POST"])
77+
def add_update_function():
78+
"""Add a function to the config."""
79+
globals.config.update_function = request.json
80+
try:
81+
signature = globals.config.get_update_signature()
82+
except (ImportError, ValueError) as err:
83+
return {"error": str(err)}
84+
return signature
85+
86+
87+
@app.route("/update_function_values", methods=["POST"])
88+
def update_function_values():
89+
"""Update the values of the update function."""
90+
globals.config.set_update_function_parameters(request.json)
91+
return {}
92+
93+
94+
@app.route("/select_update_function/<name>")
95+
def select_update_function(name):
96+
"""Select a function from the config."""
97+
globals.config.update_function_name = name
98+
return {}
6499

65100

66101
@app.route("/update", methods=["POST"])
67102
def update_scene():
68-
selected_ids = request.json["selected_ids"]
103+
selected_ids = list(sorted(request.json["selected_ids"]))
69104
step = request.json["step"]
70105

71106
function = globals.config.get_update_function()

zndraw/examples/__init__.py

Lines changed: 69 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,74 @@
1+
import abc
2+
13
import ase
24
import numpy as np
5+
from pydantic import BaseModel, Field
6+
7+
8+
class UpdateScene(BaseModel, abc.ABC):
9+
@abc.abstractmethod
10+
def run(self, atom_ids: list[int], atoms: ase.Atoms) -> list[ase.Atoms]:
11+
pass
12+
13+
14+
class Explode(UpdateScene):
15+
steps: int = Field(100, le=1000, ge=1)
16+
particles: int = Field(10, le=100, ge=1)
17+
18+
def run(self, atom_ids: list[int], atoms: ase.Atoms) -> list[ase.Atoms]:
19+
particles = []
20+
for _atom_id in atom_ids:
21+
for _ in range(self.particles):
22+
particles.append(ase.Atoms("Na", positions=[atoms.positions[_atom_id]]))
23+
24+
for _ in range(self.steps):
25+
struct = atoms.copy()
26+
for particle in particles:
27+
particle.positions += np.random.normal(scale=0.1, size=(1, 3))
28+
struct += particle
29+
yield struct
30+
31+
32+
class Delete(UpdateScene):
33+
def run(self, atom_ids: list[int], atoms: ase.Atoms) -> list[ase.Atoms]:
34+
for idx, atom_id in enumerate(sorted(atom_ids)):
35+
atoms.pop(atom_id - idx) # we remove the atom and shift the index
36+
return [atoms]
37+
38+
39+
class Move(UpdateScene):
40+
x: float = Field(0.5, le=5, ge=0)
41+
y: float = Field(0.5, le=5, ge=0)
42+
z: float = Field(0.5, le=5, ge=0)
43+
44+
def run(self, atom_ids: list[int], atoms: ase.Atoms) -> list[ase.Atoms]:
45+
for atom_id in atom_ids:
46+
atom = atoms[atom_id]
47+
atom.position += np.array([self.x, self.y, self.z])
48+
atoms += atom
49+
return [atoms]
50+
51+
52+
class Duplicate(UpdateScene):
53+
x: float = Field(0.5, le=5, ge=0)
54+
y: float = Field(0.5, le=5, ge=0)
55+
z: float = Field(0.5, le=5, ge=0)
56+
symbol: str = Field("same")
57+
58+
def run(self, atom_ids: list[int], atoms: ase.Atoms) -> list[ase.Atoms]:
59+
for atom_id in atom_ids:
60+
atom = ase.Atom(atoms[atom_id].symbol, atoms[atom_id].position)
61+
atom.position += np.array([self.x, self.y, self.z])
62+
atom.symbol = self.symbol if self.symbol != "same" else atom.symbol
63+
atoms += atom
64+
return [atoms]
365

466

5-
def explode(atom_id: list[int], atoms: ase.Atoms) -> list[ase.Atoms]:
6-
particles = []
7-
for _atom_id in atom_id:
8-
for _ in range(5):
9-
particles.append(ase.Atoms("Na", positions=[atoms.positions[_atom_id]]))
67+
class ChangeType(UpdateScene):
68+
symbol: str = Field("")
1069

11-
for _ in range(102):
12-
struct = atoms.copy()
13-
for particle in particles:
14-
particle.positions += np.random.normal(scale=0.1, size=(1, 3))
15-
struct += particle
16-
yield struct
70+
def run(self, atom_ids: list[int], atoms: ase.Atoms) -> list[ase.Atoms]:
71+
for atom_id in atom_ids:
72+
atoms[atom_id].symbol = self.symbol
73+
print(atoms)
74+
return [atoms]

zndraw/globals.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pathlib
44

55
import ase.io
6+
import pydantic
67
import tqdm
78
import znh5md
89

@@ -20,12 +21,52 @@ class Config:
2021
resolution: int = 5
2122
repeat: tuple = (1, 1, 1)
2223

24+
_update_function_name: str = None
25+
26+
@property
27+
def update_function_name(self):
28+
if self._update_function_name is not None:
29+
return self._update_function_name
30+
return self.update_function.rsplit(".", 1)[1]
31+
32+
@update_function_name.setter
33+
def update_function_name(self, value):
34+
self._update_function_name = value
35+
36+
def get_update_signature(self):
37+
module_name, function_name = self.update_function.rsplit(".", 1)
38+
if function_name in _update_functions:
39+
return _update_functions[self.update_function_name].schema()
40+
module = importlib.import_module(module_name)
41+
instance: pydantic.BaseModel = getattr(module, function_name)()
42+
_update_functions[function_name] = instance
43+
return _update_functions[self.update_function_name].schema()
44+
2345
def get_update_function(self):
46+
module_name, function_name = self.update_function.rsplit(".", 1)
2447
if self.update_function is None:
2548
return None
26-
module_name, function_name = self.update_function.rsplit(".", 1)
49+
if function_name in _update_functions:
50+
return _update_functions[self.update_function_name].run
51+
2752
module = importlib.import_module(module_name)
28-
return getattr(module, function_name)
53+
_update_functions[self.update_function_name] = getattr(module, function_name)()
54+
return _update_functions[self.update_function_name].run
55+
56+
def set_update_function_parameters(self, value):
57+
instance = _update_functions[value["function_id"]]
58+
attribute = value["property"].lower()
59+
value = value["value"]
60+
if instance.__annotations__[attribute] == float:
61+
value = float(value)
62+
elif instance.__annotations__[attribute] == int:
63+
value = int(value)
64+
elif instance.__annotations__[attribute] == bool:
65+
value = bool(value)
66+
else:
67+
value = value
68+
print(f"Setting {attribute} of {instance} to {value}")
69+
setattr(instance, attribute, value)
2970

3071
def load_atoms(self, item=None):
3172
if item == 0:
@@ -58,6 +99,8 @@ def get_atoms(self, step=0) -> ase.Atoms:
5899

59100
# TODO set defaults here and load in typer?
60101

102+
_update_functions = {}
103+
61104
_atoms_cache: dict = {}
62105
config = Config()
63106

zndraw/static/main.css

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ canvas {
3535
display: block;
3636
}
3737

38-
#help_btn {
38+
#right_btn_group {
3939
position: absolute;
4040
top: 5px;
4141
right: 10px;
@@ -45,6 +45,12 @@ canvas {
4545
opacity: 0.5;
4646
}
4747

48+
#add_class {
49+
z-index: 2000;
50+
display: none;
51+
52+
}
53+
4854
.atom-spinner,
4955
.atom-spinner * {
5056
box-sizing: border-box;

0 commit comments

Comments
 (0)