Skip to content
Open
16 changes: 16 additions & 0 deletions dwave/plugins/torch/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright 2025 D-Wave
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from dwave.plugins.torch.nn.modules import *
17 changes: 17 additions & 0 deletions dwave/plugins/torch/nn/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright 2025 D-Wave
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from dwave.plugins.torch.nn.modules.linear import *
from dwave.plugins.torch.nn.modules.utils import *
100 changes: 100 additions & 0 deletions dwave/plugins/torch/nn/modules/linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright 2025 D-Wave
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
from torch import nn

from dwave.plugins.torch.nn.modules.utils import store_config

__all__ = ["SkipLinear", "LinearBlock"]


class SkipLinear(nn.Module):
"""A linear transformation or the identity depending on whether input/output dimensions match.

This module is identity when ``din == dout``, otherwise, it is a linear transformation (no bias
term).

This is based on the `ResNet paper <https://arxiv.org/abs/1512.03385>`.

Args:
din (int): Size of each input sample.
dout (int): Size of each output sample.
"""

@store_config
def __init__(self, din: int, dout: int) -> None:
super().__init__()
if din == dout:
self.linear = nn.Identity()
else:
self.linear = nn.Linear(din, dout, bias=False)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply a linear transformation to the input variable ``x``.

Args:
x (torch.Tensor): The input tensor.

Returns:
torch.Tensor: The linearly-transformed tensor of ``x``.
"""
return self.linear(x)


class LinearBlock(nn.Module):
"""A linear block consisting of normalizations, linear transformations, dropout, relu, and a skip connection.

The module is composed of (in order):

1. a first layer norm,
2. a first linear transformation,
3. a dropout,
4. a relu activation,
5. a second layer norm,
6. a second linear layer, and, finally,
7. a skip connection from initial input to output.

This is based on the `ResNet paper <https://arxiv.org/abs/1512.03385>`_.

Args:
din (int): Size of each input sample.
dout (int): Size of each output sample.
p (float): Dropout probability.
"""

@store_config
def __init__(self, din: int, dout: int, p: float) -> None:
super().__init__()
self._skip = SkipLinear(din, dout)
dhid = max(din, dout)
self._block = nn.Sequential(
nn.LayerNorm(din),
nn.Linear(din, dhid),
nn.Dropout(p),
nn.ReLU(),
nn.LayerNorm(dhid),
nn.Linear(dhid, dout),
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Transforms the input ``x`` with the modules.

Args:
x (torch.Tensor): An input tensor.

Returns:
torch.Tensor: Output tensor.
"""
return self._block(x) + self._skip(x)
60 changes: 60 additions & 0 deletions dwave/plugins/torch/nn/modules/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright 2025 D-Wave
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from __future__ import annotations
import inspect
from typing import TYPE_CHECKING, Callable

if TYPE_CHECKING:
from functools import partial

from functools import wraps
from types import MappingProxyType

__all__ = ["store_config"]


def store_config(fn: Callable) -> partial:
"""A decorator that tracks and stores arguments of methods (excluding ``self``).

.. note::
If an argument of the function also has a config attribute, then the argument's entry in
the dictionary will be replaced by the argument's config. For example, an argument ``foo`` has
a ``config`` attribute, i.e., ``foo.config`` exists, then ``self.config`` will contain the entry
``{"foo": foo.config}``. This is motivated by the convenience of storing configs of nested
modules.

Args:
fn (Callable[object, ...]): A method whose arguments will be stored in ``self.config``.

Returns:
partial: Wrapper function that stores argument of method.
"""
@wraps(fn)
def wrapper(self, *args, **kwargs):
"""Store ``args``, ``kwargs``, and ``{"module_name": self.__class__.__name__}`` as a dictionary in ``self.config``.
"""
sig = inspect.signature(fn)
bound = sig.bind(self, *args, **kwargs)
bound.apply_defaults()

config = {k: v for k, v in bound.arguments.items() if v != self}
config['module_name'] = self.__class__.__name__
for k, v in config.items():
if hasattr(v, "config"):
config[k] = v.config
self.config = MappingProxyType(config)

