Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multi-inputs #163

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 5 additions & 12 deletions selene_sdk/evaluate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,24 +209,17 @@ def evaluate(self):
"""
batch_losses = []
all_predictions = []
for (inputs, targets) in self._test_data:
inputs = torch.Tensor(inputs)
targets = torch.Tensor(targets[:, self._use_ixs])
for samples_batch in self._test_data:
inputs, targets = samples_batch.torch_inputs_and_targets(self.use_cuda)
targets = targets[:, self._use_ixs]

if self.use_cuda:
inputs = inputs.cuda()
targets = targets.cuda()
with torch.no_grad():
inputs = Variable(inputs)
targets = Variable(targets)

predictions = None
if _is_lua_trained_model(self.model):
predictions = self.model.forward(
inputs.transpose(1, 2).contiguous().unsqueeze_(2))
inputs.contiguous().unsqueeze_(2))
else:
predictions = self.model.forward(
inputs.transpose(1, 2))
predictions = self.model.forward(inputs)
predictions = predictions[:, self._use_ixs]
loss = self.criterion(predictions, targets)

Expand Down
Empty file.
44 changes: 24 additions & 20 deletions selene_sdk/samplers/file_samplers/bed_file_sampler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
This module provides the BedFileSampler class.
"""
from selene_sdk.samplers.samples_batch import SamplesBatch
import numpy as np

from .file_sampler import FileSampler
Expand Down Expand Up @@ -96,8 +97,8 @@ def sample(self, batch_size=1):

Returns
-------
sequences, targets : tuple(numpy.ndarray, numpy.ndarray)
A tuple containing the numeric representation of the
SamplesBatch
A batch containing the numeric representation of the
sequence examples and their corresponding labels. The
shape of `sequences` will be
:math:`B \\times L \\times N`, where :math:`B` is
Expand Down Expand Up @@ -163,8 +164,8 @@ def sample(self, batch_size=1):
sequences = np.array(sequences)
if self.targets_avail:
targets = np.array(targets)
return (sequences, targets)
return sequences,
return SamplesBatch(sequences, target_batch=targets)
return SamplesBatch(sequences)

def get_data(self, batch_size, n_samples=None):
"""
Expand All @@ -188,18 +189,21 @@ def get_data(self, batch_size, n_samples=None):
and :math:`N` is the size of the sequence type's alphabet.

"""
# TODO: Should this method return a collection of samples_batch.inputs()?

if not n_samples:
n_samples = self.n_samples
sequences = []

count = batch_size
while count < n_samples:
seqs, = self.sample(batch_size=batch_size)
sequences.append(seqs)
samples_batch = self.sample(batch_size=batch_size)
sequences.append(samples_batch.sequence_batch())
count += batch_size
remainder = batch_size - (count - n_samples)
seqs, = self.sample(batch_size=remainder)
sequences.append(seqs)
samples_batch = self.sample(batch_size=remainder)
sequences.append(samples_batch.sequence_batch())

return sequences

def get_data_and_targets(self, batch_size, n_samples=None):
Expand All @@ -216,11 +220,11 @@ def get_data_and_targets(self, batch_size, n_samples=None):

Returns
-------
sequences_and_targets, targets_matrix : \
tuple(list(tuple(numpy.ndarray, numpy.ndarray)), numpy.ndarray)
Tuple containing the list of sequence-target pairs, as well
batches, targets_matrix : \
tuple(list(SamplesBatch), numpy.ndarray)
Tuple containing the list of batches, as well
as a single matrix with all targets in the same order.
Note that `sequences_and_targets`'s sequence elements are of
Note that `batches`'s sequence elements are of
the shape :math:`B \\times L \\times N` and its target
elements are of the shape :math:`B \\times F`, where
:math:`B` is `batch_size`, :math:`L` is the sequence length,
Expand All @@ -236,18 +240,18 @@ def get_data_and_targets(self, batch_size, n_samples=None):
"Please use `get_data` instead.")
if not n_samples:
n_samples = self.n_samples
sequences_and_targets = []
batches = []
targets_mat = []

count = batch_size
while count < n_samples:
seqs, tgts = self.sample(batch_size=batch_size)
sequences_and_targets.append((seqs, tgts))
targets_mat.append(tgts)
samples_batch = self.sample(batch_size=batch_size)
batches.append(samples_batch)
targets_mat.append(samples_batch.targets())
count += batch_size
remainder = batch_size - (count - n_samples)
seqs, tgts = self.sample(batch_size=remainder)
sequences_and_targets.append((seqs, tgts))
targets_mat.append(tgts)
samples_batch = self.sample(batch_size=remainder)
batches.append(samples_batch)
targets_mat.append(samples_batch.targets())
targets_mat = np.vstack(targets_mat).astype(int)
return sequences_and_targets, targets_mat
return batches, targets_mat
4 changes: 3 additions & 1 deletion selene_sdk/samplers/file_samplers/file_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from abc import ABCMeta
from abc import abstractmethod

from selene_sdk.samplers.samples_batch import SamplesBatch


class FileSampler(metaclass=ABCMeta):
"""
Expand All @@ -26,7 +28,7 @@ def __init__(self):
"""

