-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathsmiles2graph.py
113 lines (92 loc) · 3.79 KB
/
smiles2graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from rdkit import Chem
import numpy as np
import pandas as pd
import torch
import torch_geometric
from torch_geometric.data import Data
from dgllife.utils import *
def atom_to_feature_vector(atom):
"""
Converts rdkit atom object to feature list of indices
:param mol: rdkit atom object
:return: list
8 features are canonical, 2 features are from OGB
"""
featurizer_funcs = ConcatFeaturizer([atom_type_one_hot,
atom_degree_one_hot,
atom_implicit_valence_one_hot,
atom_formal_charge,
atom_num_radical_electrons,
atom_hybridization_one_hot,
atom_is_aromatic,
atom_total_num_H_one_hot,
atom_is_in_ring,
atom_chirality_type_one_hot,
])
atom_feature = featurizer_funcs(atom)
return atom_feature
def bond_to_feature_vector(bond):
"""
Converts rdkit bond object to feature list of indices
:param mol: rdkit bond object
:return: list
"""
featurizer_funcs = ConcatFeaturizer([bond_type_one_hot,
# bond_is_conjugated,
# bond_is_in_ring,
# bond_stereo_one_hot,
])
bond_feature = featurizer_funcs(bond)
return bond_feature
def smiles2graph(mol):
"""
Converts SMILES string or rdkit's mol object to graph Data object without remove salt
:input: SMILES string (str)
:return: graph object
"""
if isinstance(mol, Chem.rdchem.Mol):
pass
else:
mol = Chem.MolFromSmiles(mol)
# atoms
atom_features_list = []
for atom in mol.GetAtoms():
atom_features_list.append(atom_to_feature_vector(atom))
x = np.array(atom_features_list, dtype=np.int64)
# bonds
num_bond_features = 3 # bond type, bond stereo, is_conjugated
if len(mol.GetBonds()) > 0: # mol has bonds
edges_list = []
edge_features_list = []
for bond in mol.GetBonds():
i = bond.GetBeginAtomIdx()
j = bond.GetEndAtomIdx()
edge_feature = bond_to_feature_vector(bond)
# add edges in both directions
edges_list.append((i, j))
edge_features_list.append(edge_feature)
edges_list.append((j, i))
edge_features_list.append(edge_feature)
# data.edge_index: Graph connectivity in COO format with shape [2, num_edges]
edge_index = np.array(edges_list, dtype=np.int64).T
# data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
edge_attr = np.array(edge_features_list, dtype=np.int64)
else: # mol has no bonds
edge_index = np.empty((2, 0), dtype=np.int64)
edge_attr = np.empty((0, num_bond_features), dtype=np.int64)
graph = Data(x=torch.tensor(x, dtype=torch.float),
edge_index=torch.tensor(edge_index, dtype=torch.long),
edge_attr=torch.tensor(edge_attr), dtype=torch.float)
return graph
def save_drug_graph():
smiles = pd.read_csv('./data/IC50_GDSC/drug_smiles.csv')
drug_dict = {}
for i in range(len(smiles)):
drug_dict[smiles.iloc[i, 0]] = smiles2graph(smiles.iloc[i, 2])
np.save('./data/feature/drug_feature_graph.npy', drug_dict)
return drug_dict
if __name__ == '__main__':
graph = smiles2graph('O1C=C[C@H]([C@H]1O2)c3c2cc(OC)c4c3OC(=O)C5=C4CCC(=O)5')
print(graph.x.shape)
print(graph.edge_attr.shape)
# save_drug_graph()