Skip to content

Commit f50fb61

Browse files
authored
Speed up json loader (#1163)
* speed up json loader * include port position when writing, adding test to load back compound with ports * change name and behavior for include_ports/show_ports * switch the default back * fix unit test (update keyword)
1 parent 9d20e3d commit f50fb61

File tree

4 files changed

+58
-33
lines changed

4 files changed

+58
-33
lines changed

mbuild/compound.py

+16-16
Original file line numberDiff line numberDiff line change
@@ -1953,7 +1953,7 @@ def _visualize_py3dmol(
19531953
tmp_dir = tempfile.mkdtemp()
19541954
cloned.save(
19551955
os.path.join(tmp_dir, "tmp.mol2"),
1956-
show_ports=show_ports,
1956+
include_ports=show_ports,
19571957
overwrite=True,
19581958
parmed_kwargs={"infer_residues": False},
19591959
)
@@ -1984,7 +1984,7 @@ def _visualize_nglview(
19841984
19851985
Parameters
19861986
----------
1987-
show_ports : bool, optional, default=False
1987+
include_ports : bool, optional, default=False
19881988
Visualize Ports in addition to Particles
19891989
"""
19901990
nglview = import_("nglview")
@@ -2001,7 +2001,7 @@ def remove_digits(x):
20012001
tmp_dir = tempfile.mkdtemp()
20022002
self.save(
20032003
os.path.join(tmp_dir, "tmp.mol2"),
2004-
show_ports=show_ports,
2004+
include_ports=show_ports,
20052005
overwrite=True,
20062006
)
20072007
widget = nglview.show_file(os.path.join(tmp_dir, "tmp.mol2"))
@@ -2930,7 +2930,7 @@ def _energy_minimize_openbabel(
29302930
def save(
29312931
self,
29322932
filename,
2933-
show_ports=False,
2933+
include_ports=False,
29342934
forcefield_name=None,
29352935
forcefield_files=None,
29362936
forcefield_debug=False,
@@ -2952,7 +2952,7 @@ def save(
29522952
'hoomdxml', 'gsd', 'gro', 'top', 'lammps', 'lmp', 'mcf', 'pdb', 'xyz',
29532953
'json', 'mol2', 'sdf', 'psf'. See parmed/structure.py for more
29542954
information on savers.
2955-
show_ports : bool, optional, default=False
2955+
include_ports : bool, optional, default=False
29562956
Save ports contained within the compound.
29572957
forcefield_files : str, optional, default=None
29582958
Apply a forcefield to the output file using a forcefield provided
@@ -3024,7 +3024,7 @@ def save(
30243024
When saving the compound as a json, only the following arguments are
30253025
used:
30263026
* filename
3027-
* show_ports
3027+
* include_ports
30283028
30293029
See Also
30303030
--------
@@ -3039,7 +3039,7 @@ def save(
30393039
conversion.save(
30403040
self,
30413041
filename,
3042-
show_ports,
3042+
include_ports,
30433043
forcefield_name,
30443044
forcefield_files,
30453045
forcefield_debug,
@@ -3232,13 +3232,13 @@ def from_trajectory(
32323232
)
32333233

32343234
def to_trajectory(
3235-
self, show_ports=False, chains=None, residues=None, box=None
3235+
self, include_ports=False, chains=None, residues=None, box=None
32363236
):
32373237
"""Convert to an md.Trajectory and flatten the compound.
32383238
32393239
Parameters
32403240
----------
3241-
show_ports : bool, optional, default=False
3241+
include_ports : bool, optional, default=False
32423242
Include all port atoms when converting to trajectory.
32433243
chains : mb.Compound or list of mb.Compound
32443244
Chain types to add to the topology
@@ -3261,7 +3261,7 @@ def to_trajectory(
32613261
"""
32623262
return conversion.to_trajectory(
32633263
compound=self,
3264-
show_ports=show_ports,
3264+
include_ports=include_ports,
32653265
chains=chains,
32663266
residues=residues,
32673267
box=box,
@@ -3323,7 +3323,7 @@ def to_parmed(
33233323
box=None,
33243324
title="",
33253325
residues=None,
3326-
show_ports=False,
3326+
include_ports=False,
33273327
infer_residues=False,
33283328
infer_residues_kwargs={},
33293329
):
@@ -3341,7 +3341,7 @@ def to_parmed(
33413341
residues : str of list of str, optional, default=None
33423342
Labels of residues in the Compound. Residues are assigned by checking
33433343
against Compound.name.
3344-
show_ports : boolean, optional, default=False
3344+
include_ports : boolean, optional, default=False
33453345
Include all port atoms when converting to a `Structure`.
33463346
infer_residues : bool, optional, default=True
33473347
Attempt to assign residues based on the number of bonds and particles in
@@ -3364,7 +3364,7 @@ def to_parmed(
33643364
box=box,
33653365
title=title,
33663366
residues=residues,
3367-
show_ports=show_ports,
3367+
include_ports=include_ports,
33683368
infer_residues=infer_residues,
33693369
infer_residues_kwargs=infer_residues_kwargs,
33703370
)
@@ -3400,7 +3400,7 @@ def to_pybel(
34003400
box=None,
34013401
title="",
34023402
residues=None,
3403-
show_ports=False,
3403+
include_ports=False,
34043404
infer_residues=False,
34053405
):
34063406
"""Create a pybel.Molecule from a Compound.
@@ -3413,7 +3413,7 @@ def to_pybel(
34133413
residues : str of list of str
34143414
Labels of residues in the Compound. Residues are assigned by
34153415
checking against Compound.name.
3416-
show_ports : boolean, optional, default=False
3416+
include_ports : boolean, optional, default=False
34173417
Include all port atoms when converting to a `Structure`.
34183418
infer_residues : bool, optional, default=False
34193419
Attempt to assign residues based on names of children
@@ -3438,7 +3438,7 @@ def to_pybel(
34383438
box=box,
34393439
title=title,
34403440
residues=residues,
3441-
show_ports=show_ports,
3441+
include_ports=include_ports,
34423442
)
34433443

34443444
def to_smiles(self, backend="pybel"):

mbuild/conversion.py

+16-14
Original file line numberDiff line numberDiff line change
@@ -966,7 +966,7 @@ def from_gmso(
966966
def save(
967967
compound,
968968
filename,
969-
show_ports=False,
969+
include_ports=False,
970970
forcefield_name=None,
971971
forcefield_files=None,
972972
forcefield_debug=False,
@@ -990,7 +990,7 @@ def save(
990990
'hoomdxml', 'gsd', 'gro', 'top', 'lammps', 'lmp', 'mcf', 'xyz', 'pdb',
991991
'sdf', 'mol2', 'psf'. See parmed/structure.py for more information on
992992
savers.
993-
show_ports : bool, optional, default=False
993+
include_ports : bool, optional, default=False
994994
Save ports contained within the compound.
995995
forcefield_files : str, optional, default=None
996996
Apply a forcefield to the output file using a forcefield provided by the
@@ -1054,7 +1054,7 @@ def save(
10541054
-----
10551055
When saving the compound as a json, only the following arguments are used:
10561056
- filename
1057-
- show_ports
1057+
- include_ports
10581058
10591059
See Also
10601060
--------
@@ -1068,7 +1068,9 @@ def save(
10681068
extension = os.path.splitext(filename)[-1]
10691069

10701070
if extension == ".json":
1071-
compound_to_json(compound, file_path=filename, include_ports=show_ports)
1071+
compound_to_json(
1072+
compound, file_path=filename, include_ports=include_ports
1073+
)
10721074
return
10731075

10741076
# Savers supported by mbuild.formats
@@ -1098,7 +1100,7 @@ def save(
10981100
structure = compound.to_parmed(
10991101
box=box,
11001102
residues=residues,
1101-
show_ports=show_ports,
1103+
include_ports=include_ports,
11021104
**parmed_kwargs,
11031105
)
11041106
# Apply a force field with foyer if specified
@@ -1301,7 +1303,7 @@ def to_parmed(
13011303
box=None,
13021304
title="",
13031305
residues=None,
1304-
show_ports=False,
1306+
include_ports=False,
13051307
infer_residues=False,
13061308
infer_residues_kwargs={},
13071309
):
@@ -1321,7 +1323,7 @@ def to_parmed(
13211323
residues : str of list of str, optional, default=None
13221324
Labels of residues in the Compound. Residues are assigned by checking
13231325
against Compound.name.
1324-
show_ports : boolean, optional, default=False
1326+
include_ports : boolean, optional, default=False
13251327
Include all port atoms when converting to a `Structure`.
13261328
infer_residues : bool, optional, default=False
13271329
Attempt to assign residues based on the number of bonds and particles in
@@ -1361,7 +1363,7 @@ def to_parmed(
13611363
atom_residue_map = dict()
13621364

13631365
# Loop through particles and add initialize ParmEd atoms
1364-
for atom in compound.particles(include_ports=show_ports):
1366+
for atom in compound.particles(include_ports=include_ports):
13651367
if atom.port_particle:
13661368
current_residue = port_residue
13671369
atom_residue_map[atom] = current_residue
@@ -1458,13 +1460,13 @@ def to_parmed(
14581460

14591461

14601462
def to_trajectory(
1461-
compound, show_ports=False, chains=None, residues=None, box=None
1463+
compound, include_ports=False, chains=None, residues=None, box=None
14621464
):
14631465
"""Convert to an md.Trajectory and flatten the compound.
14641466
14651467
Parameters
14661468
----------
1467-
show_ports : bool, optional, default=False
1469+
include_ports : bool, optional, default=False
14681470
Include all port atoms when converting to trajectory.
14691471
chains : mb.Compound or list of mb.Compound
14701472
Chain types to add to the topology
@@ -1485,7 +1487,7 @@ def to_trajectory(
14851487
_to_topology
14861488
"""
14871489
md = import_("mdtraj")
1488-
atom_list = [particle for particle in compound.particles(show_ports)]
1490+
atom_list = [particle for particle in compound.particles(include_ports)]
14891491

14901492
top = _to_topology(compound, atom_list, chains, residues)
14911493

@@ -1650,7 +1652,7 @@ def to_pybel(
16501652
box=None,
16511653
title="",
16521654
residues=None,
1653-
show_ports=False,
1655+
include_ports=False,
16541656
infer_residues=False,
16551657
):
16561658
"""Create a pybel.Molecule from a Compound.
@@ -1665,7 +1667,7 @@ def to_pybel(
16651667
residues : str of list of str
16661668
Labels of residues in the Compound. Residues are assigned by checking
16671669
against Compound.name.
1668-
show_ports : boolean, optional, default=False
1670+
include_ports : boolean, optional, default=False
16691671
Include all port atoms when converting to a `Structure`.
16701672
infer_residues : bool, optional, default=False
16711673
Attempt to assign residues based on names of children
@@ -1697,7 +1699,7 @@ def to_pybel(
16971699
compound_residue_map = dict()
16981700
atom_residue_map = dict()
16991701

1700-
for i, part in enumerate(compound.particles(include_ports=show_ports)):
1702+
for i, part in enumerate(compound.particles(include_ports=include_ports)):
17011703
if residues and part.name in residues:
17021704
current_residue = mol.NewResidue()
17031705
current_residue.SetName(part.name)

mbuild/formats/json_formats.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import ele
77

88
import mbuild as mb
9+
from mbuild.bond_graph import BondGraph
910
from mbuild.exceptions import MBuildError
1011

1112

@@ -56,6 +57,7 @@ def compound_from_json(json_file):
5657
sub_cmpd = _dict_to_mb(sub_compound)
5758
converted_dict[sub_compound["id"]] = sub_cmpd
5859
sub_cmpd = converted_dict[sub_compound["id"]]
60+
sub_cmpd.bond_graph = None
5961

6062
label_str = sub_compound["label"]
6163
label_list = compound.get("label_list", {})
@@ -64,7 +66,12 @@ def compound_from_json(json_file):
6466
parent_compound.labels[key] = list()
6567
if sub_compound["id"] in vals:
6668
parent_compound.labels[key].append(sub_cmpd)
67-
parent_compound.add(sub_cmpd, label=label_str)
69+
parent_compound.add(sub_cmpd, check_box_size=False, label=label_str)
70+
71+
parent.bond_graph = BondGraph()
72+
parent.bond_graph.add_nodes_from(
73+
[particle for particle in parent.particles()]
74+
)
6875

6976
_add_ports(compound_dict, converted_dict)
7077
_add_bonds(compound_dict, parent, converted_dict)
@@ -152,6 +159,7 @@ def _particle_info(cmpd, include_ports=False):
152159
else:
153160
port_info["anchor"] = None
154161
port_info["label"] = None
162+
port_info["pos"] = port.pos.tolist()
155163
# Is this the most efficient way?
156164
for key, val in cmpd.labels.items():
157165
if (val == port) and val.port_particle:
@@ -236,15 +244,21 @@ def _add_ports(compound_dict, converted_dict):
236244
for port in ports:
237245
label_str = port["label"]
238246
port_to_add = mb.Port(anchor=converted_dict[port["anchor"]])
239-
converted_dict[compound["id"]].add(port_to_add, label_str)
247+
if port.get("pos", None) is not None:
248+
port_to_add.translate_to(port.get("pos"))
249+
converted_dict[compound["id"]].add(
250+
port_to_add, label_str, check_box_size=False
251+
)
240252
# Not necessary to add same port twice
241253
compound["ports"] = None
242254
ports = subcompound.get("ports", None)
243255
if ports:
244256
for port in ports:
245257
label_str = port["label"]
246258
port_to_add = mb.Port(anchor=converted_dict[port["anchor"]])
247-
converted_dict[subcompound["id"]].add(port_to_add, label_str)
259+
converted_dict[subcompound["id"]].add(
260+
port_to_add, label_str, check_box_size=False
261+
)
248262
subcompound["ports"] = None
249263

250264

mbuild/tests/test_json_formats.py

+9
Original file line numberDiff line numberDiff line change
@@ -121,3 +121,12 @@ def test_float_64_position(self):
121121
compound_to_json(ethane, "ethane.json", include_ports=True)
122122
ethane_copy = compound_from_json("ethane.json")
123123
assert np.allclose(ethane.xyz, ethane_copy.xyz, atol=10**-6)
124+
125+
def test_compound_with_port(self):
126+
ch2 = mb.lib.moieties.CH2()
127+
ch2.save("ch2.json", include_ports=True, overwrite=True)
128+
129+
loaded_ch2 = mb.load("ch2.json")
130+
assert len(loaded_ch2.all_ports()) == 2
131+
for port in loaded_ch2.all_ports():
132+
assert port.separation == 0.07

0 commit comments

Comments
 (0)