Skip to content

Commit

Permalink
Merge pull request #24 from jisraeli/master
Browse files Browse the repository at this point in the history
architectures test with random sequences and labels; updated dependencies
  • Loading branch information
Wainberg authored Dec 29, 2016
2 parents d26bcaa + 2a92ed0 commit 5d0119e
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 34 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ To install the latest released version of DragoNN, install the [Anaconda](https:
```
conda install dragonn -c kundajelab
```
DragoNN is compatible with Python2 and Python3. Specific optional features such as [DeepLIFT](https://github.com/kundajelab/deeplift) method and [MOE](https://github.com/Yelp/MOE) are compatible with Python2 only.
DragoNN is compatible with Python2 and Python3. Specific optional features such as [DeepLIFT](https://github.com/kundajelab/deeplift) and [MOE](https://github.com/Yelp/MOE) are compatible with Python2 only.


## 15 seconds to your first DragoNN model
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
'version': '0.1.2',
'packages': ['dragonn'],
'setup_requires': [],
'install_requires': ['numpy>=1.9', 'keras==0.3.3', 'deeplift==0.3', 'shapely', 'simdna==0.1', 'matplotlib<=1.5.3',
'install_requires': ['numpy>=1.9', 'keras==0.3.3', 'deeplift==0.3', 'shapely', 'simdna==0.2', 'matplotlib<=1.5.3',
'scikit-learn', 'pydot_ng==1.0.0', 'h5py'],
'dependency_links': ["https://github.com/kundajelab/deeplift/tarball/v0.3-alpha#egg=deeplift-0.3",
"https://github.com/kundajelab/simdna/tarball/0.1#egg=simdna-0.1"],
"https://github.com/kundajelab/simdna/tarball/0.2#egg=simdna-0.2"],
'scripts': [],
'entry_points': {'console_scripts': ['dragonn = dragonn.__main__:main']},
'name': 'dragonn'
Expand Down
52 changes: 21 additions & 31 deletions tests/test_architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,19 @@
from collections import OrderedDict
from dragonn.models import SequenceDNN
from dragonn.utils import one_hot_encode, reverse_complement
from simdna.simulations import simulate_single_motif_detection
try:
from sklearn.model_selection import train_test_split # sklearn >= 0.18
except ImportError:
from sklearn.cross_validation import train_test_split # sklearn < 0.18


def run(use_deep_CNN, use_RNN, label, golden_first_sequence, golden_results):
def run(use_deep_CNN, use_RNN, label, golden_results):
seq_length = 100
num_sequences = 200
num_positives = 100
num_negatives = num_sequences - num_positives
GC_fraction = 0.4
test_fraction = 0.2
num_epochs = 1
sequences, labels, embeddings = simulate_single_motif_detection(
'SPI1_disc1', seq_length, num_positives, num_negatives, GC_fraction)
assert sequences[0] == golden_first_sequence, 'first sequence = {}, golden = {}'.format(
sequences[0], golden_first_sequence)
sequences = np.array([''.join(random.choice('ACGT') for base in range(seq_length)) for sequence in range(num_sequences)])
labels = np.random.choice((True, False), size=num_sequences)[:, None]
encoded_sequences = one_hot_encode(sequences)
X_train, X_test, y_train, y_test = train_test_split(
encoded_sequences, labels, test_size=test_fraction)
Expand All @@ -50,32 +44,28 @@ def run(use_deep_CNN, use_RNN, label, golden_first_sequence, golden_results):

def test_shallow_CNN():
run(use_deep_CNN=False, use_RNN=False, label='Shallow CNN',
golden_first_sequence='TTGAACAAGGTGAGTAATTCTAATAAGGCTGTTCAAATATGTTCCGTGTC'
'AATGTTATTAACAATCAGTAGAACAGTTCCCCTTATCTTAGTTAACGTGT',
golden_results=OrderedDict((('Loss', 1.613392511974697),
('Balanced accuracy', 50.0),
('auROC', 0.581453634085213),
('auPRC', 0.48312846202300236),
('Recall at 5% FDR', 0.0),
('Recall at 10% FDR', 0.0),
('Recall at 20% FDR', 0.0),
('Num Positives', 19),
('Num Negatives', 21))))
golden_results=OrderedDict([('Loss', 0.70371496533279465),
('Balanced accuracy', 55.639097744360896),
('auROC', 0.50877192982456143),
('auPRC', 0.58026674651508325),
('Recall at 5% FDR', 9.5238095238095237),
('Recall at 10% FDR', 9.5238095238095237),
('Recall at 20% FDR', 9.5238095238095237),
('Num Positives', 21),
('Num Negatives', 19)]))


def test_deep_CNN():
run(use_deep_CNN=True, use_RNN=False, label='Deep CNN',
golden_first_sequence='AACTCTGCTGATCTATTAGAGCTACTATCGTCCAAAGCCCTCGCTACTGC'
'TAGGATTATTGCTGAAGAGGAAGTAAATAATTTTTATTACCAATGCATGT',
golden_results=OrderedDict((('Loss', 0.9269361595503689),
('Balanced accuracy', 50.0),
('auROC', 0.34335839598997497),
('auPRC', 0.4227729924796752),
('Recall at 5% FDR', 0.0),
('Recall at 10% FDR', 0.0),
('Recall at 20% FDR', 0.0),
('Num Positives', 21),
('Num Negatives', 19))))
golden_results=OrderedDict([('Loss', 0.68411321005526782),
('Balanced accuracy', 45.833333333333329),
('auROC', 0.51822916666666663),
('auPRC', 0.41738642611750432),
('Recall at 5% FDR', 0.0),
('Recall at 10% FDR', 0.0),
('Recall at 20% FDR', 0.0),
('Num Positives', 16),
('Num Negatives', 24)]))


if __name__ == '__main__':
Expand Down

0 comments on commit 5d0119e

Please sign in to comment.