Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expanding for multistates #12

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
279 changes: 279 additions & 0 deletions src/kartograf/atom_mapper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# This code is part of kartograf and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/kartograf
from collections import defaultdict
from itertools import product

import copy
import dill
Expand Down Expand Up @@ -880,3 +882,280 @@
molA=A.to_rdkit(), molB=B.to_rdkit()
),
)

def _raw_mapping(self,
molA: Chem.Mol,
molB: Chem.Mol,
max_d: float = 0.95,
masked_atoms_molA: Optional[list[int]] = None,
masked_atoms_molB: Optional[list[int]] = None,
pre_mapped_atoms: Optional[dict[int, int]] = None,
map_hydrogens: bool = True,):

if masked_atoms_molA is None:
masked_atoms_molA = []
if masked_atoms_molB is None:
masked_atoms_molB = []
if pre_mapped_atoms is None:
pre_mapped_atoms = dict()

Check warning on line 900 in src/kartograf/atom_mapper.py

View check run for this annotation

Codecov / codecov/patch

src/kartograf/atom_mapper.py#L895-L900

Added lines #L895 - L900 were not covered by tests


molA_pos = molA.GetConformer().GetPositions()
molB_pos = molB.GetConformer().GetPositions()
masked_atoms_molA = copy.deepcopy(masked_atoms_molA)
masked_atoms_molB = copy.deepcopy(masked_atoms_molB)
pre_mapped_atoms = copy.deepcopy(pre_mapped_atoms)

Check warning on line 907 in src/kartograf/atom_mapper.py

View check run for this annotation

Codecov / codecov/patch

src/kartograf/atom_mapper.py#L903-L907

Added lines #L903 - L907 were not covered by tests

if len(pre_mapped_atoms) > 0:
masked_atoms_molA.extend(pre_mapped_atoms.keys())
masked_atoms_molB.extend(pre_mapped_atoms.values())

Check warning on line 911 in src/kartograf/atom_mapper.py

View check run for this annotation

Codecov / codecov/patch

src/kartograf/atom_mapper.py#L909-L911

Added lines #L909 - L911 were not covered by tests

molA_masked_atomMapping, molA_pos = self._mask_atoms(

Check warning on line 913 in src/kartograf/atom_mapper.py

View check run for this annotation

Codecov / codecov/patch

src/kartograf/atom_mapper.py#L913

Added line #L913 was not covered by tests
mol=molA,
mol_pos=molA_pos,
masked_atoms=masked_atoms_molA,
map_hydrogens=map_hydrogens,
)
molB_masked_atomMapping, molB_pos = self._mask_atoms(

Check warning on line 919 in src/kartograf/atom_mapper.py

View check run for this annotation

Codecov / codecov/patch

src/kartograf/atom_mapper.py#L919

Added line #L919 was not covered by tests
mol=molB,
mol_pos=molB_pos,
masked_atoms=masked_atoms_molB,
map_hydrogens=map_hydrogens,
)

# Calculate mapping
# distance matrix: - full graph
distance_matrix = self._get_full_distance_matrix(molA_pos, molB_pos)

Check warning on line 928 in src/kartograf/atom_mapper.py

View check run for this annotation

Codecov / codecov/patch

src/kartograf/atom_mapper.py#L928

Added line #L928 was not covered by tests

# Mask distance matrix with max_d
# np.inf is considererd as not possible in lsa implementation - therefore use a high value
self.mask_dist_val = max_d * 10**6
masked_dmatrix = np.array(

Check warning on line 933 in src/kartograf/atom_mapper.py

View check run for this annotation

Codecov / codecov/patch

src/kartograf/atom_mapper.py#L932-L933

Added lines #L932 - L933 were not covered by tests
np.ma.where(distance_matrix < max_d, distance_matrix, self.mask_dist_val)
)

# solve atom mappings
mapping = self.mapping_algorithm(

Check warning on line 938 in src/kartograf/atom_mapper.py

View check run for this annotation

Codecov / codecov/patch

src/kartograf/atom_mapper.py#L938

Added line #L938 was not covered by tests
distance_matrix=masked_dmatrix, max_dist=self.mask_dist_val
)

# reverse any prior masking:
mapping = {

Check warning on line 943 in src/kartograf/atom_mapper.py

View check run for this annotation

Codecov / codecov/patch

src/kartograf/atom_mapper.py#L943

Added line #L943 was not covered by tests
molA_masked_atomMapping[k]: molB_masked_atomMapping[v]
for k, v in mapping.items()
}

# filter mapping for rules:
if self._filter_funcs is not None:
mapping = self._additional_filter_rules(molA, molB, mapping)

