Skip to content

Commit

Permalink
train script tested
Browse files Browse the repository at this point in the history
  • Loading branch information
akabiraka committed May 7, 2024
1 parent c281e15 commit 78e8aa4
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 6 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ conda activate .venvs/epbd_bert_condavenv_test1
python setup.py install

conda install -c conda-forge scikit-learn scipy -y
pip uninstall triton # We did not utilize triton for underlying hardware dependency

# To deactivate and remove the venv
conda deactivate
Expand Down
8 changes: 4 additions & 4 deletions epbd_bert/dnabert2_epbd/train_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@
tokenizer = get_dnabert2_tokenizer(max_num_tokens=512)
data_collator = SeqLabelEPBDDataCollator(pad_token_id=tokenizer.pad_token_id)
train_dataset = SequenceEPBDDataset(
data_path="data/train_val_test/peaks_with_labels_train.tsv.gz",
data_path="resources/train_val_test/peaks_with_labels_train.tsv.gz",
pydnaepbd_features_path="resources/pydnaepbd_things/coord_flips/id_seqs/", # ../data, resources
tokenizer=tokenizer,
epbd_features_type=configs.epbd_features_type,
)
val_dataset = SequenceEPBDDataset(
data_path="data/train_val_test/peaks_with_labels_val.tsv.gz",
data_path="resources/train_val_test/peaks_with_labels_val.tsv.gz",
pydnaepbd_features_path="resources/pydnaepbd_things/coord_flips/id_seqs/", # ../data, resources
tokenizer=tokenizer,
epbd_features_type=configs.epbd_features_type,
)
print(train_dataset.__len__(), val_dataset.__len__())
train_dl = DataLoader(
Expand Down
6 changes: 4 additions & 2 deletions epbd_bert/dnabert2_epbd_crossattn/train_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,13 @@
tokenizer = get_dnabert2_tokenizer(max_num_tokens=512)
data_collator = SeqLabelEPBDDataCollator(tokenizer.pad_token_id)
train_dataset = SequenceRandEPBDMultiModalDataset(
data_path="data/train_val_test/peaks_with_labels_train.tsv.gz",
data_path="resources/train_val_test/peaks_with_labels_train.tsv.gz",
pydnaepbd_features_path="resources/pydnaepbd_things/coord_flips/id_seqs/", # ../data, resources
tokenizer=tokenizer,
)
val_dataset = SequenceRandEPBDMultiModalDataset(
data_path="data/train_val_test/peaks_with_labels_val.tsv.gz",
data_path="resources/train_val_test/peaks_with_labels_val.tsv.gz",
pydnaepbd_features_path="resources/pydnaepbd_things/coord_flips/id_seqs/", # ../data, resources
tokenizer=tokenizer,
)
print(train_dataset.__len__(), val_dataset.__len__())
Expand Down

0 comments on commit 78e8aa4

Please sign in to comment.