-
Notifications
You must be signed in to change notification settings - Fork 91
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Alex Lapin
committed
Mar 2, 2021
1 parent
9bfa1fd
commit 8abac4f
Showing
6 changed files
with
154 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import numpy as np | ||
import torch | ||
|
||
|
||
class SamplesBatch: | ||
""" | ||
This class represents NN inputs and targets. Values are stored as numpy.ndarrays | ||
and there is a method to convert them to torch.Tensors. | ||
Inputs are stored in a dict, which can be used if you are providing more than just a | ||
`sequence_batch` to the NN. | ||
NOTE: If you store just a sequence as an input to the model, then `inputs()` and | ||
`torch_inputs_and_targets()` will return only the batch of sequences rather than | ||
a dict. | ||
""" | ||
|
||
_SEQUENCE_LABEL = "sequence_batch" | ||
|
||
def __init__( | ||
self, | ||
sequence_batch: np.ndarray, | ||
other_input_batches=dict(), | ||
target_batch: np.ndarray = None, | ||
) -> None: | ||
self._input_batches = other_input_batches.copy() | ||
self._input_batches[self._SEQUENCE_LABEL] = sequence_batch | ||
self._target_batch = target_batch | ||
|
||
def sequence_batch(self) -> torch.Tensor: | ||
"""Returns the sequence batch with a shape of | ||
[batch_size, sequence_length, alphabet_size]. | ||
""" | ||
return self._input_batches[self._SEQUENCE_LABEL] | ||
|
||
def inputs(self): | ||
"""Based on the size of inputs dictionary, returns either just the | ||
sequence or the whole dictionary. | ||
Returns | ||
------- | ||
numpy.ndarray or dict[str, numpy.ndarray] | ||
numpy.ndarray is returned when inputs contain just the sequence batch. | ||
dict[str, numpy.ndarray] is returned when there are multiple inputs. | ||
NOTE: Sequence batch has a shape of | ||
[batch_size, sequence_length, alphabet_size]. | ||
""" | ||
if len(self._input_batches) == 1: | ||
return self.sequence_batch() | ||
|
||
return self._input_batches | ||
|
||
def targets(self): | ||
"""Returns target batch if it is present. | ||
Returns | ||
------- | ||
numpy.ndarray | ||
""" | ||
return self._target_batch | ||
|
||
def torch_inputs_and_targets(self, use_cuda: bool): | ||
""" | ||
Returns inputs and targets in torch.Tensor format. | ||
Based on the size of inputs dictionary, returns either just the | ||
sequence or the whole dictionary. | ||
Returns | ||
------- | ||
inputs, targets :\ | ||
tuple(numpy.ndarray or dict[str, numpy.ndarray], numpy.ndarray) | ||
For `inputs`: | ||
numpy.ndarray is returned when inputs contain just the sequence batch. | ||
dict[str, numpy.ndarray] is returned when there are multiple inputs. | ||
NOTE: Returned sequence batch has a shape of | ||
[batch_size, alphabet_size, sequence_length]. | ||
""" | ||
all_inputs = dict() | ||
for key, value in self._input_batches.items(): | ||
all_inputs[key] = torch.Tensor(value) | ||
|
||
if use_cuda: | ||
all_inputs[key] = all_inputs[key].cuda() | ||
|
||
# Transpose the sequences to satisfy NN convolution input format (which is | ||
# [batch_size, channels_size, sequence_length]). | ||
all_inputs[self._SEQUENCE_LABEL] = all_inputs[self._SEQUENCE_LABEL].transpose( | ||
1, 2 | ||
) | ||
|
||
inputs = all_inputs if len(all_inputs) > 1 else all_inputs[self._SEQUENCE_LABEL] | ||
|
||
targets = None | ||
if self._target_batch is not None: | ||
targets = torch.Tensor(self._target_batch) | ||
|
||
if use_cuda: | ||
targets = targets.cuda() | ||
|
||
return inputs, targets |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import unittest | ||
|
||
import numpy as np | ||
import torch | ||
from selene_sdk.samplers.samples_batch import SamplesBatch | ||
|
||
|
||
class TestSamplesBatch(unittest.TestCase): | ||
def test_single_input(self): | ||
samples_batch = SamplesBatch(np.ones((6, 200, 4)), target_batch=np.ones(20)) | ||
|
||
inputs = samples_batch.inputs() | ||
self.assertIsInstance(inputs, np.ndarray) | ||
self.assertSequenceEqual(inputs.shape, (6, 200, 4)) | ||
|
||
torch_inputs, _ = samples_batch.torch_inputs_and_targets(use_cuda=False) | ||
self.assertIsInstance(torch_inputs, torch.Tensor) | ||
self.assertSequenceEqual(torch_inputs.shape, (6, 4, 200)) | ||
|
||
def test_multiple_inputs(self): | ||
samples_batch = SamplesBatch( | ||
np.ones((6, 200, 4)), | ||
other_input_batches={"something": np.ones(10)}, | ||
target_batch=np.ones(20), | ||
) | ||
|
||
inputs = samples_batch.inputs() | ||
self.assertIsInstance(inputs, dict) | ||
self.assertEqual(len(inputs), 2) | ||
self.assertSequenceEqual(inputs["sequence_batch"].shape, (6, 200, 4)) | ||
|
||
torch_inputs, _ = samples_batch.torch_inputs_and_targets(use_cuda=False) | ||
self.assertIsInstance(torch_inputs, dict) | ||
self.assertEqual(len(torch_inputs), 2) | ||
self.assertSequenceEqual(torch_inputs["sequence_batch"].shape, (6, 4, 200)) | ||
|
||
def test_has_target(self): | ||
samples_batch = SamplesBatch(np.ones((6, 200, 4)), target_batch=np.ones(20)) | ||
targets = samples_batch.targets() | ||
self.assertIsInstance(targets, np.ndarray) | ||
_, torch_targets = samples_batch.torch_inputs_and_targets(use_cuda=False) | ||
self.assertIsInstance(torch_targets, torch.Tensor) | ||
|
||
def test_no_target(self): | ||
samples_batch = SamplesBatch(np.ones((6, 200, 4))) | ||
self.assertIsNone(samples_batch.targets()) | ||
_, torch_targets = samples_batch.torch_inputs_and_targets(use_cuda=False) | ||
self.assertIsNone(torch_targets) |
Empty file.
Empty file.