|
| 1 | +import abc |
| 2 | + |
1 | 3 | import ase
|
2 | 4 | import numpy as np
|
| 5 | +from pydantic import BaseModel, Field |
| 6 | + |
| 7 | + |
| 8 | +class UpdateScene(BaseModel, abc.ABC): |
| 9 | + @abc.abstractmethod |
| 10 | + def run(self, atom_ids: list[int], atoms: ase.Atoms) -> list[ase.Atoms]: |
| 11 | + pass |
| 12 | + |
| 13 | + |
| 14 | +class Explode(UpdateScene): |
| 15 | + steps: int = Field(100, le=1000, ge=1) |
| 16 | + particles: int = Field(10, le=100, ge=1) |
| 17 | + |
| 18 | + def run(self, atom_ids: list[int], atoms: ase.Atoms) -> list[ase.Atoms]: |
| 19 | + particles = [] |
| 20 | + for _atom_id in atom_ids: |
| 21 | + for _ in range(self.particles): |
| 22 | + particles.append(ase.Atoms("Na", positions=[atoms.positions[_atom_id]])) |
| 23 | + |
| 24 | + for _ in range(self.steps): |
| 25 | + struct = atoms.copy() |
| 26 | + for particle in particles: |
| 27 | + particle.positions += np.random.normal(scale=0.1, size=(1, 3)) |
| 28 | + struct += particle |
| 29 | + yield struct |
| 30 | + |
| 31 | + |
| 32 | +class Delete(UpdateScene): |
| 33 | + def run(self, atom_ids: list[int], atoms: ase.Atoms) -> list[ase.Atoms]: |
| 34 | + for idx, atom_id in enumerate(sorted(atom_ids)): |
| 35 | + atoms.pop(atom_id - idx) # we remove the atom and shift the index |
| 36 | + return [atoms] |
| 37 | + |
| 38 | + |
| 39 | +class Move(UpdateScene): |
| 40 | + x: float = Field(0.5, le=5, ge=0) |
| 41 | + y: float = Field(0.5, le=5, ge=0) |
| 42 | + z: float = Field(0.5, le=5, ge=0) |
| 43 | + |
| 44 | + def run(self, atom_ids: list[int], atoms: ase.Atoms) -> list[ase.Atoms]: |
| 45 | + for atom_id in atom_ids: |
| 46 | + atom = atoms[atom_id] |
| 47 | + atom.position += np.array([self.x, self.y, self.z]) |
| 48 | + atoms += atom |
| 49 | + return [atoms] |
| 50 | + |
| 51 | + |
| 52 | +class Duplicate(UpdateScene): |
| 53 | + x: float = Field(0.5, le=5, ge=0) |
| 54 | + y: float = Field(0.5, le=5, ge=0) |
| 55 | + z: float = Field(0.5, le=5, ge=0) |
| 56 | + symbol: str = Field("same") |
| 57 | + |
| 58 | + def run(self, atom_ids: list[int], atoms: ase.Atoms) -> list[ase.Atoms]: |
| 59 | + for atom_id in atom_ids: |
| 60 | + atom = ase.Atom(atoms[atom_id].symbol, atoms[atom_id].position) |
| 61 | + atom.position += np.array([self.x, self.y, self.z]) |
| 62 | + atom.symbol = self.symbol if self.symbol != "same" else atom.symbol |
| 63 | + atoms += atom |
| 64 | + return [atoms] |
3 | 65 |
|
4 | 66 |
|
5 |
| -def explode(atom_id: list[int], atoms: ase.Atoms) -> list[ase.Atoms]: |
6 |
| - particles = [] |
7 |
| - for _atom_id in atom_id: |
8 |
| - for _ in range(5): |
9 |
| - particles.append(ase.Atoms("Na", positions=[atoms.positions[_atom_id]])) |
| 67 | +class ChangeType(UpdateScene): |
| 68 | + symbol: str = Field("") |
10 | 69 |
|
11 |
| - for _ in range(102): |
12 |
| - struct = atoms.copy() |
13 |
| - for particle in particles: |
14 |
| - particle.positions += np.random.normal(scale=0.1, size=(1, 3)) |
15 |
| - struct += particle |
16 |
| - yield struct |
| 70 | + def run(self, atom_ids: list[int], atoms: ase.Atoms) -> list[ase.Atoms]: |
| 71 | + for atom_id in atom_ids: |
| 72 | + atoms[atom_id].symbol = self.symbol |
| 73 | + print(atoms) |
| 74 | + return [atoms] |
0 commit comments