From 32d0c899a6683a96f2b8836f4bf9e5f5d4e2300b Mon Sep 17 00:00:00 2001 From: Alex Lapin Date: Tue, 2 Mar 2021 17:03:39 +0000 Subject: [PATCH] Utilize SamplesBatch class --- selene_sdk/evaluate_model.py | 17 ++---- .../file_samplers/bed_file_sampler.py | 44 +++++++++------- .../samplers/file_samplers/file_sampler.py | 4 +- .../file_samplers/mat_file_sampler.py | 43 ++++++++------- selene_sdk/samplers/intervals_sampler.py | 11 ++-- selene_sdk/samplers/multi_file_sampler.py | 8 +-- selene_sdk/samplers/online_sampler.py | 46 ++++++++-------- .../samplers/random_positions_sampler.py | 8 +-- selene_sdk/samplers/sampler.py | 32 +++++++++--- selene_sdk/train_model.py | 52 ++++++------------- .../utils/non_strand_specific_module.py | 49 +++++++++++------ 11 files changed, 167 insertions(+), 147 deletions(-) diff --git a/selene_sdk/evaluate_model.py b/selene_sdk/evaluate_model.py index 90146ad3..ab5931b0 100644 --- a/selene_sdk/evaluate_model.py +++ b/selene_sdk/evaluate_model.py @@ -209,24 +209,17 @@ def evaluate(self): """ batch_losses = [] all_predictions = [] - for (inputs, targets) in self._test_data: - inputs = torch.Tensor(inputs) - targets = torch.Tensor(targets[:, self._use_ixs]) + for samples_batch in self._test_data: + inputs, targets = samples_batch.torch_inputs_and_targets(self.use_cuda) + targets = targets[:, self._use_ixs] - if self.use_cuda: - inputs = inputs.cuda() - targets = targets.cuda() with torch.no_grad(): - inputs = Variable(inputs) - targets = Variable(targets) - predictions = None if _is_lua_trained_model(self.model): predictions = self.model.forward( - inputs.transpose(1, 2).contiguous().unsqueeze_(2)) + inputs.contiguous().unsqueeze_(2)) else: - predictions = self.model.forward( - inputs.transpose(1, 2)) + predictions = self.model.forward(inputs) predictions = predictions[:, self._use_ixs] loss = self.criterion(predictions, targets) diff --git a/selene_sdk/samplers/file_samplers/bed_file_sampler.py b/selene_sdk/samplers/file_samplers/bed_file_sampler.py index dfa16027..5592a35b 100644 --- a/selene_sdk/samplers/file_samplers/bed_file_sampler.py +++ b/selene_sdk/samplers/file_samplers/bed_file_sampler.py @@ -1,6 +1,7 @@ """ This module provides the BedFileSampler class. """ +from selene_sdk.samplers.samples_batch import SamplesBatch import numpy as np from .file_sampler import FileSampler @@ -96,8 +97,8 @@ def sample(self, batch_size=1): Returns ------- - sequences, targets : tuple(numpy.ndarray, numpy.ndarray) - A tuple containing the numeric representation of the + SamplesBatch + A batch containing the numeric representation of the sequence examples and their corresponding labels. The shape of `sequences` will be :math:`B \\times L \\times N`, where :math:`B` is @@ -163,8 +164,8 @@ def sample(self, batch_size=1): sequences = np.array(sequences) if self.targets_avail: targets = np.array(targets) - return (sequences, targets) - return sequences, + return SamplesBatch(sequences, target_batch=targets) + return SamplesBatch(sequences) def get_data(self, batch_size, n_samples=None): """ @@ -188,18 +189,21 @@ def get_data(self, batch_size, n_samples=None): and :math:`N` is the size of the sequence type's alphabet. """ + # TODO: Should this method return a collection of samples_batch.inputs()? + if not n_samples: n_samples = self.n_samples sequences = [] count = batch_size while count < n_samples: - seqs, = self.sample(batch_size=batch_size) - sequences.append(seqs) + samples_batch = self.sample(batch_size=batch_size) + sequences.append(samples_batch.sequence_batch()) count += batch_size remainder = batch_size - (count - n_samples) - seqs, = self.sample(batch_size=remainder) - sequences.append(seqs) + samples_batch = self.sample(batch_size=remainder) + sequences.append(samples_batch.sequence_batch()) + return sequences def get_data_and_targets(self, batch_size, n_samples=None): @@ -216,11 +220,11 @@ def get_data_and_targets(self, batch_size, n_samples=None): Returns ------- - sequences_and_targets, targets_matrix : \ - tuple(list(tuple(numpy.ndarray, numpy.ndarray)), numpy.ndarray) - Tuple containing the list of sequence-target pairs, as well + batches, targets_matrix : \ + tuple(list(SamplesBatch), numpy.ndarray) + Tuple containing the list of batches, as well as a single matrix with all targets in the same order. - Note that `sequences_and_targets`'s sequence elements are of + Note that `batches`'s sequence elements are of the shape :math:`B \\times L \\times N` and its target elements are of the shape :math:`B \\times F`, where :math:`B` is `batch_size`, :math:`L` is the sequence length, @@ -236,18 +240,18 @@ def get_data_and_targets(self, batch_size, n_samples=None): "Please use `get_data` instead.") if not n_samples: n_samples = self.n_samples - sequences_and_targets = [] + batches = [] targets_mat = [] count = batch_size while count < n_samples: - seqs, tgts = self.sample(batch_size=batch_size) - sequences_and_targets.append((seqs, tgts)) - targets_mat.append(tgts) + samples_batch = self.sample(batch_size=batch_size) + batches.append(samples_batch) + targets_mat.append(samples_batch.targets()) count += batch_size remainder = batch_size - (count - n_samples) - seqs, tgts = self.sample(batch_size=remainder) - sequences_and_targets.append((seqs, tgts)) - targets_mat.append(tgts) + samples_batch = self.sample(batch_size=remainder) + batches.append(samples_batch) + targets_mat.append(samples_batch.targets()) targets_mat = np.vstack(targets_mat).astype(int) - return sequences_and_targets, targets_mat + return batches, targets_mat diff --git a/selene_sdk/samplers/file_samplers/file_sampler.py b/selene_sdk/samplers/file_samplers/file_sampler.py index d29f7444..ec2628d5 100644 --- a/selene_sdk/samplers/file_samplers/file_sampler.py +++ b/selene_sdk/samplers/file_samplers/file_sampler.py @@ -8,6 +8,8 @@ from abc import ABCMeta from abc import abstractmethod +from selene_sdk.samplers.samples_batch import SamplesBatch + class FileSampler(metaclass=ABCMeta): """ @@ -26,7 +28,7 @@ def __init__(self): """ @abstractmethod - def sample(self, batch_size=1): + def sample(self, batch_size=1) -> SamplesBatch: """ Fetches a mini-batch of the data from the sampler. diff --git a/selene_sdk/samplers/file_samplers/mat_file_sampler.py b/selene_sdk/samplers/file_samplers/mat_file_sampler.py index f18bd1a2..28c77b34 100644 --- a/selene_sdk/samplers/file_samplers/mat_file_sampler.py +++ b/selene_sdk/samplers/file_samplers/mat_file_sampler.py @@ -5,6 +5,7 @@ import h5py import numpy as np import scipy.io +from selene_sdk.samplers.samples_batch import SamplesBatch from .file_sampler import FileSampler @@ -126,8 +127,8 @@ def sample(self, batch_size=1): Returns ------- - sequences, targets : tuple(numpy.ndarray, numpy.ndarray) - A tuple containing the numeric representation of the + SamplesBatch + A batch containing the numeric representation of the sequence examples and their corresponding labels. The shape of `sequences` will be :math:`B \\times L \\times N`, where :math:`B` is @@ -166,8 +167,8 @@ def sample(self, batch_size=1): targets = self._sample_tgts[:, use_indices].astype(float) targets = np.transpose( targets, (1, 0)) - return (sequences, targets) - return sequences, + return SamplesBatch(sequences, target_batch=targets) + return SamplesBatch(sequences) def get_data(self, batch_size, n_samples=None): """ @@ -190,18 +191,20 @@ def get_data(self, batch_size, n_samples=None): is `batch_size`, :math:`L` is the sequence length, and :math:`N` is the size of the sequence type's alphabet. """ + # TODO: Should this method return a collection of samples_batch.inputs()? + if not n_samples: n_samples = self.n_samples sequences = [] count = batch_size while count < n_samples: - seqs, = self.sample(batch_size=batch_size) - sequences.append(seqs) + samples_batch = self.sample(batch_size=batch_size) + sequences.append(samples_batch.sequence_batch()) count += batch_size remainder = batch_size - (count - n_samples) - seqs, = self.sample(batch_size=remainder) - sequences.append(seqs) + samples_batch = self.sample(batch_size=remainder) + sequences.append(samples_batch.sequence_batch()) return sequences def get_data_and_targets(self, batch_size, n_samples=None): @@ -218,11 +221,11 @@ def get_data_and_targets(self, batch_size, n_samples=None): Returns ------- - sequences_and_targets, targets_matrix : \ - tuple(list(tuple(numpy.ndarray, numpy.ndarray)), numpy.ndarray) - Tuple containing the list of sequence-target pairs, as well + batches, targets_matrix : \ + tuple(list(SamplesBatch), numpy.ndarray) + Tuple containing the list of batches, as well as a single matrix with all targets in the same order. - Note that `sequences_and_targets`'s sequence elements are of + Note that `batches`'s sequence elements are of the shape :math:`B \\times L \\times N` and its target elements are of the shape :math:`B \\times F`, where :math:`B` is `batch_size`, :math:`L` is the sequence length, @@ -237,19 +240,19 @@ def get_data_and_targets(self, batch_size, n_samples=None): "initialization. Please use `get_data` instead.") if not n_samples: n_samples = self.n_samples - sequences_and_targets = [] + batches = [] targets_mat = [] count = batch_size while count < n_samples: - seqs, tgts = self.sample(batch_size=batch_size) - sequences_and_targets.append((seqs, tgts)) - targets_mat.append(tgts) + samples_batch = self.sample(batch_size=batch_size) + batches.append(samples_batch) + targets_mat.append(samples_batch.targets()) count += batch_size remainder = batch_size - (count - n_samples) - seqs, tgts = self.sample(batch_size=remainder) - sequences_and_targets.append((seqs, tgts)) - targets_mat.append(tgts) + samples_batch = self.sample(batch_size=remainder) + batches.append(samples_batch) + targets_mat.append(samples_batch.targets()) # TODO: should not assume targets are always integers targets_mat = np.vstack(targets_mat).astype(float) - return sequences_and_targets, targets_mat + return batches, targets_mat diff --git a/selene_sdk/samplers/intervals_sampler.py b/selene_sdk/samplers/intervals_sampler.py index 497b0f40..5fe57d40 100644 --- a/selene_sdk/samplers/intervals_sampler.py +++ b/selene_sdk/samplers/intervals_sampler.py @@ -2,14 +2,15 @@ This module provides the `IntervalsSampler` class and supporting methods. """ -from collections import namedtuple import logging import random +from collections import namedtuple import numpy as np -from .online_sampler import OnlineSampler +from selene_sdk.samplers.samples_batch import SamplesBatch from ..utils import get_indices_and_probabilities +from .online_sampler import OnlineSampler logger = logging.getLogger(__name__) @@ -388,8 +389,8 @@ def sample(self, batch_size=1): Returns ------- - sequences, targets : tuple(numpy.ndarray, numpy.ndarray) - A tuple containing the numeric representation of the + SamplesBatch + A batch containing the numeric representation of the sequence examples and their corresponding labels. The shape of `sequences` will be :math:`B \\times L \\times N`, where :math:`B` is @@ -426,4 +427,4 @@ def sample(self, batch_size=1): sequences[n_samples_drawn, :, :] = seq targets[n_samples_drawn, :] = seq_targets n_samples_drawn += 1 - return (sequences, targets) + return SamplesBatch(sequences, target_batch=targets) diff --git a/selene_sdk/samplers/multi_file_sampler.py b/selene_sdk/samplers/multi_file_sampler.py index 5dbd0b99..ff0d1048 100644 --- a/selene_sdk/samplers/multi_file_sampler.py +++ b/selene_sdk/samplers/multi_file_sampler.py @@ -186,11 +186,11 @@ def get_test_set(self, batch_size, n_samples=None): Returns ------- - sequences_and_targets, targets_matrix : \ - tuple(list(tuple(numpy.ndarray, numpy.ndarray)), numpy.ndarray) - Tuple containing the list of sequence-target pairs, as well + batches, targets_matrix : \ + tuple(list(SamplesBatch), numpy.ndarray) + Tuple containing the list of batches, as well as a single matrix with all targets in the same order. - Note that `sequences_and_targets`'s sequence elements are of + Note that `batches`'s sequence elements are of the shape :math:`B \\times L \\times N` and its target elements are of the shape :math:`B \\times F`, where :math:`B` is `batch_size`, :math:`L` is the sequence length, diff --git a/selene_sdk/samplers/online_sampler.py b/selene_sdk/samplers/online_sampler.py index b84a26ca..1c735d5a 100644 --- a/selene_sdk/samplers/online_sampler.py +++ b/selene_sdk/samplers/online_sampler.py @@ -4,14 +4,14 @@ "on the fly" rather than storing them all persistently in memory. """ -from abc import ABCMeta import os import random +from abc import ABCMeta import numpy as np -from .sampler import Sampler from ..targets import GenomicFeatures +from .sampler import Sampler class OnlineSampler(Sampler, metaclass=ABCMeta): @@ -302,11 +302,11 @@ def get_data_and_targets(self, batch_size, n_samples=None, mode=None): Returns ------- - sequences_and_targets, targets_matrix : \ - tuple(list(tuple(numpy.ndarray, numpy.ndarray)), numpy.ndarray) - Tuple containing the list of sequence-target pairs, as well + batches, targets_matrix : \ + tuple(list(SamplesBatch), numpy.ndarray) + Tuple containing the list of batches, as well as a single matrix with all targets in the same order. - Note that `sequences_and_targets`'s sequence elements are of + Note that `batches`'s sequence elements are of the shape :math:`B \\times L \\times N` and its target elements are of the shape :math:`B \\times F`, where :math:`B` is `batch_size`, :math:`L` is the sequence length, @@ -320,7 +320,7 @@ def get_data_and_targets(self, batch_size, n_samples=None, mode=None): self.set_mode(mode) else: mode = self.mode - sequences_and_targets = [] + batches = [] if n_samples is None and mode == "validate": n_samples = 32000 elif n_samples is None and mode == "test": @@ -328,12 +328,12 @@ def get_data_and_targets(self, batch_size, n_samples=None, mode=None): n_batches = int(n_samples / batch_size) for _ in range(n_batches): - inputs, targets = self.sample(batch_size) - sequences_and_targets.append((inputs, targets)) - targets_mat = np.vstack([t for (s, t) in sequences_and_targets]) + samples_batch = self.sample(batch_size) + batches.append(samples_batch) + targets_mat = np.vstack([batch.targets() for batch in batches]) if mode in self._save_datasets: self.save_dataset_to_file(mode, close_filehandle=True) - return sequences_and_targets, targets_mat + return batches, targets_mat def get_dataset_in_batches(self, mode, batch_size, n_samples=None): """ @@ -355,12 +355,12 @@ def get_dataset_in_batches(self, mode, batch_size, n_samples=None): Returns ------- - sequences_and_targets, targets_matrix : \ - tuple(list(tuple(numpy.ndarray, numpy.ndarray)), numpy.ndarray) - Tuple containing the list of sequence-target pairs, as well + batches, targets_matrix : \ + tuple(list(SamplesBatch), numpy.ndarray) + Tuple containing the list of batches, as well as a single matrix with all targets in the same order. The list is length :math:`S`, where :math:`S =` `n_samples`. - Note that `sequences_and_targets`'s sequence elements are of + Note that `batches`'s sequence elements are of the shape :math:`B \\times L \\times N` and its target elements are of the shape :math:`B \\times F`, where :math:`B` is `batch_size`, :math:`L` is the sequence length, @@ -387,11 +387,11 @@ def get_validation_set(self, batch_size, n_samples=None): Returns ------- - sequences_and_targets, targets_matrix : \ - tuple(list(tuple(numpy.ndarray, numpy.ndarray)), numpy.ndarray) - Tuple containing the list of sequence-target pairs, as well + batches, targets_matrix : \ + tuple(list(SamplesBatch), numpy.ndarray) + Tuple containing the list of batches, as well as a single matrix with all targets in the same order. - Note that `sequences_and_targets`'s sequence elements are of + Note that `batches`'s sequence elements are of the shape :math:`B \\times L \\times N` and its target elements are of the shape :math:`B \\times F`, where :math:`B` is `batch_size`, :math:`L` is the sequence length, @@ -419,11 +419,11 @@ def get_test_set(self, batch_size, n_samples=None): Returns ------- - sequences_and_targets, targets_matrix : \ - tuple(list(tuple(numpy.ndarray, numpy.ndarray)), numpy.ndarray) - Tuple containing the list of sequence-target pairs, as well + batches, targets_matrix : \ + tuple(list(SamplesBatch), numpy.ndarray) + Tuple containing the list of batches, as well as a single matrix with all targets in the same order. - Note that `sequences_and_targets`'s sequence elements are of + Note that `batches`'s sequence elements are of the shape :math:`B \\times L \\times N` and its target elements are of the shape :math:`B \\times F`, where :math:`B` is `batch_size`, :math:`L` is the sequence length, diff --git a/selene_sdk/samplers/random_positions_sampler.py b/selene_sdk/samplers/random_positions_sampler.py index 73662a83..4d86b4fe 100644 --- a/selene_sdk/samplers/random_positions_sampler.py +++ b/selene_sdk/samplers/random_positions_sampler.py @@ -8,8 +8,10 @@ import logging import random + import numpy as np +from selene_sdk.samplers.samples_batch import SamplesBatch from .online_sampler import OnlineSampler from ..utils import get_indices_and_probabilities @@ -305,8 +307,8 @@ def sample(self, batch_size=1): Returns ------- - sequences, targets : tuple(numpy.ndarray, numpy.ndarray) - A tuple containing the numeric representation of the + SamplesBatch + A batch containing the numeric representation of the sequence examples and their corresponding labels. The shape of `sequences` will be :math:`B \\times L \\times N`, where :math:`B` is @@ -340,4 +342,4 @@ def sample(self, batch_size=1): sequences[n_samples_drawn, :, :] = seq targets[n_samples_drawn, :] = seq_targets n_samples_drawn += 1 - return (sequences, targets) + return SamplesBatch(sequences, target_batch=targets) diff --git a/selene_sdk/samplers/sampler.py b/selene_sdk/samplers/sampler.py index 8d3a3dcb..c863a84f 100644 --- a/selene_sdk/samplers/sampler.py +++ b/selene_sdk/samplers/sampler.py @@ -6,6 +6,7 @@ from abc import abstractmethod import os +from selene_sdk.samplers.samples_batch import SamplesBatch class Sampler(metaclass=ABCMeta): """ @@ -98,7 +99,7 @@ def get_feature_from_index(self, index): raise NotImplementedError() @abstractmethod - def sample(self, batch_size=1): + def sample(self, batch_size=1) -> SamplesBatch: """ Fetches a mini-batch of the data from the sampler. @@ -107,6 +108,10 @@ def sample(self, batch_size=1): batch_size : int, optional Default is 1. The size of the batch to retrieve. + Returns + ------- + samples_batch : SamplesBatch + A struct containing inputs and targets in numpy format. """ raise NotImplementedError() @@ -146,6 +151,19 @@ def get_validation_set(self, batch_size, n_samples=None): retrieve. Handling for `n_samples=None` should be done by all classes that subclass `selene_sdk.samplers.Sampler`. + Returns + ------- + batches, targets_matrix : tuple(list(SamplesBatch), numpy.ndarray) + Tuple containing the list of batches (with targets), as well + as a single matrix with all targets in the same order. + Note that `batches`'s sequence elements are of + the shape :math:`B \\times L \\times N` and its target + elements are of the shape :math:`B \\times F`, where + :math:`B` is `batch_size`, :math:`L` is the sequence length, + :math:`N` is the size of the sequence type's alphabet, and + :math:`F` is the number of features. Further, + `target_matrix` is of the shape :math:`S \\times F`, where + :math:`S =` `n_samples`. """ raise NotImplementedError() @@ -160,16 +178,16 @@ def get_test_set(self, batch_size, n_samples=None): batch_size : int The size of the batches to divide the data into. n_samples : int or None, optional - Default is `None`. The total number of validation examples - to retrieve. If `None`, 640000 examples are retrieved. + Default is `None`. The total number of test examples + to retrieve. Handling for `n_samples=None` should be done by + all classes that subclass `selene_sdk.samplers.Sampler`. Returns ------- - sequences_and_targets, targets_matrix : \ - tuple(list(tuple(numpy.ndarray, numpy.ndarray)), numpy.ndarray) - Tuple containing the list of sequence-target pairs, as well + batches, targets_matrix : tuple(list(SamplesBatch), numpy.ndarray) + Tuple containing the list of batches (with targets), as well as a single matrix with all targets in the same order. - Note that `sequences_and_targets`'s sequence elements are of + Note that `batches`'s sequence elements are of the shape :math:`B \\times L \\times N` and its target elements are of the shape :math:`B \\times F`, where :math:`B` is `batch_size`, :math:`L` is the sequence length, diff --git a/selene_sdk/train_model.py b/selene_sdk/train_model.py index 08aaaa85..854e09de 100644 --- a/selene_sdk/train_model.py +++ b/selene_sdk/train_model.py @@ -5,20 +5,17 @@ import math import os import shutil -from time import strftime -from time import time +from time import strftime, time import numpy as np import torch import torch.nn as nn +from sklearn.metrics import average_precision_score, roc_auc_score from torch.autograd import Variable from torch.optim.lr_scheduler import ReduceLROnPlateau -from sklearn.metrics import roc_auc_score -from sklearn.metrics import average_precision_score -from .utils import initialize_logger -from .utils import load_model_from_state_dict -from .utils import PerformanceMetrics +from .utils import (PerformanceMetrics, initialize_logger, + load_model_from_state_dict) logger = logging.getLogger("selene") @@ -349,19 +346,18 @@ def _get_batch(self): Returns ------- - tuple(numpy.ndarray, numpy.ndarray) - A tuple containing the examples and targets. + SamplesBatch + A batch containing the examples and targets. """ t_i_sampling = time() - batch_sequences, batch_targets = self.sampler.sample( - batch_size=self.batch_size) + samples_batch = self.sampler.sample(batch_size=self.batch_size) t_f_sampling = time() logger.debug( ("[BATCH] Time to sample {0} examples: {1} s.").format( self.batch_size, t_f_sampling - t_i_sampling)) - return (batch_sequences, batch_targets) + return samples_batch def train_and_validate(self): """ @@ -450,18 +446,9 @@ def train(self): self.model.train() self.sampler.set_mode("train") - inputs, targets = self._get_batch() - inputs = torch.Tensor(inputs) - targets = torch.Tensor(targets) - - if self.use_cuda: - inputs = inputs.cuda() - targets = targets.cuda() - - inputs = Variable(inputs) - targets = Variable(targets) - - predictions = self.model(inputs.transpose(1, 2)) + samples_batch = self._get_batch() + inputs, targets = samples_batch.torch_inputs_and_targets(self.use_cuda) + predictions = self.model(inputs) loss = self.criterion(predictions, targets) self.optimizer.zero_grad() @@ -476,7 +463,7 @@ def _evaluate_on_data(self, data_in_batches): Parameters ---------- - data_in_batches : list(tuple(numpy.ndarray, numpy.ndarray)) + data_in_batches : list(SamplesBatch) A list of tuples of the data, where the first element is the example, and the second element is the label. @@ -491,19 +478,11 @@ def _evaluate_on_data(self, data_in_batches): batch_losses = [] all_predictions = [] - for (inputs, targets) in data_in_batches: - inputs = torch.Tensor(inputs) - targets = torch.Tensor(targets) - - if self.use_cuda: - inputs = inputs.cuda() - targets = targets.cuda() + for samples_batch in data_in_batches: + inputs, targets = samples_batch.torch_inputs_and_targets(self.use_cuda) with torch.no_grad(): - inputs = Variable(inputs) - targets = Variable(targets) - - predictions = self.model(inputs.transpose(1, 2)) + predictions = self.model(inputs) loss = self.criterion(predictions, targets) all_predictions.append( @@ -622,4 +601,3 @@ def _save_checkpoint(self, best_filepath = os.path.join(self.output_dir, "best_model") shutil.copyfile("{0}.pth.tar".format(cp_filepath), "{0}.pth.tar".format(best_filepath)) - diff --git a/selene_sdk/utils/non_strand_specific_module.py b/selene_sdk/utils/non_strand_specific_module.py index 367a46a9..59698f45 100644 --- a/selene_sdk/utils/non_strand_specific_module.py +++ b/selene_sdk/utils/non_strand_specific_module.py @@ -43,8 +43,6 @@ class NonStrandSpecific(Module): ---------- model : torch.nn.Module The user-specified model architecture. - mode : {'mean', 'max'} - How to handle outputting a non-strand specific prediction. """ @@ -53,25 +51,46 @@ def __init__(self, model, mode="mean"): self.model = model - if mode != "mean" and mode != "max": + if mode == "mean": + self.reduce_fn = lambda x, y: (x + y) / 2 + elif mode == "max": + self.reduce_fn = torch.max + else: raise ValueError("Mode should be one of 'mean' or 'max' but was" "{0}.".format(mode)) - self.mode = mode + self.from_lua = _is_lua_trained_model(model) - def forward(self, input): - reverse_input = None + def _forward_input_with_reversed_sequence(self, input): + multi_inputs = isinstance(input, dict) + sequence = input if not multi_inputs else input["sequence_batch"] + reversed_sequence = None if self.from_lua: - reverse_input = _flip( - _flip(torch.squeeze(input, 2), 1), 2).unsqueeze_(2) + reversed_sequence = _flip( + _flip(torch.squeeze(sequence, 2), 1), 2).unsqueeze_(2) else: - reverse_input = _flip(_flip(input, 1), 2) + reversed_sequence = _flip(_flip(sequence, 1), 2) - output = self.model.forward(input) - output_from_rev = self.model.forward(reverse_input) - - if self.mode == "mean": - return (output + output_from_rev) / 2 + input_rev = None + if multi_inputs: + input_rev = input.copy() + input_rev["sequence_batch"] = reversed_sequence else: - return torch.max(output, output_from_rev) + input_rev = reversed_sequence + + return self.model.forward(input_rev) + + def forward(self, input): + """Computes NN output for the given sequence and for a reversed sequence, + applies `self.reduce_fn` function to those outputs, and returns the result. + + Parameters + ---------- + input : numpy.ndarray or dict(str, numpy.ndarray) + Model's inputs. Can be just a sequence or multi-inputs. + + """ + output = self.model.forward(input) + output_from_rev = self._forward_input_with_reversed_sequence(input) + return self.reduce_fn(output, output_from_rev)