Skip to content

Commit

Permalink
dnabert2-finetuned train/test scripts updated
Browse files Browse the repository at this point in the history
  • Loading branch information
akabiraka committed May 7, 2024
1 parent 0953346 commit c281e15
Show file tree
Hide file tree
Showing 8 changed files with 35 additions and 23 deletions.
5 changes: 3 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,6 @@ epbd_bert/dnabert2_epbd_crossattn/outputs
analysis/data
analysis_motif_discovery/data
**/__pycache__/
.eggs/*
build/*
*.eggs
*build
*dist
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,20 @@ jupyter notebook data_preprocessing/8_create_labels_dict.ipynb

Note: There are some other dataset modules. Each module provides example running instructions at the bottom.

## Training and testing the developed models
| Model Module | Usage |
| :--- | :--- |
| DNABERT2-finetuned | |
| ```epbd_bert.dnabert2_classifier.train_lightning``` | Train DNABERT2 using train/validation split |
| ```epbd_bert.dnabert2_classifier.test``` | Test finetuned DNABERT2 on test split |
| VanillaEPBD-DNABERT2-coordflip | |
| ```epbd_bert.dnabert2_epbd.train_lightning``` | Train VanillaEPBD-DNABERT2 using train/validation split |
| ```epbd_bert.dnabert2_epbd.test``` | Test VanillaEPBD-DNABERT2 on test split |
| EPBDxDNABERT-2 | |
| ```epbd_bert.dnabert2_epbd_crossattn.train_lightning``` | Train EPBDxDNABERT-2 using train/validation split |
| ```epbd_bert.dnabert2_epbd_crossattn.test``` | Test EPBDxDNABERT-2 on test split |

Note: Details of each model with other ablation study can be found in the [Paper](https://www.biorxiv.org/content/10.1101/2024.01.16.575935v2.abstract). To run train/test: ```python -m epbd_bert.dnabert2_classifier.test```.

## Authors

Expand Down
15 changes: 7 additions & 8 deletions epbd_bert/dnabert2_classifier/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ------- first 3 are necessary paths
model_ckpt_path = "dnabert2/backups/version_1/checkpoints/epoch=54-step=132385.ckpt"
saved_preds_path = "dnabert2/backups/pred_and_targets_dict_ontest_v1.pkl"
result_path = "dnabert2/backups/result_v1.tsv"
test_data_path = "data/train_val_test/peaks_with_labels_test.tsv.gz"
labels_dict_path = "data/processed/peakfilename_index_dict.pkl"
# model_ckpt_path = "dnabert2/backups/version_1/checkpoints/epoch=54-step=132385.ckpt"
model_ckpt_path = "resources/trained_weights/dnabert2_classifier.ckpt"
saved_preds_path = "outputs/dnabert2_pred_and_targets_dict_ontest.pkl"
result_path = "outputs/dnabert2_result.tsv"
test_data_path = "resources/train_val_test/peaks_with_labels_test.tsv.gz"
labels_dict_path = "resources/processed_data/peakfilename_index_dict.pkl"

tokenizer = get_dnabert2_tokenizer(max_num_tokens=512)
data_collator = SeqLabelDataCollator(pad_token_id=tokenizer.pad)
Expand Down Expand Up @@ -81,9 +82,7 @@ def get_predictions(dl: DataLoader, saved_preds_path: str, compute_again=False):
pickle_utils.save_as_pickle(preds_and_targets_dict, saved_preds_path)
return preds_and_targets_dict

preds_and_targets_dict = get_predictions(
test_dl, saved_preds_path, compute_again=False
)
preds_and_targets_dict = get_predictions(test_dl, saved_preds_path, compute_again=False)

# overall auc-roc
auc_roc = metrics.roc_auc_score(
Expand Down
4 changes: 2 additions & 2 deletions epbd_bert/dnabert2_classifier/train_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
tokenizer = get_dnabert2_tokenizer(max_num_tokens=512)
data_collator = SeqLabelDataCollator(pad_token_id=tokenizer.pad_token_id)
train_dataset = SequenceDataset(
data_path="data/train_val_test/peaks_with_labels_train.tsv.gz",
data_path="resources/train_val_test/peaks_with_labels_train.tsv.gz",
tokenizer=tokenizer,
)
val_dataset = SequenceDataset(
data_path="data/train_val_test/peaks_with_labels_val.tsv.gz",
data_path="resources/train_val_test/peaks_with_labels_val.tsv.gz",
tokenizer=tokenizer,
)
print(train_dataset.__len__(), val_dataset.__len__())
Expand Down
4 changes: 2 additions & 2 deletions epbd_bert/path_configs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# data_home_dir = "epbd-bert/" # "./"
data_home_dir = "./"
data_home_dir = ""
dnabert2_pretrained_dirpath = data_home_dir + "resources/DNABERT-2-117M/"
print(dnabert2_pretrained_dirpath)
# print(dnabert2_pretrained_dirpath)

# pydnaepbd_features_path = "data/pydnaepbd_things/coord_flips/id_seqs/"
# pydnaepbd_features_path = "gen-epbd/cond_epbd/coord_flips/"
Expand Down
13 changes: 5 additions & 8 deletions epbd_bert/utility/data_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import pandas as pd
import numpy as np
from sklearn.utils.class_weight import compute_class_weight
import epbd_bert.utility.pickle_utils as pickle_utils

Expand Down Expand Up @@ -53,12 +54,10 @@ def get_uniform_peaks_metadata(home_dir="/usr/projects/pyDNA_EPBD/tf_dna_binding
return peaks_metadata_df


def compute_multi_class_weights(home_dir="/usr/projects/pyDNA_EPBD/tf_dna_binding/"):
data_path = home_dir + "data/train_val_test/peaks_with_labels_train.tsv.gz"
def compute_multi_class_weights(home_dir=""):
data_path = home_dir + "resources/train_val_test/peaks_with_labels_train.tsv.gz"
data_df = pd.read_csv(data_path, compression="gzip", sep="\t")
labels_dict = pickle_utils.load(
home_dir + "data/processed/peakfilename_index_dict.pkl"
)
labels_dict = pickle_utils.load(home_dir + "resources/processed_data/peakfilename_index_dict.pkl")

all_labels = []

Expand All @@ -68,9 +67,7 @@ def get_all_labels(labels):
all_labels.append(labels_dict[l])

data_df["labels"].apply(get_all_labels)
class_weights = compute_class_weight(
"balanced", classes=list(range(len(labels_dict))), y=all_labels
)
class_weights = compute_class_weight("balanced", classes=np.array(list(range(len(labels_dict)))), y=all_labels)
class_weights = torch.tensor(class_weights, dtype=torch.float)

# print(class_weights)
Expand Down
Empty file added outputs/.gitkeep
Empty file.
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ pandas
matplotlib
seaborn
lightning
transformers
transformers
einops

0 comments on commit c281e15

Please sign in to comment.