Skip to content

Commit

Permalink
Utilize SamplesBatch class
Browse files Browse the repository at this point in the history
  • Loading branch information
Alex Lapin committed Mar 3, 2021
1 parent 8abac4f commit 32d0c89
Show file tree
Hide file tree
Showing 11 changed files with 167 additions and 147 deletions.
17 changes: 5 additions & 12 deletions selene_sdk/evaluate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,24 +209,17 @@ def evaluate(self):
"""
batch_losses = []
all_predictions = []
for (inputs, targets) in self._test_data:
inputs = torch.Tensor(inputs)
targets = torch.Tensor(targets[:, self._use_ixs])
for samples_batch in self._test_data:
inputs, targets = samples_batch.torch_inputs_and_targets(self.use_cuda)
targets = targets[:, self._use_ixs]

if self.use_cuda:
inputs = inputs.cuda()
targets = targets.cuda()
with torch.no_grad():
inputs = Variable(inputs)
targets = Variable(targets)

predictions = None
if _is_lua_trained_model(self.model):
predictions = self.model.forward(
inputs.transpose(1, 2).contiguous().unsqueeze_(2))
inputs.contiguous().unsqueeze_(2))
else:
predictions = self.model.forward(
inputs.transpose(1, 2))
predictions = self.model.forward(inputs)
predictions = predictions[:, self._use_ixs]
loss = self.criterion(predictions, targets)

Expand Down
44 changes: 24 additions & 20 deletions selene_sdk/samplers/file_samplers/bed_file_sampler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
This module provides the BedFileSampler class.
"""
from selene_sdk.samplers.samples_batch import SamplesBatch
import numpy as np

from .file_sampler import FileSampler
Expand Down Expand Up @@ -96,8 +97,8 @@ def sample(self, batch_size=1):
Returns
-------
sequences, targets : tuple(numpy.ndarray, numpy.ndarray)
A tuple containing the numeric representation of the
SamplesBatch
A batch containing the numeric representation of the
sequence examples and their corresponding labels. The
shape of `sequences` will be
:math:`B \\times L \\times N`, where :math:`B` is
Expand Down Expand Up @@ -163,8 +164,8 @@ def sample(self, batch_size=1):
sequences = np.array(sequences)
if self.targets_avail:
targets = np.array(targets)
return (sequences, targets)
return sequences,
return SamplesBatch(sequences, target_batch=targets)
return SamplesBatch(sequences)

def get_data(self, batch_size, n_samples=None):
"""
Expand All @@ -188,18 +189,21 @@ def get_data(self, batch_size, n_samples=None):
and :math:`N` is the size of the sequence type's alphabet.
"""
# TODO: Should this method return a collection of samples_batch.inputs()?

if not n_samples:
n_samples = self.n_samples
sequences = []

count = batch_size
while count < n_samples:
seqs, = self.sample(batch_size=batch_size)
sequences.append(seqs)
samples_batch = self.sample(batch_size=batch_size)
sequences.append(samples_batch.sequence_batch())
count += batch_size
remainder = batch_size - (count - n_samples)
seqs, = self.sample(batch_size=remainder)
sequences.append(seqs)
samples_batch = self.sample(batch_size=remainder)
sequences.append(samples_batch.sequence_batch())

return sequences

def get_data_and_targets(self, batch_size, n_samples=None):
Expand All @@ -216,11 +220,11 @@ def get_data_and_targets(self, batch_size, n_samples=None):
Returns
-------
sequences_and_targets, targets_matrix : \
tuple(list(tuple(numpy.ndarray, numpy.ndarray)), numpy.ndarray)
Tuple containing the list of sequence-target pairs, as well
batches, targets_matrix : \
tuple(list(SamplesBatch), numpy.ndarray)
Tuple containing the list of batches, as well
as a single matrix with all targets in the same order.
Note that `sequences_and_targets`'s sequence elements are of
Note that `batches`'s sequence elements are of
the shape :math:`B \\times L \\times N` and its target
elements are of the shape :math:`B \\times F`, where
:math:`B` is `batch_size`, :math:`L` is the sequence length,
Expand All @@ -236,18 +240,18 @@ def get_data_and_targets(self, batch_size, n_samples=None):
"Please use `get_data` instead.")
if not n_samples:
n_samples = self.n_samples
sequences_and_targets = []
batches = []
targets_mat = []

count = batch_size
while count < n_samples:
seqs, tgts = self.sample(batch_size=batch_size)
sequences_and_targets.append((seqs, tgts))
targets_mat.append(tgts)
samples_batch = self.sample(batch_size=batch_size)
batches.append(samples_batch)
targets_mat.append(samples_batch.targets())
count += batch_size
remainder = batch_size - (count - n_samples)
seqs, tgts = self.sample(batch_size=remainder)
sequences_and_targets.append((seqs, tgts))
targets_mat.append(tgts)
samples_batch = self.sample(batch_size=remainder)
batches.append(samples_batch)
targets_mat.append(samples_batch.targets())
targets_mat = np.vstack(targets_mat).astype(int)
return sequences_and_targets, targets_mat
return batches, targets_mat
4 changes: 3 additions & 1 deletion selene_sdk/samplers/file_samplers/file_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from abc import ABCMeta
from abc import abstractmethod

from selene_sdk.samplers.samples_batch import SamplesBatch


class FileSampler(metaclass=ABCMeta):
"""
Expand All @@ -26,7 +28,7 @@ def __init__(self):
"""

