Skip to content

Commit

Permalink
Reproducing the performance, but not successful yet.
Browse files Browse the repository at this point in the history
  • Loading branch information
dohlee committed Feb 18, 2023
1 parent de84fff commit a98b498
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 4 deletions.
4 changes: 3 additions & 1 deletion abyssal_pytorch/abyssal_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,19 @@ def forward(self, x):
return self.att_conv(x).softmax(dim=-1) * self.feat_conv(x)

class Abyssal(nn.Module):
def __init__(self, use_bn=False):
def __init__(self, p_dropout=0.0, use_bn=False):
super().__init__()

self.light_attention = LightAttention()
self.fc_block = nn.Sequential(
nn.Linear(2560, 2048),
nn.BatchNorm1d(2048) if use_bn else nn.Identity(),
nn.ReLU(), # Not sure
nn.Dropout(p_dropout),
nn.Linear(2048, 1024),
nn.BatchNorm1d(1024) if use_bn else nn.Identity(),
nn.ReLU(), # Not sure
nn.Dropout(p_dropout),
nn.Linear(1024, 1),
)

Expand Down
72 changes: 69 additions & 3 deletions abyssal_pytorch/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,44 @@ def validate(model, val_loader, criterion, metrics_f):

return loss, metrics

def test(model, val_loader, criterion, metrics_f):
model.eval()

out_fwd, out_rev, label = [], [], []
with torch.no_grad():
for batch in val_loader:
wt_emb, mut_emb = batch['wt_emb'].cuda(), batch['mut_emb'].cuda()
_label = batch['label'].cuda().flatten()

_out_fwd = model(wt_emb, mut_emb).flatten()
_out_rev = model(mut_emb, wt_emb).flatten() # Swap wt_emb and mut_emb.

out_fwd.append(_out_fwd.cpu())
out_rev.append(_out_rev.cpu())

label.append(_label.cpu())

out_fwd = torch.cat(out_fwd, dim=0)
out_rev = torch.cat(out_rev, dim=0)
label = torch.cat(label, dim=0)

loss = criterion(out_fwd, label).item()
metrics = {k: f(out_fwd, label) for k, f in metrics_f.items()}

# Add antisymmetry metrics.
metrics['pearson_fr'] = pearsonr(out_fwd, out_rev)[0]
metrics['delta'] = torch.cat([out_fwd, out_rev], dim=0).mean()

wandb.log({
'test/loss': loss,
'test/pearson': metrics['pearson'],
'test/spearman': metrics['spearman'],
'test/pearson_fr': metrics['pearson_fr'],
'test/delta': metrics['delta'],
})

return loss, metrics

def seed_everything(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
Expand All @@ -108,12 +146,14 @@ def main():
parser = argparse.ArgumentParser()
parser.add_argument('--train', required=True)
parser.add_argument('--val', required=True)
parser.add_argument('--test', required=True)
parser.add_argument('--output', '-o', required=True)
parser.add_argument('--emb-dir', required=True)
parser.add_argument('--batch-size', type=int, default=128)
parser.add_argument('--epochs', type=int, default=72)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--use-bn', action='store_true', default=False)
parser.add_argument('--dropout', type=float, default=0.0)
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--use-wandb', action='store_true', default=False)
args = parser.parse_args()
Expand All @@ -130,13 +170,17 @@ def main():
val_df = pd.read_csv(args.val)
val_set = MegaDataset(val_df, emb_dir=args.emb_dir)

test_df = pd.read_csv(args.test)
test_set = MegaDataset(test_df, emb_dir=args.emb_dir)

train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=16, pin_memory=True)
val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=16, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=16, pin_memory=True)

model = Abyssal()
model = Abyssal(p_dropout=args.dropout, use_bn=args.use_bn)
model = model.cuda()

optimizer = optim.Adam(model.parameters(), lr=args.lr)
optimizer = optim.AdamW(model.parameters(), lr=args.lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.97)
criterion = nn.MSELoss()

Expand All @@ -146,22 +190,44 @@ def main():
}

best_val_loss = np.inf
best_val_pearson = -np.inf
best_val_spearman = -np.inf
best_test_loss = np.inf
best_test_pearson = -np.inf
best_test_spearman = -np.inf
for epoch in range(args.epochs):
train(model, train_loader, optimizer, criterion, metrics_f)
val_loss, val_metrics = validate(model, val_loader, criterion, metrics_f)

test_loss, test_metrics = test(model, test_loader, criterion, metrics_f)

if val_loss < best_val_loss:
best_val_loss = val_loss
best_val_pearson = val_metrics['pearson']
best_val_spearman = val_metrics['spearman']

torch.save(model.state_dict(), args.output)

best_test_loss = test_loss
best_test_pearson = test_metrics['pearson']
best_test_spearman = test_metrics['spearman']

message = f'Epoch {epoch} Validation: loss {val_loss:.4f},'
message += ', '.join([f'{k} {v:.4f}' for k, v in val_metrics.items()])
print(message)

message = f'Epoch {epoch} Test: loss {test_loss:.4f},'
message += ', '.join([f'{k} {v:.4f}' for k, v in test_metrics.items()])
print(message)

scheduler.step()

wandb.log({
'best_val_loss': best_val_loss,
'best_val_pearson': best_val_pearson,
'best_val_spearman': best_val_spearman,
'test_loss': best_test_loss,
'test_pearson': best_test_pearson,
'test_spearman': best_test_spearman,
})

if __name__ == '__main__':
Expand Down
40 changes: 40 additions & 0 deletions note/data-preprocessing.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1254,6 +1254,46 @@
"!wc -l ../data/mega.val.csv\n",
"!wc -l ../data/mega.test.csv"
]
},
{
"cell_type": "code",
"execution_count": 131,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"-6.799119306568167e-17"
]
},
"execution_count": 131,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.concat([mega_train, mega_train_sym])['ddG_ML'].mean()"
]
},
{
"cell_type": "code",
"execution_count": 132,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1.6572616968485474"
]
},
"execution_count": 132,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pd.concat([mega_train, mega_train_sym])['ddG_ML'].var()"
]
}
],
"metadata": {
Expand Down

0 comments on commit a98b498

Please sign in to comment.