diff --git a/Data_encoding.py b/Data_encoding.py new file mode 100644 index 0000000..230b5b3 --- /dev/null +++ b/Data_encoding.py @@ -0,0 +1,562 @@ +import os +import csv +from pubchempy import * +import numpy as np +import numbers +import h5py +import math +import pandas as pd +import json,pickle +from collections import OrderedDict +from rdkit import Chem +from rdkit.Chem import MolFromSmiles +import networkx as nx +from Model_utils import * +import random +import pickle +import sys +import matplotlib.pyplot as plt +import argparse + +def is_not_float(string_list): + try: + for string in string_list: + float(string) + return False + except: + return True + +""" +The following 4 function is used to preprocess the drug data. We download the drug list manually, and download the SMILES format using pubchempy. Since this part is time consuming, I write the cids and SMILES into a csv file. +""" + +folder = "data/" +#folder = "" + +def load_drug_list(): + filename = folder + "Druglist.csv" + csvfile = open(filename, "rb") + reader = csv.reader(csvfile) + next(reader, None) + drugs = [] + for line in reader: + drugs.append(line[0]) + drugs = list(set(drugs)) + return drugs + +def write_drug_cid(): + drugs = load_drug_list() + drug_id = [] + datas = [] + outputfile = open(folder + 'pychem_cid.csv', 'wb') + wr = csv.writer(outputfile) + unknow_drug = [] + for drug in drugs: + c = get_compounds(drug, 'name') + if drug.isdigit(): + cid = int(drug) + elif len(c) == 0: + unknow_drug.append(drug) + continue + else: + cid = c[0].cid + print(drug, cid) + drug_id.append(cid) + row = [drug, str(cid)] + wr.writerow(row) + outputfile.close() + outputfile = open(folder + "unknow_drug_by_pychem.csv", 'wb') + wr = csv.writer(outputfile) + wr.writerow(unknow_drug) + +def cid_from_other_source(): + """ + some drug can not be found in pychem, so I try to find some cid manually. + the small_molecule.csv is downloaded from http://lincs.hms.harvard.edu/db/sm/ + """ + f = open(folder + "small_molecule.csv", 'r') + reader = csv.reader(f) + reader.next() + cid_dict = {} + for item in reader: + name = item[1] + cid = item[4] + if not name in cid_dict: + cid_dict[name] = str(cid) + + unknow_drug = open(folder + "unknow_drug_by_pychem.csv").readline().split(",") + drug_cid_dict = {k:v for k,v in cid_dict.iteritems() if k in unknow_drug and not is_not_float([v])} + return drug_cid_dict + +def load_cid_dict(): + reader = csv.reader(open(folder + "pychem_cid.csv")) + pychem_dict = {} + for item in reader: + pychem_dict[item[0]] = item[1] + pychem_dict.update(cid_from_other_source()) + return pychem_dict + + +def download_smiles(): + cids_dict = load_cid_dict() + cids = [v for k,v in cids_dict.iteritems()] + inv_cids_dict = {v:k for k,v in cids_dict.iteritems()} + download('CSV', folder + 'drug_smiles.csv', cids, operation='property/CanonicalSMILES,IsomericSMILES', overwrite=True) + f = open(folder + 'drug_smiles.csv') + reader = csv.reader(f) + header = ['name'] + reader.next() + content = [] + for line in reader: + content.append([inv_cids_dict[line[0]]] + line) + f.close() + f = open(folder + "drug_smiles.csv", "w") + writer = csv.writer(f) + writer.writerow(header) + for item in content: + writer.writerow(item) + f.close() + +""" +The following code will convert the SMILES format into onehot format +""" + +def atom_features(atom): + return np.array(one_of_k_encoding_unk(atom.GetSymbol(),['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na','Ca', 'Fe', 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb','Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti', 'Zn', 'H','Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn', 'Zr','Cr', 'Pt', 'Hg', 'Pb', 'Unknown']) + + one_of_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5, 6,7,8,9,10]) + + one_of_k_encoding_unk(atom.GetTotalNumHs(), [0, 1, 2, 3, 4, 5, 6,7,8,9,10]) + + one_of_k_encoding_unk(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6,7,8,9,10]) + + [atom.GetIsAromatic()]) + +def one_of_k_encoding(x, allowable_set): + if x not in allowable_set: + raise Exception("input {0} not in allowable set{1}:".format(x, allowable_set)) + return list(map(lambda s: x == s, allowable_set)) + +def one_of_k_encoding_unk(x, allowable_set): + """Maps inputs not in the allowable set to the last element.""" + if x not in allowable_set: + x = allowable_set[-1] + return list(map(lambda s: x == s, allowable_set)) + +def smile_to_graph(smile): + mol = Chem.MolFromSmiles(smile) + + c_size = mol.GetNumAtoms() + + features = [] + for atom in mol.GetAtoms(): + feature = atom_features(atom) + features.append( feature / sum(feature) ) + + edges = [] + for bond in mol.GetBonds(): + edges.append([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()]) + g = nx.Graph(edges).to_directed() + edge_index = [] + for e1, e2 in g.edges: + edge_index.append([e1, e2]) + + return c_size, features, edge_index + +def load_drug_smile(): + reader = csv.reader(open(folder + "drug_smiles.csv")) + next(reader, None) + + drug_dict = {} + drug_smile = [] + + for item in reader: + name = item[0] + smile = item[2] + + if name in drug_dict: + pos = drug_dict[name] + else: + pos = len(drug_dict) + drug_dict[name] = pos + drug_smile.append(smile) + + smile_graph = {} + for smile in drug_smile: + g = smile_to_graph(smile) + smile_graph[smile] = g + + return drug_dict, drug_smile, smile_graph + +def save_cell_mut_matrix(): + f = open(folder + "PANCANCER_Genetic_feature.csv") + reader = csv.reader(f) + next(reader) + features = {} + cell_dict = {} + mut_dict = {} + matrix_list = [] + + for item in reader: + cell_id = item[1] + mut = item[5] + is_mutated = int(item[6]) + + if mut in mut_dict: + col = mut_dict[mut] + else: + col = len(mut_dict) + mut_dict[mut] = col + + if cell_id in cell_dict: + row = cell_dict[cell_id] + else: + row = len(cell_dict) + cell_dict[cell_id] = row + if is_mutated == 1: + matrix_list.append((row, col)) + + cell_feature = np.zeros((len(cell_dict), len(mut_dict))) + + for item in matrix_list: + cell_feature[item[0], item[1]] = 1 + + with open('mut_dict', 'wb') as fp: + pickle.dump(mut_dict, fp) + + return cell_dict, cell_feature + + +""" +This part is used to extract the drug - cell interaction strength. it contains IC50, AUC, Max conc, RMSE, Z_score +""" +def save_mix_drug_cell_matrix(): + f = open(folder + "PANCANCER_IC.csv") + reader = csv.reader(f) + next(reader) + + cell_dict, cell_feature = save_cell_mut_matrix() + drug_dict, drug_smile, smile_graph = load_drug_smile() + + temp_data = [] + bExist = np.zeros((len(drug_dict), len(cell_dict))) + + for item in reader: + drug = item[0] + cell = item[3] + ic50 = item[8] + ic50 = 1 / (1 + pow(math.exp(float(ic50)), -0.1)) + temp_data.append((drug, cell, ic50)) + + xd = [] + xc = [] + y = [] + lst_drug = [] + lst_cell = [] + random.shuffle(temp_data) + for data in temp_data: + drug, cell, ic50 = data + if drug in drug_dict and cell in cell_dict: + xd.append(drug_smile[drug_dict[drug]]) + xc.append(cell_feature[cell_dict[cell]]) + y.append(ic50) + bExist[drug_dict[drug], cell_dict[cell]] = 1 + lst_drug.append(drug) + lst_cell.append(cell) + + with open('drug_dict', 'wb') as fp: + pickle.dump(drug_dict, fp) + + xd, xc, y = np.asarray(xd), np.asarray(xc), np.asarray(y) + + size = int(xd.shape[0] * 0.95) + size1 = int(xd.shape[0] * 0.95) + size2 = int(xd.shape[0] * 0.1) + size3 = int(xd.shape[0] * 0.2) + size4 = int(xd.shape[0] * 0.3) + size5 = int(xd.shape[0] * 0.4) + size6 = int(xd.shape[0] * 0.5) + size7 = int(xd.shape[0] * 0.6) + size8 = int(xd.shape[0] * 0.7) + size9 = int(xd.shape[0] * 0.8) + + np.save('list_drug_mix_test', lst_drug[size5:size7]) + np.save('list_cell_mix_test', lst_cell[size5:size7]) + + with open('list_drug_mix_test', 'wb') as fp: + pickle.dump(lst_drug[size5:size7], fp) + + with open('list_cell_mix_test', 'wb') as fp: + pickle.dump(lst_cell[size5:size7], fp) + + xd_train = xd[:size] + xd_val = xd[size:] + xd_test = xd[size5:size7] + + xc_train = xc[:size] + xc_val = xc[size:] + xc_test = xc[size5:size7] + + y_train = y[:size] + y_val = y[size:] + y_test = y[size5:size7] + + dataset = 'GDSC' + print('preparing ', dataset + '_train.pt in pytorch format!') + + train_data = TestbedDataset(root='data', dataset=dataset+'_train_mix', xd=xd_train, xt=xc_train, y=y_train, smile_graph=smile_graph) + val_data = TestbedDataset(root='data', dataset=dataset+'_val_mix', xd=xd_val, xt=xc_val, y=y_val, smile_graph=smile_graph) + test_data = TestbedDataset(root='data', dataset=dataset+'_test_mix', xd=xd_test, xt=xc_test, y=y_test, smile_graph=smile_graph) + + +def save_blind_drug_matrix(): + f = open(folder + "PANCANCER_IC.csv") + reader = csv.reader(f) + next(reader) + + cell_dict, cell_feature = save_cell_mut_matrix() + drug_dict, drug_smile, smile_graph = load_drug_smile() + + matrix_list = [] + + temp_data = [] + + xd_train = [] + xc_train = [] + y_train = [] + + xd_val = [] + xc_val = [] + y_val = [] + + xd_test = [] + xc_test = [] + y_test = [] + + xd_unknown = [] + xc_unknown = [] + y_unknown = [] + + dict_drug_cell = {} + + bExist = np.zeros((len(drug_dict), len(cell_dict))) + + for item in reader: + drug = item[0] + cell = item[3] + ic50 = item[8] + ic50 = 1 / (1 + pow(math.exp(float(ic50)), -0.1)) + + temp_data.append((drug, cell, ic50)) + + random.shuffle(temp_data) + + for data in temp_data: + drug, cell, ic50 = data + if drug in drug_dict and cell in cell_dict: + if drug in dict_drug_cell: + dict_drug_cell[drug].append((cell, ic50)) + else: + dict_drug_cell[drug] = [(cell, ic50)] + + bExist[drug_dict[drug], cell_dict[cell]] = 1 + + lstDrugTest = [] + + size = int(len(dict_drug_cell) * 0.8) + size1 = int(len(dict_drug_cell) * 0.9) + pos = 0 + for drug,values in dict_drug_cell.items(): + pos += 1 + for v in values: + cell, ic50 = v + if pos < size: + xd_train.append(drug_smile[drug_dict[drug]]) + xc_train.append(cell_feature[cell_dict[cell]]) + y_train.append(ic50) + elif pos < size1: + xd_val.append(drug_smile[drug_dict[drug]]) + xc_val.append(cell_feature[cell_dict[cell]]) + y_val.append(ic50) + else: + xd_test.append(drug_smile[drug_dict[drug]]) + xc_test.append(cell_feature[cell_dict[cell]]) + y_test.append(ic50) + lstDrugTest.append(drug) + + with open('drug_bind_test', 'wb') as fp: + pickle.dump(lstDrugTest, fp) + + print(len(y_train), len(y_val), len(y_test)) + + xd_train, xc_train, y_train = np.asarray(xd_train), np.asarray(xc_train), np.asarray(y_train) + xd_val, xc_val, y_val = np.asarray(xd_val), np.asarray(xc_val), np.asarray(y_val) + xd_test, xc_test, y_test = np.asarray(xd_test), np.asarray(xc_test), np.asarray(y_test) + + dataset = 'GDSC' + print('preparing ', dataset + '_train.pt in pytorch format!') + train_data = TestbedDataset(root='data', dataset=dataset+'_train_blind', xd=xd_train, xt=xc_train, y=y_train, smile_graph=smile_graph) + val_data = TestbedDataset(root='data', dataset=dataset+'_val_blind', xd=xd_val, xt=xc_val, y=y_val, smile_graph=smile_graph) + test_data = TestbedDataset(root='data', dataset=dataset+'_test_blind', xd=xd_test, xt=xc_test, y=y_test, smile_graph=smile_graph) + + +def save_blind_cell_matrix(): + f = open(folder + "PANCANCER_IC.csv") + reader = csv.reader(f) + next(reader) + + cell_dict, cell_feature = save_cell_mut_matrix() + drug_dict, drug_smile, smile_graph = load_drug_smile() + + matrix_list = [] + + temp_data = [] + + xd_train = [] + xc_train = [] + y_train = [] + + xd_val = [] + xc_val = [] + y_val = [] + + xd_test = [] + xc_test = [] + y_test = [] + + xd_unknown = [] + xc_unknown = [] + y_unknown = [] + + dict_drug_cell = {} + + bExist = np.zeros((len(drug_dict), len(cell_dict))) + + for item in reader: + drug = item[0] + cell = item[3] + ic50 = item[8] + ic50 = 1 / (1 + pow(math.exp(float(ic50)), -0.1)) + + temp_data.append((drug, cell, ic50)) + + random.shuffle(temp_data) + + for data in temp_data: + drug, cell, ic50 = data + if drug in drug_dict and cell in cell_dict: + if cell in dict_drug_cell: + dict_drug_cell[cell].append((drug, ic50)) + else: + dict_drug_cell[cell] = [(drug, ic50)] + + bExist[drug_dict[drug], cell_dict[cell]] = 1 + + lstCellTest = [] + + size = int(len(dict_drug_cell) * 0.8) + size1 = int(len(dict_drug_cell) * 0.9) + pos = 0 + for cell,values in dict_drug_cell.items(): + pos += 1 + for v in values: + drug, ic50 = v + if pos < size: + xd_train.append(drug_smile[drug_dict[drug]]) + xc_train.append(cell_feature[cell_dict[cell]]) + y_train.append(ic50) + elif pos < size1: + xd_val.append(drug_smile[drug_dict[drug]]) + xc_val.append(cell_feature[cell_dict[cell]]) + y_val.append(ic50) + else: + xd_test.append(drug_smile[drug_dict[drug]]) + xc_test.append(cell_feature[cell_dict[cell]]) + y_test.append(ic50) + lstCellTest.append(cell) + + with open('cell_bind_test', 'wb') as fp: + pickle.dump(lstCellTest, fp) + + print(len(y_train), len(y_val), len(y_test)) + + xd_train, xc_train, y_train = np.asarray(xd_train), np.asarray(xc_train), np.asarray(y_train) + xd_val, xc_val, y_val = np.asarray(xd_val), np.asarray(xc_val), np.asarray(y_val) + xd_test, xc_test, y_test = np.asarray(xd_test), np.asarray(xc_test), np.asarray(y_test) + + dataset = 'GDSC' + print('preparing ', dataset + '_train.pt in pytorch format!') + train_data = TestbedDataset(root='data', dataset=dataset+'_train_cell_blind', xd=xd_train, xt=xc_train, y=y_train, smile_graph=smile_graph) + val_data = TestbedDataset(root='data', dataset=dataset+'_val_cell_blind', xd=xd_val, xt=xc_val, y=y_val, smile_graph=smile_graph) + test_data = TestbedDataset(root='data', dataset=dataset+'_test_cell_blind', xd=xd_test, xt=xc_test, y=y_test, smile_graph=smile_graph) + +def save_best_individual_drug_cell_matrix(): + f = open(folder + "PANCANCER_IC.csv") + reader = csv.reader(f) + next(reader) + + cell_dict, cell_feature = save_cell_mut_matrix() + drug_dict, drug_smile, smile_graph = load_drug_smile() + + matrix_list = [] + + temp_data = [] + + xd_train = [] + xc_train = [] + y_train = [] + + dict_drug_cell = {} + + bExist = np.zeros((len(drug_dict), len(cell_dict))) + i=0 + for item in reader: + drug = item[0] + cell = item[3] + ic50 = item[8] + ic50 = 1 / (1 + pow(math.exp(float(ic50)), -0.1)) + + if drug == "Bortezomib": + temp_data.append((drug, cell, ic50)) + random.shuffle(temp_data) + + for data in temp_data: + drug, cell, ic50 = data + if drug in drug_dict and cell in cell_dict: + if drug in dict_drug_cell: + dict_drug_cell[drug].append((cell, ic50)) + else: + dict_drug_cell[drug] = [(cell, ic50)] + + bExist[drug_dict[drug], cell_dict[cell]] = 1 + cells = [] + for drug,values in dict_drug_cell.items(): + for v in values: + cell, ic50 = v + xd_train.append(drug_smile[drug_dict[drug]]) + xc_train.append(cell_feature[cell_dict[cell]]) + y_train.append(ic50) + cells.append(cell) + + xd_train, xc_train, y_train = np.asarray(xd_train), np.asarray(xc_train), np.asarray(y_train) + with open('cell_blind_sal', 'wb') as fp: + pickle.dump(cells, fp) + dataset = 'GDSC' + print('preparing ', dataset + '_train.pt in pytorch format!') + train_data = TestbedDataset(root='data', dataset=dataset+'_bortezomib', xd=xd_train, xt=xc_train, y=y_train, smile_graph=smile_graph, saliency_map=True) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='prepare dataset to train model') + parser.add_argument('--choice', type=int, required=False, default=0, help='0.mix test, 1.saliency value, 2.drug blind, 3.cell blind') + args = parser.parse_args() + choice = args.choice + if choice == 0: + # save mix test dataset + save_mix_drug_cell_matrix() + elif choice == 1: + # save saliency map dataset + save_best_individual_drug_cell_matrix() + elif choice == 2: + # save blind drug dataset + save_blind_drug_matrix() + elif choice == 3: + # save blind cell dataset + save_blind_cell_matrix() + else: + print("Invalide option, choose 0 -> 4") + \ No newline at end of file diff --git a/Model_training.py b/Model_training.py new file mode 100644 index 0000000..27ea304 --- /dev/null +++ b/Model_training.py @@ -0,0 +1,146 @@ +import numpy as np +import pandas as pd +import sys, os +from random import shuffle +import torch +import torch.nn as nn +from models.gat import GATNet +from models.gat_gcn import GAT_GCN +from models.gcn import GCNNet +from models.ginconv import GINConvNet +from models.DbTrs_ge import DbTrs_ge +from models.MTrsDRP import MTrsDRP +from Model_utils import * +import datetime +import argparse + +# training function at each epoch +def train(model, device, train_loader, optimizer, epoch, log_interval): + print('Training on {} samples...'.format(len(train_loader.dataset))) + model.train() + loss_fn = nn.MSELoss() + avg_loss = [] + for batch_idx, data in enumerate(train_loader): + data = data.to(device) + optimizer.zero_grad() + output, _ = model(data) + loss = loss_fn(output, data.y.view(-1, 1).float().to(device)) + loss.backward() + optimizer.step() + avg_loss.append(loss.item()) + if batch_idx % log_interval == 0: + print('Train epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, + batch_idx * len(data.x), + len(train_loader.dataset), + 100. * batch_idx / len(train_loader), + loss.item())) + return sum(avg_loss)/len(avg_loss) + +def predicting(model, device, loader): + model.eval() + total_preds = torch.Tensor() + total_labels = torch.Tensor() + print('Make prediction for {} samples...'.format(len(loader.dataset))) + with torch.no_grad(): + for data in loader: + data = data.to(device) + output, _ = model(data) + total_preds = torch.cat((total_preds, output.cpu()), 0) + total_labels = torch.cat((total_labels, data.y.view(-1, 1).cpu()), 0) + + labels = np.array(total_labels.cpu().detach().numpy().flatten(), dtype=np.float32) + labels = pd.DataFrame(labels) + labels.to_csv("log/labels_GDRP.csv") + + preds = np.array(total_preds.cpu().detach().numpy().flatten(), dtype=np.float32) + preds = pd.DataFrame(preds) + preds.to_csv("log/preds_GDRP.csv") + return total_labels.numpy().flatten(),total_preds.numpy().flatten() + +def main(modeling, train_batch, val_batch, test_batch, lr, num_epoch, log_interval, cuda_name): + + print('Learning rate: ', lr) + print('Epochs: ', num_epoch) + + model_st = modeling.__name__ + dataset = 'GDSC' + train_losses = [] + val_losses = [] + val_pearsons = [] + print('\nrunning on ', model_st + '_' + dataset ) + processed_data_file_train = 'data/processed/' + dataset + '_train_mix'+'.pt' + processed_data_file_val = 'data/processed/' + dataset + '_val_mix'+'.pt' + processed_data_file_test = 'data/processed/' + dataset + '_test_mix'+'.pt' + if ((not os.path.isfile(processed_data_file_train)) or (not os.path.isfile(processed_data_file_val)) or (not os.path.isfile(processed_data_file_test))): + print('please run create_data.py to prepare data in pytorch format!') + else: + train_data = TestbedDataset(root='data', dataset=dataset+'_train_mix') + val_data = TestbedDataset(root='data', dataset=dataset+'_val_mix') + test_data = TestbedDataset(root='data', dataset=dataset+'_test_mix') + + # make data PyTorch mini-batch processing ready + train_loader = DataLoader(train_data, batch_size=train_batch, shuffle=True) + val_loader = DataLoader(val_data, batch_size=val_batch, shuffle=False) + test_loader = DataLoader(test_data, batch_size=test_batch, shuffle=False) + print("CPU/GPU: ", torch.cuda.is_available()) + + # training the model + device = torch.device(cuda_name if torch.cuda.is_available() else "cpu") + print(device) + model = modeling().to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=lr) + best_mse = 1000 + best_pearson = 1 + best_epoch = -1 + model_file_name = 'model_' + model_st + '_' + dataset + '.model' + result_file_name = 'result_' + model_st + '_' + dataset + '.csv' + loss_fig_name = 'model_' + model_st + '_' + dataset + '_loss' + pearson_fig_name = 'model_' + model_st + '_' + dataset + '_pearson' + for epoch in range(num_epoch): + train_loss = train(model, device, train_loader, optimizer, epoch+1, log_interval) + G,P = predicting(model, device, val_loader) + ret = [rmse(G,P),mse(G,P),pearson(G,P),spearman(G,P)] + + G_test,P_test = predicting(model, device, test_loader) + ret_test = [rmse(G_test,P_test),mse(G_test,P_test),pearson(G_test,P_test),spearman(G_test,P_test)] + + train_losses.append(train_loss) + val_losses.append(ret[1]) + val_pearsons.append(ret[2]) + + if ret[1] 0: + while j >= 0: + if y[i] > y[j]: + z = z+1 + u = f[i] - f[j] + if u > 0: + S = S + 1 + elif u == 0: + S = S + 0.5 + j = j - 1 + i = i - 1 + j = i-1 + ci = S/z + return ci + +def draw_loss(train_losses, test_losses, title): + plt.figure() + plt.plot(train_losses, label='train loss') + plt.plot(test_losses, label='test loss') + + plt.title(title) + plt.xlabel('Epoch') + plt.ylabel('Loss') + plt.legend() + # save image + plt.savefig(title+".png") # should before show method + +def draw_pearson(pearsons, title): + plt.figure() + plt.plot(pearsons, label='test pearson') + + plt.title(title) + plt.xlabel('Epoch') + plt.ylabel('Pearson') + plt.legend() + # save image + plt.savefig(title+".png") # should before show method \ No newline at end of file diff --git a/Model_validation.py b/Model_validation.py new file mode 100644 index 0000000..1ffd4e3 --- /dev/null +++ b/Model_validation.py @@ -0,0 +1,106 @@ +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +from models.MTrsDRP import MTrsDRP +from Model_utils import * + +def predicting(model, device, loader): + total_preds = torch.Tensor() + total_labels = torch.Tensor() + data_drug = torch.Tensor() + data_mut = torch.Tensor() + print('Make prediction for {} samples...'.format(len(loader.dataset))) + with torch.no_grad(): + for data in loader: + data = data + + output, drug_data = model(data) + + return total_labels.numpy().flatten(), total_preds.numpy().flatten() + +model = BeTrsDRP_mut() +model.load_state_dict(torch.load('model_BeTrsDRP_mut_GDSC.model')) +model.eval() + +list_drug_mix_test = np.load('list_drug_mix_test.npy') +list_cell_mix_test = np.load('list_cell_mix_test.npy') + +dataset = 'GDSC' +test_batch = 32 +num_epoch = 1 + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +test_data = TestbedDataset(root='data', dataset=dataset+'_test_mix') +test_loader = DataLoader(test_data, batch_size=test_batch, shuffle=False) + +test_drug_count = {} +for i in range(len(list_drug_mix_test)): + if list_drug_mix_test[i] in test_drug_count: + test_drug_count[list_drug_mix_test[i]] += 1 + else: + test_drug_count[list_drug_mix_test[i]] = 1 + +test_drug_result = {} +for epoch in range(num_epoch): + G_test,P_test = predicting(model, device, test_loader) + rmse_res = np.sqrt((G_test - P_test)**2) + rp = np.corrcoef(G_test, P_test)[0,1] + for i in range(len(rmse_res)): + if list_drug_mix_test[i] in test_drug_result: + test_drug_result[list_drug_mix_test[i]] += rmse_res[i] + else: + test_drug_result[list_drug_mix_test[i]] = rmse_res[i] + +for key, value in test_drug_result.items(): + test_drug_result[key] /= (test_drug_count[key] * num_epoch) + +sorted_dict = {} +sorted_keys = sorted(test_drug_result, key=test_drug_result.get) + +for w in sorted_keys: + sorted_dict[w] = test_drug_result[w] + +first2pairs = {k: sorted_dict[k] for k in list(sorted_dict.keys())[:10]} +last2pairs = {k: sorted_dict[k] for k in list(sorted_dict.keys())[-10:]} + +list_drug_mix_test_reshape = list_drug_mix_test.reshape((list_drug_mix_test.shape[0],-1)) +list_cell_mix_test_reshape = list_cell_mix_test.reshape((list_cell_mix_test.shape[0],-1)) +G_test_reshape = G_test.reshape((G_test.shape[0],-1)) +P_test_reshape = P_test.reshape((P_test.shape[0],-1)) + +test_drug_pearson = np.concatenate((list_drug_mix_test_reshape, list_cell_mix_test_reshape, G_test_reshape, P_test_reshape), axis=1) +df = pd.DataFrame(test_drug_pearson, columns=['Drug', + 'Cell-line', 'Label', 'Predict']) +test_drug_pearson_result = {} +grouped_df = df.groupby('Drug') +for key, item in grouped_df: + test_drug_pearson_result[key] = pearson(grouped_df.get_group(key)['Label'].to_numpy().astype(np.float), grouped_df.get_group(key)['Predict'].to_numpy().astype(np.float)) + +sorted_pearson_dict = {} +sorted_pearson_keys = sorted(test_drug_pearson_result, key=test_drug_pearson_result.get) + +for w in sorted_pearson_keys: + sorted_pearson_dict[w] = test_drug_pearson_result[w] + +first2pairs_pearson = {k: sorted_pearson_dict[k] for k in list(sorted_pearson_dict.keys())[:10]} +last2pairs_pearson = {k: sorted_pearson_dict[k] for k in list(sorted_pearson_dict.keys())[-10:]} + +label = list(first2pairs.keys()) + ['', ''] + list(last2pairs.keys()) +values = list(first2pairs.values()) + [0, 0] + list(last2pairs.values()) + +plt.bar(label, values) +plt.xticks(rotation=90) +plt.ylabel('RMSE') +plt.title('MUT') +plt.savefig("Blind_rmse.png", bbox_inches='tight') + +label_pearson = list(first2pairs_pearson.keys()) + ['', ''] + list(last2pairs_pearson.keys()) +values_pearson = list(first2pairs_pearson.values()) + [0, 0] + list(last2pairs_pearson.values()) + +plt.bar(label_pearson, values_pearson) +plt.xticks(rotation=90) +plt.ylabel('CCp') +plt.title('MUT') +plt.savefig("Blind_ccp.png", bbox_inches='tight') \ No newline at end of file diff --git a/models/MTrsDRP.py b/models/MTrsDRP.py new file mode 100644 index 0000000..52fa20d --- /dev/null +++ b/models/MTrsDRP.py @@ -0,0 +1,282 @@ +""" +__coding__: utf-8 +__Author__: liaoxin +__Time__: 2022/8/21 14:30 +__File__: BeTrsDRP_mut.py +__remark__: +__Software__: PyCharm +""" +import torch +import torch.nn as nn +from torch.nn import Linear +from torch.nn import Sequential +from torch.nn import ReLU +import torch.nn.functional as F +from torch_geometric.nn import GINConv as GIN_layer +from torch_geometric.nn import GCNConv as GCN_layer +from torch_geometric.nn import GATConv as GAT_layer, BatchNorm, global_mean_pool, global_max_pool, global_add_pool +from torch_geometric.nn import global_mean_pool as gap +from torch_geometric.nn import global_max_pool as gmp +from torch import nn, einsum +from einops import rearrange + +DIST_KERNELS = { + 'exp': { + 'fn': lambda t: torch.exp(-t), + 'mask_value_fn': lambda t: torch.finfo(t.dtype).max + }, + 'softmax': { + 'fn': lambda t: torch.softmax(t, dim=-1), + 'mask_value_fn': lambda t: -torch.finfo(t.dtype).max + } +} + + +def exists(val): + return val is not None + + +def default(val, d): + return d if not exists(val) else val + + +class Residual(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x, **kwargs): + return x + self.fn(x, **kwargs) + + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + + def forward(self, x, **kwargs): + x = self.norm(x) + return self.fn(x, **kwargs) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4): + super().__init__() + dim_out = default(dim_out, dim) + self.net = nn.Sequential( + nn.Linear(dim, dim * mult), + nn.GELU(), + nn.Linear(dim * mult, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +class Attention(nn.Module): + def __init__(self, dim=78, heads=6, dim_head=13, Lg=0.5, Ld=0.5, La=1, dist_kernel_fn='exp'): + super().__init__() + inner_dim = dim_head * heads + self.heads = heads + self.scale = dim_head ** -0.5 + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) + self.to_out = nn.Linear(inner_dim, dim) + + self.La = La + self.Ld = Ld + self.Lg = Lg + + self.dist_kernel_fn = dist_kernel_fn + + def forward(self, x, mask=None, adjacency_mat=None, distance_mat=None): + h, La, Ld, Lg, dist_kernel_fn = self.heads, self.La, self.Ld, self.Lg, self.dist_kernel_fn + + qkv = self.to_qkv(x) + q, k, v = rearrange(qkv, 'b n (h qkv d) -> b h n qkv d', h=h, qkv=3).unbind(dim=-2) + dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale + + if exists(distance_mat): + distance_mat = rearrange(distance_mat, 'b i j -> b () i j') + + if exists(adjacency_mat): + adjacency_mat = rearrange(adjacency_mat, 'b i j -> b () i j') + + if exists(mask): + mask_value = torch.finfo(dots.dtype).max + mask = mask[:, None, :, None] * mask[:, None, None, :] + + # mask attention + dots.masked_fill_(~mask, -mask_value) + + if exists(adjacency_mat): + adjacency_mat.masked_fill_(~mask, 0.) + + attn = dots.softmax(dim=-1) + + # sum contributions from adjacency and distance tensors + attn = attn * La + + if exists(adjacency_mat): + attn = attn + Lg * adjacency_mat + + out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + + +class MAT(nn.Module): + def __init__( + self, + *, + dim_in=78, + model_dim=78, + dim_out=78, + depth=1, + heads=6, + Lg=0.5, + Ld=0.5, + La=1, + dist_kernel_fn='exp' + ): + super().__init__() + + self.embed_to_model = nn.Linear(dim_in, model_dim) + self.layers = nn.ModuleList([]) + + for _ in range(depth): + layer = nn.ModuleList([ + Residual(PreNorm(model_dim, Attention(model_dim, heads=heads, Lg=Lg, Ld=Ld, La=La, + dist_kernel_fn=dist_kernel_fn))), + Residual(PreNorm(model_dim, FeedForward(model_dim))) + ]) + self.layers.append(layer) + + self.norm_out = nn.LayerNorm(model_dim) + self.ff_out = FeedForward(model_dim, dim_out) + + def forward( + self, + x, + mask=None, + adjacency_mat=None, + distance_mat=None + ): + x = self.embed_to_model(x) + + for (attn, ff) in self.layers: + x = attn( + x, + mask=mask, + adjacency_mat=adjacency_mat, + distance_mat=distance_mat + ) + x = ff(x) + + x = self.norm_out(x) + x = x.mean(dim=-2) + x = self.ff_out(x) + return x + + +class MTrsDRP(torch.nn.Module): + def __init__(self, output_dim=1, num_features_xd=78, + ge_features_dim=1000, num_features_xt=25, embed_dim=128, + mut_feature_dim=735, meth_feature_dim=377, connect_dim=128, dropout=0.2): + super(BeTrsDRP_mut, self).__init__() + + self.mat = MAT( + dim_in=78, + model_dim=78, + dim_out=78 * 2, + depth=2, + heads=6, + Lg=0.5, + Ld=0.5, + La=1, + dist_kernel_fn='exp') + self.conv_gcn = GCN_layer(num_features_xd * 2, num_features_xd * 2) + #self.fc1_drug = Linear(312, 1500) + + net1 = Sequential(Linear(num_features_xd * 2, num_features_xd * 2), ReLU(), Linear(num_features_xd * 2, num_features_xd * 2)) + self.conv_gin1 = GIN_layer(net1) + self.bn1 = torch.nn.BatchNorm1d(num_features_xd * 2) + + self.fc1_drug = Linear(num_features_xd * 4, 1500) + self.fc2_drug = Linear(1500, connect_dim) + # 激活函数和正则化 + self.relu = ReLU() + self.dropout = nn.Dropout(dropout) + + # 单组组学数据特征--MUT + self.EncoderLayer_mut_1 = nn.TransformerEncoderLayer(d_model=mut_feature_dim, nhead=1, dropout=0.5) + self.conv_mut_1 = nn.TransformerEncoder(self.EncoderLayer_mut_1, 1) + self.EncoderLayer_mut_2 = nn.TransformerEncoderLayer(d_model=mut_feature_dim, nhead=1, dropout=0.5) + self.conv_mut_2 = nn.TransformerEncoder(self.EncoderLayer_mut_2, 1) + self.fc1_mut = Linear(mut_feature_dim, 2944) + self.fc2_mut = Linear(2944, connect_dim) + + # 全连接层 + #self.fc1_all = Linear(2 * connect_dim, 1024) + #self.fc2_all = Linear(1024, 512) + #self.fc3_all = Linear(512, 256) + #self.fc4_all = Linear(256, 128) + #self.out = Linear(128, output_dim) + + self.fc1_all = Linear(2 * connect_dim, 1024) + self.fc2_all = Linear(1024, 128) + #self.fc3_all = Linear(512, 128) + #self.fc4_all = Linear(256, 256) + self.out = Linear(connect_dim, output_dim) + + # 激活函数和正则化 + self.relu = ReLU() + self.dropout = nn.Dropout(0.5) + + def forward(self, data): + drug_poi_data, drug_edg_index, batch = data.x, data.edge_index, data.batch + #ge_data, meth_data, mut_data = data.target_ge, data.target_meth, data.target_mut + mut_data = data.target + + drug_data = torch.unsqueeze(drug_poi_data, 1) + drug_data = self.mat(drug_data) + + drug_data = self.conv_gcn(drug_data, drug_edg_index) + drug_data = self.relu(drug_data) + + drug_data = F.relu(self.conv_gin1(drug_data, drug_edg_index)) + drug_data = self.bn1(drug_data) + + #drug_data = self.conv_gcn(drug_data, drug_edg_index) + #drug_data = self.relu(drug_data) + + drug_data = torch.cat([gmp(drug_data, batch), gap(drug_data, batch)], dim=1) + drug_data = self.relu(self.fc1_drug(drug_data)) + drug_data = self.dropout(drug_data) + drug_data = self.fc2_drug(drug_data) + + mut_data = mut_data[:, None, :] + mut_data = self.conv_mut_1(mut_data) + mut_data = self.conv_mut_2(mut_data) + mut_data = mut_data.view(-1, mut_data.shape[1] * mut_data.shape[2]) + mut_data = self.fc1_mut(mut_data) + mut_data = self.dropout(self.relu(mut_data)) + mut_data = self.fc2_mut(mut_data) + concat_data = torch.cat((drug_data, mut_data), 1) + + # 隐藏层 + concat_data = self.fc1_all(concat_data) + concat_data = self.relu(concat_data) + concat_data = self.dropout(concat_data) + concat_data = self.fc2_all(concat_data) + concat_data = self.relu(concat_data) + concat_data = self.dropout(concat_data) + #concat_data = self.fc3_all(concat_data) + #concat_data = self.relu(concat_data) + #concat_data = self.dropout(concat_data) + #concat_data = self.fc4_all(concat_data) + #concat_data = self.relu(concat_data) + #concat_data = self.dropout(concat_data) + out = self.out(concat_data) + out = nn.Sigmoid()(out) + return out, drug_data