Skip to content
118 changes: 118 additions & 0 deletions dwave/plugins/torch/nn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import inspect
from functools import wraps
from types import MappingProxyType

import torch
from torch import nn


def store_config(fn):
@wraps(fn)
def wrapper(self, *args, **kwargs):
# Get signature of function and match the arguments with their names
sig = inspect.signature(fn)
bound = sig.bind(self, *args, **kwargs)
# Use default values if the args/kwargs were not supplied
bound.apply_defaults()
config = {k: v for k, v in bound.arguments.items() if k != 'self'}
config['module_name'] = bound.args[0].__class__.__name__
self.config = MappingProxyType(config)
fn(self, *args, **kwargs)
return wrapper


class Identity(nn.Module):

@store_config
def __init__(self):
"""An identity module.

This module is useful for handling cases where a neural network module is expected, but no
effect is desired."""
super().__init__()

def forward(self, x) -> torch.Tensor:
"""Input

Args:
x (torch.Tensor): The input and the output.

Returns:
torch.Tensor: The input and the output.
"""
return x


class SkipLinear(nn.Module):

@store_config
def __init__(self, din, dout) -> None:
"""Applies a linear transformation to the incoming data: :math:`y = xA^T`.

This module is identity when `din == dout`, otherwise, it is a linear transformation, i.e.,
no bias term.

Args:
din (int): Size of each input sample.
dout (int): Size of each output sample.
"""
super().__init__()
if din == dout:
self.linear = Identity()
else:
self.linear = nn.Linear(din, dout, bias=False)

def forward(self, x) -> torch.Tensor:
"""Apply a linear transformation to the input variable `x`.

Args:
x (torch.Tensor): the input tensor.

Returns:
torch.Tensor: the linearly-transformed tensor of `x`.
"""
return self.linear(x)


class LinearBlock(nn.Module):
@store_config
def __init__(self, din, dout, p) -> None:
"""A linear block consisting of normalizations, linear transformations, dropout, relu, and a skip connection.

The module is composed of (in order):
1. a first layer norm,
2. a first linear transformation,
3. a dropout,
4. a relu activation,
5. a second layer norm,
6. a second linear layer, and, finally,
7. a skip connection from initial input to output.

Args:
din (int): Size of each input sample
dout (int): Size of each output sample
p (float): Dropout probability.
"""
super().__init__()
self.skip = SkipLinear(din, dout)
linear_1 = nn.Linear(din, dout)
linear_2 = nn.Linear(dout, dout)
self.block = nn.Sequential(
nn.LayerNorm(din),
linear_1,
nn.Dropout(p),
nn.ReLU(),
nn.LayerNorm(dout),
linear_2,
)

