diff --git a/selene_sdk/evaluate_model.py b/selene_sdk/evaluate_model.py index 90146ad3..ab5931b0 100644 --- a/selene_sdk/evaluate_model.py +++ b/selene_sdk/evaluate_model.py @@ -209,24 +209,17 @@ def evaluate(self): """ batch_losses = [] all_predictions = [] - for (inputs, targets) in self._test_data: - inputs = torch.Tensor(inputs) - targets = torch.Tensor(targets[:, self._use_ixs]) + for samples_batch in self._test_data: + inputs, targets = samples_batch.torch_inputs_and_targets(self.use_cuda) + targets = targets[:, self._use_ixs] - if self.use_cuda: - inputs = inputs.cuda() - targets = targets.cuda() with torch.no_grad(): - inputs = Variable(inputs) - targets = Variable(targets) - predictions = None if _is_lua_trained_model(self.model): predictions = self.model.forward( - inputs.transpose(1, 2).contiguous().unsqueeze_(2)) + inputs.contiguous().unsqueeze_(2)) else: - predictions = self.model.forward( - inputs.transpose(1, 2)) + predictions = self.model.forward(inputs) predictions = predictions[:, self._use_ixs] loss = self.criterion(predictions, targets) diff --git a/selene_sdk/predict/tests/__init__.py b/selene_sdk/predict/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/selene_sdk/samplers/file_samplers/bed_file_sampler.py b/selene_sdk/samplers/file_samplers/bed_file_sampler.py index dfa16027..5592a35b 100644 --- a/selene_sdk/samplers/file_samplers/bed_file_sampler.py +++ b/selene_sdk/samplers/file_samplers/bed_file_sampler.py @@ -1,6 +1,7 @@ """ This module provides the BedFileSampler class. """ +from selene_sdk.samplers.samples_batch import SamplesBatch import numpy as np from .file_sampler import FileSampler @@ -96,8 +97,8 @@ def sample(self, batch_size=1): Returns ------- - sequences, targets : tuple(numpy.ndarray, numpy.ndarray) - A tuple containing the numeric representation of the + SamplesBatch + A batch containing the numeric representation of the sequence examples and their corresponding labels. The shape of `sequences` will be :math:`B \\times L \\times N`, where :math:`B` is @@ -163,8 +164,8 @@ def sample(self, batch_size=1): sequences = np.array(sequences) if self.targets_avail: targets = np.array(targets) - return (sequences, targets) - return sequences, + return SamplesBatch(sequences, target_batch=targets) + return SamplesBatch(sequences) def get_data(self, batch_size, n_samples=None): """ @@ -188,18 +189,21 @@ def get_data(self, batch_size, n_samples=None): and :math:`N` is the size of the sequence type's alphabet. """ + # TODO: Should this method return a collection of samples_batch.inputs()? + if not n_samples: n_samples = self.n_samples sequences = [] count = batch_size while count < n_samples: - seqs, = self.sample(batch_size=batch_size) - sequences.append(seqs) + samples_batch = self.sample(batch_size=batch_size) + sequences.append(samples_batch.sequence_batch()) count += batch_size remainder = batch_size - (count - n_samples) - seqs, = self.sample(batch_size=remainder) - sequences.append(seqs) + samples_batch = self.sample(batch_size=remainder) + sequences.append(samples_batch.sequence_batch()) + return sequences def get_data_and_targets(self, batch_size, n_samples=None): @@ -216,11 +220,11 @@ def get_data_and_targets(self, batch_size, n_samples=None): Returns ------- - sequences_and_targets, targets_matrix : \ - tuple(list(tuple(numpy.ndarray, numpy.ndarray)), numpy.ndarray) - Tuple containing the list of sequence-target pairs, as well + batches, targets_matrix : \ + tuple(list(SamplesBatch), numpy.ndarray) + Tuple containing the list of batches, as well as a single matrix with all targets in the same order. - Note that `sequences_and_targets`'s sequence elements are of + Note that `batches`'s sequence elements are of the shape :math:`B \\times L \\times N` and its target elements are of the shape :math:`B \\times F`, where :math:`B` is `batch_size`, :math:`L` is the sequence length, @@ -236,18 +240,18 @@ def get_data_and_targets(self, batch_size, n_samples=None): "Please use `get_data` instead.") if not n_samples: n_samples = self.n_samples - sequences_and_targets = [] + batches = [] targets_mat = [] count = batch_size while count < n_samples: - seqs, tgts = self.sample(batch_size=batch_size) - sequences_and_targets.append((seqs, tgts)) - targets_mat.append(tgts) + samples_batch = self.sample(batch_size=batch_size) + batches.append(samples_batch) + targets_mat.append(samples_batch.targets()) count += batch_size remainder = batch_size - (count - n_samples) - seqs, tgts = self.sample(batch_size=remainder) - sequences_and_targets.append((seqs, tgts)) - targets_mat.append(tgts) + samples_batch = self.sample(batch_size=remainder) + batches.append(samples_batch) + targets_mat.append(samples_batch.targets()) targets_mat = np.vstack(targets_mat).astype(int) - return sequences_and_targets, targets_mat + return batches, targets_mat diff --git a/selene_sdk/samplers/file_samplers/file_sampler.py b/selene_sdk/samplers/file_samplers/file_sampler.py index d29f7444..ec2628d5 100644 --- a/selene_sdk/samplers/file_samplers/file_sampler.py +++ b/selene_sdk/samplers/file_samplers/file_sampler.py @@ -8,6 +8,8 @@ from abc import ABCMeta from abc import abstractmethod +from selene_sdk.samplers.samples_batch import SamplesBatch + class FileSampler(metaclass=ABCMeta): """ @@ -26,7 +28,7 @@ def __init__(self): """ @abstractmethod - def sample(self, batch_size=1): + def sample(self, batch_size=1) -> SamplesBatch: """ Fetches a mini-batch of the data from the sampler. diff --git a/selene_sdk/samplers/file_samplers/mat_file_sampler.py b/selene_sdk/samplers/file_samplers/mat_file_sampler.py index f18bd1a2..28c77b34 100644 --- a/selene_sdk/samplers/file_samplers/mat_file_sampler.py +++ b/selene_sdk/samplers/file_samplers/mat_file_sampler.py @@ -5,6 +5,7 @@ import h5py import numpy as np import scipy.io +from selene_sdk.samplers.samples_batch import SamplesBatch from .file_sampler import FileSampler @@ -126,8 +127,8 @@ def sample(self, batch_size=1): Returns ------- - sequences, targets : tuple(numpy.ndarray, numpy.ndarray) - A tuple containing the numeric representation of the + SamplesBatch + A batch containing the numeric representation of the sequence examples and their corresponding labels. The shape of `sequences` will be :math:`B \\times L \\times N`, where :math:`B` is @@ -166,8 +167,8 @@ def sample(self, batch_size=1): targets = self._sample_tgts[:, use_indices].astype(float) targets = np.transpose( targets, (1, 0)) - return (sequences, targets) - return sequences, + return SamplesBatch(sequences, target_batch=targets) + return SamplesBatch(sequences) def get_data(self, batch_size, n_samples=None): """ @@ -190,18 +191,20 @@ def get_data(self, batch_size, n_samples=None): is `batch_size`, :math:`L` is the sequence length, and :math:`N` is the size of the sequence type's alphabet. """ + # TODO: Should this method return a collection of samples_batch.inputs()? + if not n_samples: n_samples = self.n_samples sequences = [] count = batch_size while count < n_samples: - seqs, = self.sample(batch_size=batch_size) - sequences.append(seqs) + samples_batch = self.sample(batch_size=batch_size) + sequences.append(samples_batch.sequence_batch()) count += batch_size remainder = batch_size - (count - n_samples) - seqs, = self.sample(batch_size=remainder) - sequences.append(seqs) + samples_batch = self.sample(batch_size=remainder) + sequences.append(samples_batch.sequence_batch()) return sequences def get_data_and_targets(self, batch_size, n_samples=None): @@ -218,11 +221,11 @@ def get_data_and_targets(self, batch_size, n_samples=None): Returns ------- - sequences_and_targets, targets_matrix : \ - tuple(list(tuple(numpy.ndarray, numpy.ndarray)), numpy.ndarray) - Tuple containing the list of sequence-target pairs, as well + batches, targets_matrix : \ + tuple(list(SamplesBatch), numpy.ndarray) + Tuple containing the list of batches, as well as a single matrix with all targets in the same order. - Note that `sequences_and_targets`'s sequence elements are of + Note that `batches`'s sequence elements are of the shape :math:`B \\times L \\times N` and its target elements are of the shape :math:`B \\times F`, where :math:`B` is `batch_size`, :math:`L` is the sequence length, @@ -237,19 +240,19 @@ def get_data_and_targets(self, batch_size, n_samples=None): "initialization. Please use `get_data` instead.") if not n_samples: n_samples = self.n_samples - sequences_and_targets = [] + batches = [] targets_mat = [] count = batch_size while count < n_samples: - seqs, tgts = self.sample(batch_size=batch_size) - sequences_and_targets.append((seqs, tgts)) - targets_mat.append(tgts) + samples_batch = self.sample(batch_size=batch_size) + batches.append(samples_batch) + targets_mat.append(samples_batch.targets()) count += batch_size remainder = batch_size - (count - n_samples) - seqs, tgts = self.sample(batch_size=remainder) - sequences_and_targets.append((seqs, tgts)) - targets_mat.append(tgts) + samples_batch = self.sample(batch_size=remainder) + batches.append(samples_batch) + targets_mat.append(samples_batch.targets()) # TODO: should not assume targets are always integers targets_mat = np.vstack(targets_mat).astype(float) - return sequences_and_targets, targets_mat + return batches, targets_mat diff --git a/selene_sdk/samplers/intervals_sampler.py b/selene_sdk/samplers/intervals_sampler.py index 497b0f40..5fe57d40 100644 --- a/selene_sdk/samplers/intervals_sampler.py +++ b/selene_sdk/samplers/intervals_sampler.py @@ -2,14 +2,15 @@ This module provides the `IntervalsSampler` class and supporting methods. """ -from collections import namedtuple import logging import random +from collections import namedtuple import numpy as np -from .online_sampler import OnlineSampler +from selene_sdk.samplers.samples_batch import SamplesBatch from ..utils import get_indices_and_probabilities +from .online_sampler import OnlineSampler logger = logging.getLogger(__name__) @@ -388,8 +389,8 @@ def sample(self, batch_size=1): Returns ------- - sequences, targets : tuple(numpy.ndarray, numpy.ndarray) - A tuple containing the numeric representation of the + SamplesBatch + A batch containing the numeric representation of the sequence examples and their corresponding labels. The shape of `sequences` will be :math:`B \\times L \\times N`, where :math:`B` is @@ -426,4 +427,4 @@ def sample(self, batch_size=1): sequences[n_samples_drawn, :, :] = seq targets[n_samples_drawn, :] = seq_targets n_samples_drawn += 1 - return (sequences, targets) + return SamplesBatch(sequences, target_batch=targets) diff --git a/selene_sdk/samplers/multi_file_sampler.py b/selene_sdk/samplers/multi_file_sampler.py index 5dbd0b99..ff0d1048 100644 --- a/selene_sdk/samplers/multi_file_sampler.py +++ b/selene_sdk/samplers/multi_file_sampler.py @@ -186,11 +186,11 @@ def get_test_set(self, batch_size, n_samples=None): Returns ------- - sequences_and_targets, targets_matrix : \ - tuple(list(tuple(numpy.ndarray, numpy.ndarray)), numpy.ndarray) - Tuple containing the list of sequence-target pairs, as well + batches, targets_matrix : \ + tuple(list(SamplesBatch), numpy.ndarray) + Tuple containing the list of batches, as well as a single matrix with all targets in the same order. - Note that `sequences_and_targets`'s sequence elements are of + Note that `batches`'s sequence elements are of the shape :math:`B \\times L \\times N` and its target elements are of the shape :math:`B \\times F`, where :math:`B` is `batch_size`, :math:`L` is the sequence length, diff --git a/selene_sdk/samplers/online_sampler.py b/selene_sdk/samplers/online_sampler.py index b84a26ca..1c735d5a 100644 --- a/selene_sdk/samplers/online_sampler.py +++ b/selene_sdk/samplers/online_sampler.py @@ -4,14 +4,14 @@ "on the fly" rather than storing them all persistently in memory. """ -from abc import ABCMeta import os import random +from abc import ABCMeta import numpy as np -from .sampler import Sampler from ..targets import GenomicFeatures +from .sampler import Sampler class OnlineSampler(Sampler, metaclass=ABCMeta): @@ -302,11 +302,11 @@ def get_data_and_targets(self, batch_size, n_samples=None, mode=None): Returns ------- - sequences_and_targets, targets_matrix : \ - tuple(list(tuple(numpy.ndarray, numpy.ndarray)), numpy.ndarray) - Tuple containing the list of sequence-target pairs, as well + batches, targets_matrix : \ + tuple(list(SamplesBatch), numpy.ndarray) + Tuple containing the list of batches, as well as a single matrix with all targets in the same order. - Note that `sequences_and_targets`'s sequence elements are of + Note that `batches`'s sequence elements are of the shape :math:`B \\times L \\times N` and its target elements are of the shape :math:`B \\times F`, where :math:`B` is `batch_size`, :math:`L` is the sequence length, @@ -320,7 +320,7 @@ def get_data_and_targets(self, batch_size, n_samples=None, mode=None): self.set_mode(mode) else: mode = self.mode - sequences_and_targets = [] + batches = [] if n_samples is None and mode == "validate": n_samples = 32000 elif n_samples is None and mode == "test": @@ -328,12 +328,12 @@ def get_data_and_targets(self, batch_size, n_samples=None, mode=None): n_batches = int(n_samples / batch_size) for _ in range(n_batches): - inputs, targets = self.sample(batch_size) - sequences_and_targets.append((inputs, targets)) - targets_mat = np.vstack([t for (s, t) in sequences_and_targets]) + samples_batch = self.sample(batch_size) + batches.append(samples_batch) + targets_mat = np.vstack([batch.targets() for batch in batches]) if mode in self._save_datasets: self.save_dataset_to_file(mode, close_filehandle=True) - return sequences_and_targets, targets_mat + return batches, targets_mat def get_dataset_in_batches(self, mode, batch_size, n_samples=None): """ @@ -355,12 +355,12 @@ def get_dataset_in_batches(self, mode, batch_size, n_samples=None): Returns ------- - sequences_and_targets, targets_matrix : \ - tuple(list(tuple(numpy.ndarray, numpy.ndarray)), numpy.ndarray) - Tuple containing the list of sequence-target pairs, as well + batches, targets_matrix : \ + tuple(list(SamplesBatch), numpy.ndarray) + Tuple containing the list of batches, as well as a single matrix with all targets in the same order. The list is length :math:`S`, where :math:`S =` `n_samples`. - Note that `sequences_and_targets`'s sequence elements are of + Note that `batches`'s sequence elements are of the shape :math:`B \\times L \\times N` and its target elements are of the shape :math:`B \\times F`, where :math:`B` is `batch_size`, :math:`L` is the sequence length, @@ -387,11 +387,11 @@ def get_validation_set(self, batch_size, n_samples=None): Returns ------- - sequences_and_targets, targets_matrix : \ - tuple(list(tuple(numpy.ndarray, numpy.ndarray)), numpy.ndarray) - Tuple containing the list of sequence-target pairs, as well + batches, targets_matrix : \ + tuple(list(SamplesBatch), numpy.ndarray) + Tuple containing the list of batches, as well as a single matrix with all targets in the same order. - Note that `sequences_and_targets`'s sequence elements are of + Note that `batches`'s sequence elements are of the shape :math:`B \\times L \\times N` and its target elements are of the shape :math:`B \\times F`, where :math:`B` is `batch_size`, :math:`L` is the sequence length, @@ -419,11 +419,11 @@ def get_test_set(self, batch_size, n_samples=None): Returns ------- - sequences_and_targets, targets_matrix : \ - tuple(list(tuple(numpy.ndarray, numpy.ndarray)), numpy.ndarray) - Tuple containing the list of sequence-target pairs, as well + batches, targets_matrix : \ + tuple(list(SamplesBatch), numpy.ndarray) + Tuple containing the list of batches, as well as a single matrix with all targets in the same order. - Note that `sequences_and_targets`'s sequence elements are of + Note that `batches`'s sequence elements are of the shape :math:`B \\times L \\times N` and its target elements are of the shape :math:`B \\times F`, where :math:`B` is `batch_size`, :math:`L` is the sequence length, diff --git a/selene_sdk/samplers/random_positions_sampler.py b/selene_sdk/samplers/random_positions_sampler.py index 73662a83..4d86b4fe 100644 --- a/selene_sdk/samplers/random_positions_sampler.py +++ b/selene_sdk/samplers/random_positions_sampler.py @@ -8,8 +8,10 @@ import logging import random + import numpy as np +from selene_sdk.samplers.samples_batch import SamplesBatch from .online_sampler import OnlineSampler from ..utils import get_indices_and_probabilities @@ -305,8 +307,8 @@ def sample(self, batch_size=1): Returns ------- - sequences, targets : tuple(numpy.ndarray, numpy.ndarray) - A tuple containing the numeric representation of the + SamplesBatch + A batch containing the numeric representation of the sequence examples and their corresponding labels. The shape of `sequences` will be :math:`B \\times L \\times N`, where :math:`B` is @@ -340,4 +342,4 @@ def sample(self, batch_size=1): sequences[n_samples_drawn, :, :] = seq targets[n_samples_drawn, :] = seq_targets n_samples_drawn += 1 - return (sequences, targets) + return SamplesBatch(sequences, target_batch=targets) diff --git a/selene_sdk/samplers/sampler.py b/selene_sdk/samplers/sampler.py index 8d3a3dcb..c863a84f 100644 --- a/selene_sdk/samplers/sampler.py +++ b/selene_sdk/samplers/sampler.py @@ -6,6 +6,7 @@ from abc import abstractmethod import os +from selene_sdk.samplers.samples_batch import SamplesBatch class Sampler(metaclass=ABCMeta): """ @@ -98,7 +99,7 @@ def get_feature_from_index(self, index): raise NotImplementedError() @abstractmethod - def sample(self, batch_size=1): + def sample(self, batch_size=1) -> SamplesBatch: """ Fetches a mini-batch of the data from the sampler. @@ -107,6 +108,10 @@ def sample(self, batch_size=1): batch_size : int, optional Default is 1. The size of the batch to retrieve. + Returns + ------- + samples_batch : SamplesBatch + A struct containing inputs and targets in numpy format. """ raise NotImplementedError() @@ -146,6 +151,19 @@ def get_validation_set(self, batch_size, n_samples=None): retrieve. Handling for `n_samples=None` should be done by all classes that subclass `selene_sdk.samplers.Sampler`. + Returns + ------- + batches, targets_matrix : tuple(list(SamplesBatch), numpy.ndarray) + Tuple containing the list of batches (with targets), as well + as a single matrix with all targets in the same order. + Note that `batches`'s sequence elements are of + the shape :math:`B \\times L \\times N` and its target + elements are of the shape :math:`B \\times F`, where + :math:`B` is `batch_size`, :math:`L` is the sequence length, + :math:`N` is the size of the sequence type's alphabet, and + :math:`F` is the number of features. Further, + `target_matrix` is of the shape :math:`S \\times F`, where + :math:`S =` `n_samples`. """ raise NotImplementedError() @@ -160,16 +178,16 @@ def get_test_set(self, batch_size, n_samples=None): batch_size : int The size of the batches to divide the data into. n_samples : int or None, optional - Default is `None`. The total number of validation examples - to retrieve. If `None`, 640000 examples are retrieved. + Default is `None`. The total number of test examples + to retrieve. Handling for `n_samples=None` should be done by + all classes that subclass `selene_sdk.samplers.Sampler`. Returns ------- - sequences_and_targets, targets_matrix : \ - tuple(list(tuple(numpy.ndarray, numpy.ndarray)), numpy.ndarray) - Tuple containing the list of sequence-target pairs, as well + batches, targets_matrix : tuple(list(SamplesBatch), numpy.ndarray) + Tuple containing the list of batches (with targets), as well as a single matrix with all targets in the same order. - Note that `sequences_and_targets`'s sequence elements are of + Note that `batches`'s sequence elements are of the shape :math:`B \\times L \\times N` and its target elements are of the shape :math:`B \\times F`, where :math:`B` is `batch_size`, :math:`L` is the sequence length, diff --git a/selene_sdk/samplers/samples_batch.py b/selene_sdk/samplers/samples_batch.py new file mode 100644 index 00000000..8bade930 --- /dev/null +++ b/selene_sdk/samplers/samples_batch.py @@ -0,0 +1,106 @@ +import numpy as np +import torch + + +class SamplesBatch: + """ + This class represents NN inputs and targets. Values are stored as numpy.ndarrays + and there is a method to convert them to torch.Tensors. + + Inputs are stored in a dict, which can be used if you are providing more than just a + `sequence_batch` to the NN. + + NOTE: If you store just a sequence as an input to the model, then `inputs()` and + `torch_inputs_and_targets()` will return only the batch of sequences rather than + a dict. + + """ + + _SEQUENCE_LABEL = "sequence_batch" + + def __init__( + self, + sequence_batch: np.ndarray, + other_input_batches=dict(), + target_batch: np.ndarray = None, + ) -> None: + self._input_batches = other_input_batches.copy() + self._input_batches[self._SEQUENCE_LABEL] = sequence_batch + self._target_batch = target_batch + + def sequence_batch(self) -> torch.Tensor: + """Returns the sequence batch with a shape of + [batch_size, sequence_length, alphabet_size]. + """ + return self._input_batches[self._SEQUENCE_LABEL] + + def inputs(self): + """Based on the size of inputs dictionary, returns either just the + sequence or the whole dictionary. + + Returns + ------- + numpy.ndarray or dict[str, numpy.ndarray] + numpy.ndarray is returned when inputs contain just the sequence batch. + dict[str, numpy.ndarray] is returned when there are multiple inputs. + + NOTE: Sequence batch has a shape of + [batch_size, sequence_length, alphabet_size]. + """ + if len(self._input_batches) == 1: + return self.sequence_batch() + + return self._input_batches + + def targets(self): + """Returns target batch if it is present. + + Returns + ------- + numpy.ndarray + + """ + return self._target_batch + + def torch_inputs_and_targets(self, use_cuda: bool): + """ + Returns inputs and targets in torch.Tensor format. + + Based on the size of inputs dictionary, returns either just the + sequence or the whole dictionary. + + Returns + ------- + inputs, targets :\ + tuple(numpy.ndarray or dict[str, numpy.ndarray], numpy.ndarray) + For `inputs`: + numpy.ndarray is returned when inputs contain just the sequence batch. + dict[str, numpy.ndarray] is returned when there are multiple inputs. + + NOTE: Returned sequence batch has a shape of + [batch_size, alphabet_size, sequence_length]. + + """ + all_inputs = dict() + for key, value in self._input_batches.items(): + all_inputs[key] = torch.Tensor(value) + + if use_cuda: + all_inputs[key] = all_inputs[key].cuda() + + # Transpose the sequences to satisfy NN convolution input format (which is + # [batch_size, channels_size, sequence_length]). + all_inputs[self._SEQUENCE_LABEL] = all_inputs[self._SEQUENCE_LABEL].transpose( + 1, 2 + ) + + inputs = all_inputs if len(all_inputs) > 1 else all_inputs[self._SEQUENCE_LABEL] + + targets = None + if self._target_batch is not None: + targets = torch.Tensor(self._target_batch) + + if use_cuda: + targets = targets.cuda() + + return inputs, targets diff --git a/selene_sdk/samplers/tests/__init__.py b/selene_sdk/samplers/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/selene_sdk/samplers/tests/test_samples_batch.py b/selene_sdk/samplers/tests/test_samples_batch.py new file mode 100644 index 00000000..c2f4c055 --- /dev/null +++ b/selene_sdk/samplers/tests/test_samples_batch.py @@ -0,0 +1,48 @@ +import unittest + +import numpy as np +import torch +from selene_sdk.samplers.samples_batch import SamplesBatch + + +class TestSamplesBatch(unittest.TestCase): + def test_single_input(self): + samples_batch = SamplesBatch(np.ones((6, 200, 4)), target_batch=np.ones(20)) + + inputs = samples_batch.inputs() + self.assertIsInstance(inputs, np.ndarray) + self.assertSequenceEqual(inputs.shape, (6, 200, 4)) + + torch_inputs, _ = samples_batch.torch_inputs_and_targets(use_cuda=False) + self.assertIsInstance(torch_inputs, torch.Tensor) + self.assertSequenceEqual(torch_inputs.shape, (6, 4, 200)) + + def test_multiple_inputs(self): + samples_batch = SamplesBatch( + np.ones((6, 200, 4)), + other_input_batches={"something": np.ones(10)}, + target_batch=np.ones(20), + ) + + inputs = samples_batch.inputs() + self.assertIsInstance(inputs, dict) + self.assertEqual(len(inputs), 2) + self.assertSequenceEqual(inputs["sequence_batch"].shape, (6, 200, 4)) + + torch_inputs, _ = samples_batch.torch_inputs_and_targets(use_cuda=False) + self.assertIsInstance(torch_inputs, dict) + self.assertEqual(len(torch_inputs), 2) + self.assertSequenceEqual(torch_inputs["sequence_batch"].shape, (6, 4, 200)) + + def test_has_target(self): + samples_batch = SamplesBatch(np.ones((6, 200, 4)), target_batch=np.ones(20)) + targets = samples_batch.targets() + self.assertIsInstance(targets, np.ndarray) + _, torch_targets = samples_batch.torch_inputs_and_targets(use_cuda=False) + self.assertIsInstance(torch_targets, torch.Tensor) + + def test_no_target(self): + samples_batch = SamplesBatch(np.ones((6, 200, 4))) + self.assertIsNone(samples_batch.targets()) + _, torch_targets = samples_batch.torch_inputs_and_targets(use_cuda=False) + self.assertIsNone(torch_targets) diff --git a/selene_sdk/sequences/tests/__init__.py b/selene_sdk/sequences/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/selene_sdk/targets/tests/__init__.py b/selene_sdk/targets/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/selene_sdk/train_model.py b/selene_sdk/train_model.py index 08aaaa85..854e09de 100644 --- a/selene_sdk/train_model.py +++ b/selene_sdk/train_model.py @@ -5,20 +5,17 @@ import math import os import shutil -from time import strftime -from time import time +from time import strftime, time import numpy as np import torch import torch.nn as nn +from sklearn.metrics import average_precision_score, roc_auc_score from torch.autograd import Variable from torch.optim.lr_scheduler import ReduceLROnPlateau -from sklearn.metrics import roc_auc_score -from sklearn.metrics import average_precision_score -from .utils import initialize_logger -from .utils import load_model_from_state_dict -from .utils import PerformanceMetrics +from .utils import (PerformanceMetrics, initialize_logger, + load_model_from_state_dict) logger = logging.getLogger("selene") @@ -349,19 +346,18 @@ def _get_batch(self): Returns ------- - tuple(numpy.ndarray, numpy.ndarray) - A tuple containing the examples and targets. + SamplesBatch + A batch containing the examples and targets. """ t_i_sampling = time() - batch_sequences, batch_targets = self.sampler.sample( - batch_size=self.batch_size) + samples_batch = self.sampler.sample(batch_size=self.batch_size) t_f_sampling = time() logger.debug( ("[BATCH] Time to sample {0} examples: {1} s.").format( self.batch_size, t_f_sampling - t_i_sampling)) - return (batch_sequences, batch_targets) + return samples_batch def train_and_validate(self): """ @@ -450,18 +446,9 @@ def train(self): self.model.train() self.sampler.set_mode("train") - inputs, targets = self._get_batch() - inputs = torch.Tensor(inputs) - targets = torch.Tensor(targets) - - if self.use_cuda: - inputs = inputs.cuda() - targets = targets.cuda() - - inputs = Variable(inputs) - targets = Variable(targets) - - predictions = self.model(inputs.transpose(1, 2)) + samples_batch = self._get_batch() + inputs, targets = samples_batch.torch_inputs_and_targets(self.use_cuda) + predictions = self.model(inputs) loss = self.criterion(predictions, targets) self.optimizer.zero_grad() @@ -476,7 +463,7 @@ def _evaluate_on_data(self, data_in_batches): Parameters ---------- - data_in_batches : list(tuple(numpy.ndarray, numpy.ndarray)) + data_in_batches : list(SamplesBatch) A list of tuples of the data, where the first element is the example, and the second element is the label. @@ -491,19 +478,11 @@ def _evaluate_on_data(self, data_in_batches): batch_losses = [] all_predictions = [] - for (inputs, targets) in data_in_batches: - inputs = torch.Tensor(inputs) - targets = torch.Tensor(targets) - - if self.use_cuda: - inputs = inputs.cuda() - targets = targets.cuda() + for samples_batch in data_in_batches: + inputs, targets = samples_batch.torch_inputs_and_targets(self.use_cuda) with torch.no_grad(): - inputs = Variable(inputs) - targets = Variable(targets) - - predictions = self.model(inputs.transpose(1, 2)) + predictions = self.model(inputs) loss = self.criterion(predictions, targets) all_predictions.append( @@ -622,4 +601,3 @@ def _save_checkpoint(self, best_filepath = os.path.join(self.output_dir, "best_model") shutil.copyfile("{0}.pth.tar".format(cp_filepath), "{0}.pth.tar".format(best_filepath)) - diff --git a/selene_sdk/utils/non_strand_specific_module.py b/selene_sdk/utils/non_strand_specific_module.py index 367a46a9..59698f45 100644 --- a/selene_sdk/utils/non_strand_specific_module.py +++ b/selene_sdk/utils/non_strand_specific_module.py @@ -43,8 +43,6 @@ class NonStrandSpecific(Module): ---------- model : torch.nn.Module The user-specified model architecture. - mode : {'mean', 'max'} - How to handle outputting a non-strand specific prediction. """ @@ -53,25 +51,46 @@ def __init__(self, model, mode="mean"): self.model = model - if mode != "mean" and mode != "max": + if mode == "mean": + self.reduce_fn = lambda x, y: (x + y) / 2 + elif mode == "max": + self.reduce_fn = torch.max + else: raise ValueError("Mode should be one of 'mean' or 'max' but was" "{0}.".format(mode)) - self.mode = mode + self.from_lua = _is_lua_trained_model(model) - def forward(self, input): - reverse_input = None + def _forward_input_with_reversed_sequence(self, input): + multi_inputs = isinstance(input, dict) + sequence = input if not multi_inputs else input["sequence_batch"] + reversed_sequence = None if self.from_lua: - reverse_input = _flip( - _flip(torch.squeeze(input, 2), 1), 2).unsqueeze_(2) + reversed_sequence = _flip( + _flip(torch.squeeze(sequence, 2), 1), 2).unsqueeze_(2) else: - reverse_input = _flip(_flip(input, 1), 2) + reversed_sequence = _flip(_flip(sequence, 1), 2) - output = self.model.forward(input) - output_from_rev = self.model.forward(reverse_input) - - if self.mode == "mean": - return (output + output_from_rev) / 2 + input_rev = None + if multi_inputs: + input_rev = input.copy() + input_rev["sequence_batch"] = reversed_sequence else: - return torch.max(output, output_from_rev) + input_rev = reversed_sequence + + return self.model.forward(input_rev) + + def forward(self, input): + """Computes NN output for the given sequence and for a reversed sequence, + applies `self.reduce_fn` function to those outputs, and returns the result. + + Parameters + ---------- + input : numpy.ndarray or dict(str, numpy.ndarray) + Model's inputs. Can be just a sequence or multi-inputs. + + """ + output = self.model.forward(input) + output_from_rev = self._forward_input_with_reversed_sequence(input) + return self.reduce_fn(output, output_from_rev)