Check warning on line 950 in src/kartograf/atom_mapper.py

View check run for this annotation

Codecov / codecov/patch

src/kartograf/atom_mapper.py#L949-L950

Added lines #L949 - L950 were not covered by tests

return mapping

Check warning on line 952 in src/kartograf/atom_mapper.py

View check run for this annotation

Codecov / codecov/patch

src/kartograf/atom_mapper.py#L952

Added line #L952 was not covered by tests


def suggest_multistate_mapping(self, molecules: Iterable[SmallMoleculeComponent],
max_d: float = 0.95,
map_hydrogens: bool = True,
greedy = True
):

#Todo: ensure unique mol names

if(greedy==True):
masks = []
positions = []
for comp in molecules:
mol = comp.to_rdkit()
conf = mol.GetConformer()
pos = conf.GetPositions()
m, mpos = self._mask_atoms(mol, pos,

Check warning on line 970 in src/kartograf/atom_mapper.py

View check run for this annotation

Codecov / codecov/patch

src/kartograf/atom_mapper.py#L963-L970

Added lines #L963 - L970 were not covered by tests
map_hydrogens=map_hydrogens,
masked_atoms=[])

masks.append(m)
positions.append(mpos)

Check warning on line 975 in src/kartograf/atom_mapper.py

View check run for this annotation

Codecov / codecov/patch

src/kartograf/atom_mapper.py#L974-L975

Added lines #L974 - L975 were not covered by tests

multi_state_mapping = self._multi_state_greedy_dist_approach(components=molecules, positions=positions, masks=masks)

Check warning on line 977 in src/kartograf/atom_mapper.py

View check run for this annotation

Codecov / codecov/patch

src/kartograf/atom_mapper.py#L977

Added line #L977 was not covered by tests
else:
# calculate all mappings: - not working atm
mappings = []
for cA in molecules:
for cB in molecules:
if (cA != cB):
mapping = self._raw_mapping(cA.to_rdkit(), cB.to_rdkit(),

Check warning on line 984 in src/kartograf/atom_mapper.py

View check run for this annotation

Codecov / codecov/patch

src/kartograf/atom_mapper.py#L980-L984

Added lines #L980 - L984 were not covered by tests
max_d=max_d, map_hydrogens=map_hydrogens)
mappings.append(LigandAtomMapping(componentA=cA, componentB=cB,

Check warning on line 986 in src/kartograf/atom_mapper.py

View check run for this annotation

Codecov / codecov/patch

src/kartograf/atom_mapper.py#L986

Added line #L986 was not covered by tests
componentA_to_componentB=mapping))

#merge mappings:
multi_state_mapping = self._merge_mappings_to_multistate_mapping(mappings=mappings)
multi_state_mapping = list(filter(lambda x: len(x) == len(molecules), multi_state_mapping))
print("raw map", multi_state_mapping)

Check warning on line 992 in src/kartograf/atom_mapper.py

View check run for this annotation

Codecov / codecov/patch

src/kartograf/atom_mapper.py#L990-L992

Added lines #L990 - L992 were not covered by tests

#Filter Mappings
'''

Check warning on line 995 in src/kartograf/atom_mapper.py

View check run for this annotation

Codecov / codecov/patch

src/kartograf/atom_mapper.py#L995

Added line #L995 was not covered by tests
Filter all pairs
'''
filtered_raw_mapping = copy.deepcopy(multi_state_mapping)
tmp_filtered_raw_mapping = []
for ligA in molecules:
for ligB in molecules:
if(ligA == ligB):
continue

Check warning on line 1003 in src/kartograf/atom_mapper.py

View check run for this annotation

Codecov / codecov/patch

src/kartograf/atom_mapper.py#L998-L1003

Added lines #L998 - L1003 were not covered by tests
else:
mappingAB = {m[ligA.name]:m[ligB.name] for m in

Check warning on line 1005 in src/kartograf/atom_mapper.py

View check run for this annotation

Codecov / codecov/patch

src/kartograf/atom_mapper.py#L1005

Added line #L1005 was not covered by tests
multi_state_mapping}
mapping = self._additional_filter_rules(ligA.to_rdkit(),

Check warning on line 1007 in src/kartograf/atom_mapper.py

View check run for this annotation

Codecov / codecov/patch

src/kartograf/atom_mapper.py#L1007

Added line #L1007 was not covered by tests
ligB.to_rdkit(),
mappingAB)

ligA_present = list(mapping.keys())

Check warning on line 1011 in src/kartograf/atom_mapper.py

View check run for this annotation

Codecov / codecov/patch

src/kartograf/atom_mapper.py#L1011