return fn(self, *args, **kwargs)
return wrapper
6 changes: 6 additions & 0 deletions releasenotes/notes/add-nn-modules-c29a092140eacbe1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
features:
- Add the Python module ``dwave.plugins.torch.nn`` containing commonly-used neural network modules
and patterns used to build more complex architectures.
- Add ``LinearBlock`` and ``SkipLinear`` modules.
- Add utilities for testing torch modules added to the ``nn`` Python submodule.
121 changes: 121 additions & 0 deletions tests/helper_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from inspect import signature

import torch


def model_probably_good(
model: torch.nn.Module, shape_in: tuple[int, ...], shape_out: tuple[int, ...]
) -> bool:
"""Checks whether the model output has expected shape, is probably unconstrained, and the model
has its configs stored.

This function generates dummy data with a padded batch dimension on top of the
input dimension (so ``shape_in`` should exclude a batch dimension). The data is passed through
the ``model``. Subsequent tests are described in ``shapes_match``, ``probably_unconstrained``,
and ``has_correct_config``.

Args:
model (torch.nn.Module): The module to be tested.
shape_in (tuple[int, ...]): Input data shape excluding the batch dimension.
shape_out (tuple[int, ...]): Output data shape excluding the batch dimension.

Returns:
bool: Indicator for whether the model meets the three conditions above.
"""
bs = 100
x = torch.randn((bs, ) + shape_in)
y = model(x)
padded_out = (bs,)+shape_out
return (shapes_match(y, padded_out)
and probably_unconstrained(y)
and has_correct_config(model))


def has_correct_config(model: torch.nn.Module) -> bool:
"""Checks whether the model has its initialization arguments stored in a ``config`` field.

Args:
model (torch.nn.Module): The module to be tested.

Returns:
bool: Indicator for whether the model has its initialization arguments stored.
"""
if not hasattr(model, "config"):
return False
sig = signature(model.__init__)
return set(model.config.keys()) == set(sig.parameters.keys()) | {"module_name"}


def shapes_match(x: torch.Tensor, y: tuple[int, ...]) -> bool:
"""Checks whether `x.shape` is equal to `y`.

Args:
x (torch.Tensor): A tensor.
y (tuple[int, ...]): The expected shape.

Returns:
bool: Indicator for whether the shape is as expected.
"""
return tuple(x.shape) == y


def are_all_spins(x: torch.Tensor) -> bool:
"""Checks all entries of `x` are one in absolute value.

Args:
x (torch.Tensor): A tensor.

Returns:
bool: indicator for whether all entries of `x` are in ``{-1, 1}``.
"""
return (x.float().abs() == 1).all()


def has_mixed_signs(x: torch.Tensor) -> bool:
"""Checks whether `x` has both positive and negative values.

Args:
x (torch.Tensor): A tensor to be cast to type float.

Returns:
bool: Indicator for whether `x` consists of both positive and negative values.
"""
return bool(x.max() > 0 and x.min() < 0)


def has_zeros(x: torch.Tensor) -> bool:
"""Checks whether `x` has exact zeros.

Args:
x (torch.Tensor): A tensor.

Returns:
bool: Indicator for whether `x` has any zero-valued entries.
"""
return (x == 0).float().any()


def bounded_in_plus_minus_one(x: torch.Tensor) -> bool:
"""Checks whether all entries of `x` are in ``[-1, 1]``.

Args:
x (torch.Tensor): A tensor.

Returns:
bool: Indicator for whether all values of `x` are in ``[-1, 1]``.
"""
return bool((x.abs() <= 1).all())


def probably_unconstrained(x: torch.Tensor):
"""Checks whether `x` has any activation-like constraints.
Checks `x` has no exact zeros, not bounded in ``[-1, 1]``, and has both positive and
negative-valued entries.

Args:
x (torch.Tensor): A tensor.

Returns:
bool: Indicator for whether `x` passes the "constraints".
"""
return not has_zeros(x) and not bounded_in_plus_minus_one(x) and has_mixed_signs(x)
Loading