From 8abac4f53aaf197848177e88493cff89a8c994f0 Mon Sep 17 00:00:00 2001 From: Alex Lapin Date: Mon, 1 Mar 2021 15:42:22 +0000 Subject: [PATCH] Add SamplesBatch class --- selene_sdk/predict/tests/__init__.py | 0 selene_sdk/samplers/samples_batch.py | 106 ++++++++++++++++++ selene_sdk/samplers/tests/__init__.py | 0 .../samplers/tests/test_samples_batch.py | 48 ++++++++ selene_sdk/sequences/tests/__init__.py | 0 selene_sdk/targets/tests/__init__.py | 0 6 files changed, 154 insertions(+) create mode 100644 selene_sdk/predict/tests/__init__.py create mode 100644 selene_sdk/samplers/samples_batch.py create mode 100644 selene_sdk/samplers/tests/__init__.py create mode 100644 selene_sdk/samplers/tests/test_samples_batch.py create mode 100644 selene_sdk/sequences/tests/__init__.py create mode 100644 selene_sdk/targets/tests/__init__.py diff --git a/selene_sdk/predict/tests/__init__.py b/selene_sdk/predict/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/selene_sdk/samplers/samples_batch.py b/selene_sdk/samplers/samples_batch.py new file mode 100644 index 00000000..8bade930 --- /dev/null +++ b/selene_sdk/samplers/samples_batch.py @@ -0,0 +1,106 @@ +import numpy as np +import torch + + +class SamplesBatch: + """ + This class represents NN inputs and targets. Values are stored as numpy.ndarrays + and there is a method to convert them to torch.Tensors. + + Inputs are stored in a dict, which can be used if you are providing more than just a + `sequence_batch` to the NN. + + NOTE: If you store just a sequence as an input to the model, then `inputs()` and + `torch_inputs_and_targets()` will return only the batch of sequences rather than + a dict. + + """ + + _SEQUENCE_LABEL = "sequence_batch" + + def __init__( + self, + sequence_batch: np.ndarray, + other_input_batches=dict(), + target_batch: np.ndarray = None, + ) -> None: + self._input_batches = other_input_batches.copy() + self._input_batches[self._SEQUENCE_LABEL] = sequence_batch + self._target_batch = target_batch + + def sequence_batch(self) -> torch.Tensor: + """Returns the sequence batch with a shape of + [batch_size, sequence_length, alphabet_size]. + """ + return self._input_batches[self._SEQUENCE_LABEL] + + def inputs(self): + """Based on the size of inputs dictionary, returns either just the + sequence or the whole dictionary. + + Returns + ------- + numpy.ndarray or dict[str, numpy.ndarray] + numpy.ndarray is returned when inputs contain just the sequence batch. + dict[str, numpy.ndarray] is returned when there are multiple inputs. + + NOTE: Sequence batch has a shape of + [batch_size, sequence_length, alphabet_size]. + """ + if len(self._input_batches) == 1: + return self.sequence_batch() + + return self._input_batches + + def targets(self): + """Returns target batch if it is present. + + Returns + ------- + numpy.ndarray + + """ + return self._target_batch + + def torch_inputs_and_targets(self, use_cuda: bool): + """ + Returns inputs and targets in torch.Tensor format. + + Based on the size of inputs dictionary, returns either just the + sequence or the whole dictionary. + + Returns + ------- + inputs, targets :\ + tuple(numpy.ndarray or dict[str, numpy.ndarray], numpy.ndarray) + For `inputs`: + numpy.ndarray is returned when inputs contain just the sequence batch. + dict[str, numpy.ndarray] is returned when there are multiple inputs. + + NOTE: Returned sequence batch has a shape of + [batch_size, alphabet_size, sequence_length]. + + """ + all_inputs = dict() + for key, value in self._input_batches.items(): + all_inputs[key] = torch.Tensor(value) + + if use_cuda: + all_inputs[key] = all_inputs[key].cuda() + + # Transpose the sequences to satisfy NN convolution input format (which is + # [batch_size, channels_size, sequence_length]). + all_inputs[self._SEQUENCE_LABEL] = all_inputs[self._SEQUENCE_LABEL].transpose( + 1, 2 + ) + + inputs = all_inputs if len(all_inputs) > 1 else all_inputs[self._SEQUENCE_LABEL] + + targets = None + if self._target_batch is not None: + targets = torch.Tensor(self._target_batch) + + if use_cuda: + targets = targets.cuda() + + return inputs, targets diff --git a/selene_sdk/samplers/tests/__init__.py b/selene_sdk/samplers/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/selene_sdk/samplers/tests/test_samples_batch.py b/selene_sdk/samplers/tests/test_samples_batch.py new file mode 100644 index 00000000..c2f4c055 --- /dev/null +++ b/selene_sdk/samplers/tests/test_samples_batch.py @@ -0,0 +1,48 @@ +import unittest + +import numpy as np +import torch +from selene_sdk.samplers.samples_batch import SamplesBatch + + +class TestSamplesBatch(unittest.TestCase): + def test_single_input(self): + samples_batch = SamplesBatch(np.ones((6, 200, 4)), target_batch=np.ones(20)) + + inputs = samples_batch.inputs() + self.assertIsInstance(inputs, np.ndarray) + self.assertSequenceEqual(inputs.shape, (6, 200, 4)) + + torch_inputs, _ = samples_batch.torch_inputs_and_targets(use_cuda=False) + self.assertIsInstance(torch_inputs, torch.Tensor) + self.assertSequenceEqual(torch_inputs.shape, (6, 4, 200)) + + def test_multiple_inputs(self): + samples_batch = SamplesBatch( + np.ones((6, 200, 4)), + other_input_batches={"something": np.ones(10)}, + target_batch=np.ones(20), + ) + + inputs = samples_batch.inputs() + self.assertIsInstance(inputs, dict) + self.assertEqual(len(inputs), 2) + self.assertSequenceEqual(inputs["sequence_batch"].shape, (6, 200, 4)) + + torch_inputs, _ = samples_batch.torch_inputs_and_targets(use_cuda=False) + self.assertIsInstance(torch_inputs, dict) + self.assertEqual(len(torch_inputs), 2) + self.assertSequenceEqual(torch_inputs["sequence_batch"].shape, (6, 4, 200)) + + def test_has_target(self): + samples_batch = SamplesBatch(np.ones((6, 200, 4)), target_batch=np.ones(20)) + targets = samples_batch.targets() + self.assertIsInstance(targets, np.ndarray) + _, torch_targets = samples_batch.torch_inputs_and_targets(use_cuda=False) + self.assertIsInstance(torch_targets, torch.Tensor) + + def test_no_target(self): + samples_batch = SamplesBatch(np.ones((6, 200, 4))) + self.assertIsNone(samples_batch.targets()) + _, torch_targets = samples_batch.torch_inputs_and_targets(use_cuda=False) + self.assertIsNone(torch_targets) diff --git a/selene_sdk/sequences/tests/__init__.py b/selene_sdk/sequences/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/selene_sdk/targets/tests/__init__.py b/selene_sdk/targets/tests/__init__.py new file mode 100644 index 00000000..e69de29b