@abstractmethod
def sample(self, batch_size=1):
def sample(self, batch_size=1) -> SamplesBatch:
"""
Fetches a mini-batch of the data from the sampler.

Expand Down
43 changes: 23 additions & 20 deletions selene_sdk/samplers/file_samplers/mat_file_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import h5py
import numpy as np
import scipy.io
from selene_sdk.samplers.samples_batch import SamplesBatch

from .file_sampler import FileSampler

Expand Down Expand Up @@ -126,8 +127,8 @@ def sample(self, batch_size=1):

Returns
-------
sequences, targets : tuple(numpy.ndarray, numpy.ndarray)
A tuple containing the numeric representation of the
SamplesBatch
A batch containing the numeric representation of the
sequence examples and their corresponding labels. The
shape of `sequences` will be
:math:`B \\times L \\times N`, where :math:`B` is
Expand Down Expand Up @@ -166,8 +167,8 @@ def sample(self, batch_size=1):
targets = self._sample_tgts[:, use_indices].astype(float)
targets = np.transpose(
targets, (1, 0))
return (sequences, targets)
return sequences,
return SamplesBatch(sequences, target_batch=targets)
return SamplesBatch(sequences)

def get_data(self, batch_size, n_samples=None):
"""
Expand All @@ -190,18 +191,20 @@ def get_data(self, batch_size, n_samples=None):
is `batch_size`, :math:`L` is the sequence length,
and :math:`N` is the size of the sequence type's alphabet.
"""
# TODO: Should this method return a collection of samples_batch.inputs()?

if not n_samples:
n_samples = self.n_samples
sequences = []

count = batch_size
while count < n_samples:
seqs, = self.sample(batch_size=batch_size)
sequences.append(seqs)
samples_batch = self.sample(batch_size=batch_size)
sequences.append(samples_batch.sequence_batch())
count += batch_size
remainder = batch_size - (count - n_samples)
seqs, = self.sample(batch_size=remainder)
sequences.append(seqs)
samples_batch = self.sample(batch_size=remainder)
sequences.append(samples_batch.sequence_batch())
return sequences

def get_data_and_targets(self, batch_size, n_samples=None):
Expand All @@ -218,11 +221,11 @@ def get_data_and_targets(self, batch_size, n_samples=None):

Returns
-------
sequences_and_targets, targets_matrix : \
tuple(list(tuple(numpy.ndarray, numpy.ndarray)), numpy.ndarray)
Tuple containing the list of sequence-target pairs, as well
batches, targets_matrix : \
tuple(list(SamplesBatch), numpy.ndarray)
Tuple containing the list of batches, as well
as a single matrix with all targets in the same order.
Note that `sequences_and_targets`'s sequence elements are of
Note that `batches`'s sequence elements are of
the shape :math:`B \\times L \\times N` and its target
elements are of the shape :math:`B \\times F`, where
:math:`B` is `batch_size`, :math:`L` is the sequence length,
Expand All @@ -237,19 +240,19 @@ def get_data_and_targets(self, batch_size, n_samples=None):
"initialization. Please use `get_data` instead.")
if not n_samples:
n_samples = self.n_samples
sequences_and_targets = []
batches = []
targets_mat = []

count = batch_size
while count < n_samples:
seqs, tgts = self.sample(batch_size=batch_size)
sequences_and_targets.append((seqs, tgts))
targets_mat.append(tgts)
samples_batch = self.sample(batch_size=batch_size)
batches.append(samples_batch)
targets_mat.append(samples_batch.targets())
count += batch_size
remainder = batch_size - (count - n_samples)
seqs, tgts = self.sample(batch_size=remainder)
sequences_and_targets.append((seqs, tgts))
targets_mat.append(tgts)
samples_batch = self.sample(batch_size=remainder)
batches.append(samples_batch)
targets_mat.append(samples_batch.targets())
# TODO: should not assume targets are always integers
targets_mat = np.vstack(targets_mat).astype(float)
return sequences_and_targets, targets_mat
return batches, targets_mat
11 changes: 6 additions & 5 deletions selene_sdk/samplers/intervals_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
This module provides the `IntervalsSampler` class and supporting
methods.
"""
from collections import namedtuple
import logging
import random
from collections import namedtuple

import numpy as np

from .online_sampler import OnlineSampler
from selene_sdk.samplers.samples_batch import SamplesBatch
from ..utils import get_indices_and_probabilities
from .online_sampler import OnlineSampler

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -388,8 +389,8 @@ def sample(self, batch_size=1):

Returns
-------
sequences, targets : tuple(numpy.ndarray, numpy.ndarray)
A tuple containing the numeric representation of the
SamplesBatch
A batch containing the numeric representation of the
sequence examples and their corresponding labels. The
shape of `sequences` will be
:math:`B \\times L \\times N`, where :math:`B` is
Expand Down Expand Up @@ -426,4 +427,4 @@ def sample(self, batch_size=1):
sequences[n_samples_drawn, :, :] = seq
targets[n_samples_drawn, :] = seq_targets
n_samples_drawn += 1
return (sequences, targets)
return SamplesBatch(sequences, target_batch=targets)
8 changes: 4 additions & 4 deletions selene_sdk/samplers/multi_file_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,11 +186,11 @@ def get_test_set(self, batch_size, n_samples=None):

Returns
-------
sequences_and_targets, targets_matrix : \
tuple(list(tuple(numpy.ndarray, numpy.ndarray)), numpy.ndarray)
Tuple containing the list of sequence-target pairs, as well
batches, targets_matrix : \
tuple(list(SamplesBatch), numpy.ndarray)
Tuple containing the list of batches, as well
as a single matrix with all targets in the same order.
Note that `sequences_and_targets`'s sequence elements are of
Note that `batches`'s sequence elements are of
the shape :math:`B \\times L \\times N` and its target
elements are of the shape :math:`B \\times F`, where
:math:`B` is `batch_size`, :math:`L` is the sequence length,
Expand Down
Loading