def forward(self, x) -> torch.Tensor:
"""Transforms the input `x` with the modules.

Args:
x (torch.Tensor): An input tensor.

Returns:
torch.Tensor: Output tensor.
"""
return self.block(x) + self.skip(x)
5 changes: 5 additions & 0 deletions releasenotes/notes/add-nn-modules-c29a092140eacbe1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
features:
- Add the python module `dwave.plugins.torch.nn`
- Add ``LinearBlock`` and ``SkipLinear` modules
- Add utilities for testing modules
52 changes: 52 additions & 0 deletions tests/test_nn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import unittest

import torch
from parameterized import parameterized

from dwave.plugins.torch.nn import LinearBlock, SkipLinear, store_config
from tests.utils import model_probably_good


class TestNN(unittest.TestCase):
"""The tests in this class is, generally, concerned with two characteristics of the output.
1. Module outputs, probably, do not end with an activation function, and
2. the output tensor shapes are as expected.
"""

def test_store_config(self):
# Check the Module stores configs as expected.
class MyModel(torch.nn.Module):
@store_config
def __init__(self, a, b=1, *, x=4, y='hello'):
super().__init__()

def forward(self, x):
return x
model = MyModel(a=123, x=5)
self.assertDictEqual(dict(model.config),
{"a": 123, "b": 1, "x": 5, "y": "hello",
"module_name": "MyModel"})

@parameterized.expand([0, 0.5, 1])
def test_LinearBlock(self, p):
din = 32
dout = 177
model = LinearBlock(din, dout, p)
self.assertTrue(model_probably_good(model, (din,), (dout,)))

def test_SkipLinear(self):
din = 33
dout = 99
model = SkipLinear(din, dout)
self.assertTrue(model_probably_good(model, (din,), (dout, )))
with self.subTest("Check identity for `din == dout`"):
dim = 123
model = SkipLinear(dim, dim)
x = torch.randn((dim,))
y = model(x)
self.assertTrue((x == y).all())
self.assertTrue(model_probably_good(model, (dim,), (dim, )))


if __name__ == "__main__":
unittest.main()
143 changes: 143 additions & 0 deletions tests/test_test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import unittest

import torch
from parameterized import parameterized

from dwave.plugins.torch.nn import store_config
from tests import utils


class TestTestUtils(unittest.TestCase):

def test_probably_unconstrained(self):
x = torch.randn((1000, 10, 10))
self.assertTrue(utils._probably_unconstrained(x))

# Activate
self.assertFalse(utils._probably_unconstrained(x.sigmoid()))
self.assertFalse(utils._probably_unconstrained(x.relu()))
self.assertFalse(utils._probably_unconstrained(x.tanh()))

def test__are_all_spins(self):
# Scalar case
self.assertTrue(utils._are_all_spins(torch.tensor([1])))
self.assertTrue(utils._are_all_spins(torch.tensor([-1])))
self.assertFalse(utils._are_all_spins(torch.tensor([0])))

# Zeros
self.assertFalse(utils._are_all_spins(torch.tensor([0, 1])))
self.assertFalse(utils._are_all_spins(torch.tensor([0, -1])))
self.assertFalse(utils._are_all_spins(torch.tensor([0, 0])))
# Nonzeros
self.assertFalse(utils._are_all_spins(torch.tensor([1, 1.2])))
self.assertFalse(utils._are_all_spins(-torch.tensor([1, 1.2])))

# All spins
self.assertTrue(utils._are_all_spins(torch.tensor([-1, 1])))
self.assertTrue(utils._are_all_spins(torch.tensor([-1.0, 1.0])))

def test__has_zeros(self):
# Scalar
self.assertFalse(utils._has_zeros(torch.tensor([1])))
self.assertTrue(utils._has_zeros(torch.tensor([0])))
self.assertTrue(utils._has_zeros(torch.tensor([-0])))

# Tensor
self.assertTrue(utils._has_zeros(torch.tensor([0, 1])))

def test__has_mixed_signs(self):
# Single entries cannot have mixed signs
self.assertFalse(utils._has_mixed_signs(torch.tensor([-0])))
self.assertFalse(utils._has_mixed_signs(torch.tensor([0])))
self.assertFalse(utils._has_mixed_signs(torch.tensor([1])))
self.assertFalse(utils._has_mixed_signs(torch.tensor([-1])))

# Zeros are unsigned
self.assertFalse(utils._has_mixed_signs(torch.tensor([0, 0])))
self.assertFalse(utils._has_mixed_signs(torch.tensor([0, 1.2])))
self.assertFalse(utils._has_mixed_signs(torch.tensor([0, -1.2])))

# All entries have same sign
self.assertFalse(utils._has_mixed_signs(torch.tensor([0.4, 1.2])))
self.assertFalse(utils._has_mixed_signs(-torch.tensor([0.4, 1.2])))

# Finally!
self.assertTrue(utils._has_mixed_signs(torch.tensor([-0.1, 1.2])))

def test__bounded_in_plus_minus_one(self):
# Violation on one end
self.assertFalse(utils._bounded_in_plus_minus_one(torch.tensor([1.2])))
self.assertFalse(utils._bounded_in_plus_minus_one(torch.tensor([-1.2])))
self.assertFalse(utils._bounded_in_plus_minus_one(torch.tensor([1.2, 0])))
self.assertFalse(utils._bounded_in_plus_minus_one(torch.tensor([-1.2, 0])))

# Boundary
self.assertTrue(utils._bounded_in_plus_minus_one(torch.tensor([1])))
self.assertTrue(utils._bounded_in_plus_minus_one(torch.tensor([-1])))
self.assertTrue(utils._bounded_in_plus_minus_one(torch.tensor([1, -1])))
self.assertTrue(utils._bounded_in_plus_minus_one(torch.tensor([1, 0])))
self.assertTrue(utils._bounded_in_plus_minus_one(torch.tensor([0, 1])))

# Correct
self.assertTrue(utils._bounded_in_plus_minus_one(torch.tensor([0.5, 0.9, -0.2])))

@parameterized.expand([[dict(a=1, x=4)], [dict(a="hello")]])
def test__has_correct_config(self, kwargs):
class MyModel(torch.nn.Module):
@store_config
def __init__(self, a, b=2, *, x=4, y=5):
super().__init__()

def forward(self, x):
return torch.ones(5)
model = MyModel(**kwargs)
self.assertTrue(utils._has_correct_config(model))
self.assertFalse(utils._has_correct_config(torch.nn.Linear(5, 3)))

def test__shapes_match(self):
shape = (123, 456)
x = torch.randn(shape)
self.assertTrue(utils._shapes_match(x, shape))
self.assertFalse(utils._shapes_match(x, (1, 2, 3)))

def test_model_probably_good(self):
with self.subTest("Model should be good"):
class MyModel(torch.nn.Module):
@store_config
def __init__(self, a, b=2, *, x=4, y=5):
super().__init__()

def forward(self, x):
return 2*x
self.assertTrue(utils.model_probably_good(MyModel("hello"), (500, ), (500,)))

with self.subTest("Model should be bad: config not stored"):
class MyModel(torch.nn.Module):
def __init__(self, a, b=2, *, x=4, y=5):
super().__init__()

def forward(self, x):
return 2*x
self.assertFalse(utils.model_probably_good(MyModel("hello"), (500, ), (500,)))

with self.subTest("Model should be bad: shape mismatch"):
class MyModel(torch.nn.Module):
def __init__(self, a, b=2, *, x=4, y=5):
super().__init__()

def forward(self, x):
return torch.randn(500)
self.assertFalse(utils.model_probably_good(MyModel("hello"), (123, ), (123,)))

with self.subTest("Model should be bad: constrained output"):
class MyModel(torch.nn.Module):
def __init__(self, a, b=2, *, x=4, y=5):
super().__init__()

def forward(self, x):
return torch.ones_like(x)
self.assertFalse(utils.model_probably_good(MyModel("hello"), (123, ), (123,)))


if __name__ == "__main__":
unittest.main()
Loading