Added line #L1011 was not covered by tests

for m in filtered_raw_mapping:
if(m[ligA.name] in ligA_present):
tmp_filtered_raw_mapping.append(m)
filtered_raw_mapping = tmp_filtered_raw_mapping
tmp_filtered_raw_mapping = []
print("filtered map", multi_state_mapping)

Check warning on line 1018 in src/kartograf/atom_mapper.py

View check run for this annotation

Codecov / codecov/patch

src/kartograf/atom_mapper.py#L1013-L1018

Added lines #L1013 - L1018 were not covered by tests


# Get Core Region
# get all connected sets
connected_sets = {}
for component in molecules:
map_atom_ids = [m[component.name] for m in filtered_raw_mapping]
connected_set = self._get_connected_atom_subsets(component.to_rdkit(), map_atom_ids)
connected_sets[component.name] = connected_set

Check warning on line 1027 in src/kartograf/atom_mapper.py

View check run for this annotation

Codecov / codecov/patch

src/kartograf/atom_mapper.py#L1023-L1027

Added lines #L1023 - L1027 were not covered by tests

# translate mappings into mapping set - tuples
mapping_connected_sets = defaultdict(dict)
for i, m in enumerate(filtered_raw_mapping):
mapping_connected_sets[i] = []
mid = 0
for k, (lig, aid) in enumerate(m.items()):
connected = connected_sets[lig]
for j, s in enumerate(connected):
mid += 1
if (aid in s):
mapping_connected_sets[i].append(mid)
break
print("Connected Set", mapping_connected_sets)

Check warning on line 1041 in src/kartograf/atom_mapper.py

View check run for this annotation

Codecov / codecov/patch

src/kartograf/atom_mapper.py#L1030-L1041

Added lines #L1030 - L1041 were not covered by tests

# max overlap
tup = [tuple(m) for m in mapping_connected_sets.values()]
combinations, counts = np.unique(tup, return_counts=True, axis=0)
max_overlap_tuple = tuple(combinations[list(counts).index(max(counts))])
print(max_overlap_tuple)

Check warning on line 1047 in src/kartograf/atom_mapper.py

View check run for this annotation

Codecov / codecov/patch

src/kartograf/atom_mapper.py#L1044-L1047

Added lines #L1044 - L1047 were not covered by tests

# FIlter for max overlap
filter_map = []
for mid, mapping_sets in mapping_connected_sets.items():
if (max_overlap_tuple == tuple(mapping_sets)):
filter_map.append(filtered_raw_mapping[mid])

Check warning on line 1053 in src/kartograf/atom_mapper.py

View check run for this annotation

Codecov / codecov/patch

src/kartograf/atom_mapper.py#L1050-L1053

Added lines #L1050 - L1053 were not covered by tests

print(filter_map)

Check warning on line 1055 in src/kartograf/atom_mapper.py

View check run for this annotation

Codecov / codecov/patch

src/kartograf/atom_mapper.py#L1055

Added line #L1055 was not covered by tests


return filter_map

Check warning on line 1058 in src/kartograf/atom_mapper.py

View check run for this annotation

Codecov / codecov/patch

src/kartograf/atom_mapper.py#L1058

Added line #L1058 was not covered by tests


def _multi_state_greedy_dist_approach(self, components, positions, masks):
# build ndDistmatrix
euclidean_dist = lambda v: np.sqrt(np.sum(np.square(v), axis=1))

Check warning on line 1063 in src/kartograf/atom_mapper.py

View check run for this annotation

Codecov / codecov/patch

src/kartograf/atom_mapper.py#L1063

Added line #L1063 was not covered by tests

# Calculate and pre-filter long distances
# mol_distance_matrix[molA][atomI][molB_pos]
# = distance between MolA_atomI to molB_atomJ

mol_distance_matrix = []
for molA_atomI_id, molA_pos in enumerate(positions): #MolA
molA_distances = []
for molA_atomI in molA_pos: #Atom of MolA
molA_atomI_distances = []
for molB_id, molB_pos in enumerate(positions): #MolB
if (molA_atomI_id == molB_id):
continue

Check warning on line 1076 in src/kartograf/atom_mapper.py

View check run for this annotation

Codecov / codecov/patch

src/kartograf/atom_mapper.py#L1069-L1076

Added lines #L1069 - L1076 were not covered by tests
else:
molAB_atomI_distances = euclidean_dist(molB_pos - molA_atomI)
molAB_atomI_distances[molAB_atomI_distances > 0.95] = np.inf
molA_atomI_distances.append(molAB_atomI_distances)
molA_distances.append(np.array(molA_atomI_distances))
mol_distance_matrix.append(molA_distances)