@abstractmethod
def sample(self, batch_size=1):
def sample(self, batch_size=1) -> SamplesBatch:
"""
Fetches a mini-batch of the data from the sampler.
Expand Down
43 changes: 23 additions & 20 deletions selene_sdk/samplers/file_samplers/mat_file_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import h5py
import numpy as np
import scipy.io
from selene_sdk.samplers.samples_batch import SamplesBatch

from .file_sampler import FileSampler

Expand Down Expand Up @@ -126,8 +127,8 @@ def sample(self, batch_size=1):
Returns
-------
sequences, targets : tuple(numpy.ndarray, numpy.ndarray)
A tuple containing the numeric representation of the
SamplesBatch
A batch containing the numeric representation of the
sequence examples and their corresponding labels. The
shape of `sequences` will be
:math:`B \\times L \\times N`, where :math:`B` is
Expand Down Expand Up @@ -166,8 +167,8 @@ def sample(self, batch_size=1):
targets = self._sample_tgts[:, use_indices].astype(float)
targets = np.transpose(
targets, (1, 0))
return (sequences, targets)
return sequences,
return SamplesBatch(sequences, target_batch=targets)
return SamplesBatch(sequences)

def get_data(self, batch_size, n_samples=None):
"""
Expand All @@ -190,18 +191,20 @@ def get_data(self, batch_size, n_samples=None):
is `batch_size`, :math:`L` is the sequence length,
and :math:`N` is the size of the sequence type's alphabet.
"""
# TODO: Should this method return a collection of samples_batch.inputs()?

if not n_samples:
n_samples = self.n_samples
sequences = []

count = batch_size
while count < n_samples:
seqs, = self.sample(batch_size=batch_size)
sequences.append(seqs)
samples_batch = self.sample(batch_size=batch_size)
sequences.append(samples_batch.sequence_batch())
count += batch_size
remainder = batch_size - (count - n_samples)
seqs, = self.sample(batch_size=remainder)
sequences.append(seqs)
samples_batch = self.sample(batch_size=remainder)
sequences.append(samples_batch.sequence_batch())
return sequences

def get_data_and_targets(self, batch_size, n_samples=None):
Expand All @@ -218,11 +221,11 @@ def get_data_and_targets(self, batch_size, n_samples=None):
Returns
-------
sequences_and_targets, targets_matrix : \
tuple(list(tuple(numpy.ndarray, numpy.ndarray)), numpy.ndarray)
Tuple containing the list of sequence-target pairs, as well
batches, targets_matrix : \
tuple(list(SamplesBatch), numpy.ndarray)
Tuple containing the list of batches, as well
as a single matrix with all targets in the same order.
Note that `sequences_and_targets`'s sequence elements are of
Note that `batches`'s sequence elements are of
the shape :math:`B \\times L \\times N` and its target
elements are of the shape :math:`B \\times F`, where
:math:`B` is `batch_size`, :math:`L` is the sequence length,
Expand All @@ -237,19 +240,19 @@ def get_data_and_targets(self, batch_size, n_samples=None):
"initialization. Please use `get_data` instead.")
if not n_samples:
n_samples = self.n_samples
sequences_and_targets = []
batches = []
targets_mat = []

count = batch_size
while count < n_samples:
seqs, tgts = self.sample(batch_size=batch_size)
sequences_and_targets.append((seqs, tgts))
targets_mat.append(tgts)
samples_batch = self.sample(batch_size=batch_size)
batches.append(samples_batch)
targets_mat.append(samples_batch.targets())
count += batch_size
remainder = batch_size - (count - n_samples)
seqs, tgts = self.sample(batch_size=remainder)
sequences_and_targets.append((seqs, tgts))
targets_mat.append(tgts)
samples_batch = self.sample(batch_size=remainder)
batches.append(samples_batch)
targets_mat.append(samples_batch.targets())
# TODO: should not assume targets are always integers
targets_mat = np.vstack(targets_mat).astype(float)
return sequences_and_targets, targets_mat
return batches, targets_mat
11 changes: 6 additions & 5 deletions selene_sdk/samplers/intervals_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
This module provides the `IntervalsSampler` class and supporting
methods.
"""
from collections import namedtuple
import logging
import random
from collections import namedtuple

import numpy as np

from .online_sampler import OnlineSampler
from selene_sdk.samplers.samples_batch import SamplesBatch
from ..utils import get_indices_and_probabilities
from .online_sampler import OnlineSampler

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -388,8 +389,8 @@ def sample(self, batch_size=1):
Returns
-------
sequences, targets : tuple(numpy.ndarray, numpy.ndarray)
A tuple containing the numeric representation of the
SamplesBatch
A batch containing the numeric representation of the
sequence examples and their corresponding labels. The
shape of `sequences` will be
:math:`B \\times L \\times N`, where :math:`B` is
Expand Down Expand Up @@ -426,4 +427,4 @@ def sample(self, batch_size=1):
sequences[n_samples_drawn, :, :] = seq
targets[n_samples_drawn, :] = seq_targets
n_samples_drawn += 1
return (sequences, targets)
return SamplesBatch(sequences, target_batch=targets)
8 changes: 4 additions & 4 deletions selene_sdk/samplers/multi_file_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,11 +186,11 @@ def get_test_set(self, batch_size, n_samples=None):
Returns
-------
sequences_and_targets, targets_matrix : \
tuple(list(tuple(numpy.ndarray, numpy.ndarray)), numpy.ndarray)
Tuple containing the list of sequence-target pairs, as well
batches, targets_matrix : \
tuple(list(SamplesBatch), numpy.ndarray)
Tuple containing the list of batches, as well
as a single matrix with all targets in the same order.
Note that `sequences_and_targets`'s sequence elements are of
Note that `batches`'s sequence elements are of
the shape :math:`B \\times L \\times N` and its target
elements are of the shape :math:`B \\times F`, where
:math:`B` is `batch_size`, :math:`L` is the sequence length,
Expand Down
Loading

0 comments on commit 32d0c89

Please sign in to comment.