Check warning on line 1082 in src/kartograf/atom_mapper.py

View check run for this annotation

Codecov / codecov/patch

src/kartograf/atom_mapper.py#L1078-L1082

Added lines #L1078 - L1082 were not covered by tests

# Calculate raw mappings in N Dimensions, collect tuples and all dists
distance_tuples = defaultdict(list)
for molA_id, molA in enumerate(mol_distance_matrix):
for molA_atomI_id, molA_atomI in enumerate(molA):

Check warning on line 1087 in src/kartograf/atom_mapper.py

View check run for this annotation

Codecov / codecov/patch

src/kartograf/atom_mapper.py#L1085-L1087

Added lines #L1085 - L1087 were not covered by tests
# all inf dist in mols? - no mapping possible
if (any([np.all(np.inf == m_dist) for m_dist in molA_atomI])):
continue

Check warning on line 1090 in src/kartograf/atom_mapper.py

View check run for this annotation

Codecov / codecov/patch

src/kartograf/atom_mapper.py#L1089-L1090

Added lines #L1089 - L1090 were not covered by tests
else:
# Filter only for possible atoms - sparse graph
possible_mappings = []
for molB_distances in molA_atomI:
possible_molB_atom_ids = np.where(molB_distances != np.inf)
possible_mappings.append(np.vstack([possible_molB_atom_ids, molB_distances[possible_molB_atom_ids]]).T)

Check warning on line 1096 in src/kartograf/atom_mapper.py

View check run for this annotation

Codecov / codecov/patch

src/kartograf/atom_mapper.py#L1093-L1096

Added lines #L1093 - L1096 were not covered by tests

# calculate all possible tuples and their sum dist for
# atom molA_atomI.
possible_mappings_id = [list(map(int, a[:, 0])) for a in possible_mappings]
for multi_mapping_atom_ids in product(*possible_mappings_id):
multi_mapping_distance = 0
for k, t in enumerate(multi_mapping_atom_ids):
ti = np.squeeze(np.where(possible_mappings[k][:, 0] == t))
molAB_atomI_distances = np.squeeze(possible_mappings[k][ti, 1])
if isinstance(molAB_atomI_distances, float):
multi_mapping_distance = molAB_atomI_distances

Check warning on line 1107 in src/kartograf/atom_mapper.py

View check run for this annotation

Codecov / codecov/patch

src/kartograf/atom_mapper.py#L1100-L1107

Added lines #L1100 - L1107 were not covered by tests
else:
multi_mapping_distance = np.mean(molAB_atomI_distances)
multi_mapping_atom_ids = list(multi_mapping_atom_ids)
multi_mapping_atom_ids.insert(molA_id, molA_atomI_id)
multi_mapping_atom_ids = tuple(multi_mapping_atom_ids)
distance_tuples[multi_mapping_atom_ids].append(multi_mapping_distance)

Check warning on line 1113 in src/kartograf/atom_mapper.py

View check run for this annotation

Codecov / codecov/patch

src/kartograf/atom_mapper.py#L1109-L1113

Added lines #L1109 - L1113 were not covered by tests

# convolute all distances of mutli atom tuple selection
distance_tuples = {tuple(k): np.sum(v) for k, v in distance_tuples.items()}

Check warning on line 1116 in src/kartograf/atom_mapper.py

View check run for this annotation

Codecov / codecov/patch

src/kartograf/atom_mapper.py#L1116

Added line #L1116 was not covered by tests

# select mapping
already_selected_atoms = []
multistate_atom_mapping = []
for multi_mapping_atom_ids, dist in sorted(distance_tuples.items(), key=lambda x: x[1]):
check_atomIDs = [(i, t) for i, t in enumerate(multi_mapping_atom_ids)]
if (any([ct in already_selected_atoms for ct in check_atomIDs])):
continue

Check warning on line 1124 in src/kartograf/atom_mapper.py

View check run for this annotation

Codecov / codecov/patch

src/kartograf/atom_mapper.py#L1119-L1124

Added lines #L1119 - L1124 were not covered by tests
else:
multistate_atom_mapping.append({components[i].name: masks[i][t] for i, t in check_atomIDs})
already_selected_atoms.extend(check_atomIDs)

Check warning on line 1127 in src/kartograf/atom_mapper.py

View check run for this annotation

Codecov / codecov/patch

src/kartograf/atom_mapper.py#L1126-L1127

Added lines #L1126 - L1127 were not covered by tests

return multistate_atom_mapping

Check warning on line 1129 in src/kartograf/atom_mapper.py

View check run for this annotation

Codecov / codecov/patch

src/kartograf/atom_mapper.py#L1129

Added line #L1129 was not covered by tests

def _merge_mappings_to_multistate_mapping(self, mappings, _only_all_state_mappings: bool = True) -> Iterable[
dict[str, int]]:
# reformat mappings
components = []
found_mappings = []
for m in mappings:
components.extend([m.componentA, m.componentB])
for aa, ab in m.componentA_to_componentB.items():
found_mappings.append({m.componentA.name: aa, m.componentB.name: ab})
components = list(set(components))

Check warning on line 1140 in src/kartograf/atom_mapper.py

View check run for this annotation

Codecov / codecov/patch

src/kartograf/atom_mapper.py#L1134-L1140

Added lines #L1134 - L1140 were not covered by tests

# convolute:
unique_ms_atom_mappings = []
for atom_mapping_tuple in found_mappings:
all_am_related_tuples = list(atom_mapping_tuple.items())
for mapTupB in found_mappings:
if any([k in mapTupB and mapTupB[k] == v for k, v in atom_mapping_tuple.items()]):
all_am_related_tuples.extend(list(mapTupB.items()))

Check warning on line 1148 in src/kartograf/atom_mapper.py

View check run for this annotation

Codecov / codecov/patch

src/kartograf/atom_mapper.py#L1143-L1148

Added lines #L1143 - L1148 were not covered by tests

# unique and sorted:
unique_ms_map = tuple(sorted(set(all_am_related_tuples)))
unique_ms_atom_mappings.append(unique_ms_map)
unique_ms_atom_mappings = list(set(unique_ms_atom_mappings))

Check warning on line 1153 in src/kartograf/atom_mapper.py

View check run for this annotation

Codecov / codecov/patch

src/kartograf/atom_mapper.py#L1151-L1153

Added lines #L1151 - L1153 were not covered by tests

# Filter step: only all state mappings
if (_only_all_state_mappings):
multi_state_mapping = list(filter(lambda x: len(x) == len(components), unique_ms_atom_mappings))

Check warning on line 1157 in src/kartograf/atom_mapper.py

View check run for this annotation

Codecov / codecov/patch

src/kartograf/atom_mapper.py#L1156-L1157

Added lines #L1156 - L1157 were not covered by tests
else:
multi_state_mapping = unique_ms_atom_mappings

Check warning on line 1159 in src/kartograf/atom_mapper.py

View check run for this annotation

Codecov / codecov/patch

src/kartograf/atom_mapper.py#L1159

Added line #L1159 was not covered by tests

return list(map(dict, multi_state_mapping))

Check warning on line 1161 in src/kartograf/atom_mapper.py

View check run for this annotation

Codecov / codecov/patch

src/kartograf/atom_mapper.py#L1161

Added line #L1161 was not covered by tests
60 changes: 60 additions & 0 deletions src/kartograf/utils/multistate_visualization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@


'''
2D
'''

from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem import AllChem
from IPython.display import Image, display


def visualize_multistate_mappings_2D(components, multi_state_mapping, ncols=5):
nrows = len(components) // ncols
nrows = nrows if (len(components) % ncols == 0) else nrows + 1
grid_x = ncols
grid_y = nrows
d2d = Draw.rdMolDraw2D.MolDraw2DCairo(grid_x * 500, grid_y * 500, 500, 500)

# squash to 2D
copies = [Chem.Mol(mol.to_rdkit()) for mol in components]
for mol in copies:
AllChem.Compute2DCoords(mol)

# mol alignments if atom_mapping present
ref_mol = copies[0]
for mobile_mol in copies[1:]:
atomMap = []
for ms_map in multi_state_mapping:
atomMap.append((ms_map[mobile_mol.GetProp("_Name")],
ms_map[ref_mol.GetProp("_Name")]))

AllChem.AlignMol(mobile_mol, ref_mol, atomMap=atomMap)

atom_lists = []
for c in components:
lig_maps = []
for m in multi_state_mapping:
lig_maps.append(m[c.name])
atom_lists.append(lig_maps)

RED = (220 / 255, 50 / 255, 32 / 255, 1.0)
# standard settings for our visualization
d2d.drawOptions().useBWAtomPalette()
d2d.drawOptions().continousHighlight = False
d2d.drawOptions().setHighlightColour(RED)
d2d.drawOptions().addAtomIndices = True
d2d.DrawMolecules(
copies,
highlightAtoms=atom_lists,
# highlightBonds=bonds_list,
# highlightAtomColors=atom_colors,
# highlightBondColors=bond_colors,
)
d2d.FinishDrawing()

return Image(d2d.GetDrawingText())



Loading
Loading