-
Notifications
You must be signed in to change notification settings - Fork 9
Add neural network modules, tests, and test utils #33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
kevinchern
wants to merge
13
commits into
dwavesystems:main
Choose a base branch
from
kevinchern:feature/nn
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
621fd0a
Add neural network modules, tests, and test utils
kevinchern a42973c
Use decorator to store configs and add module test
kevinchern c85393e
Add docstrings
kevinchern 495851c
Address review comments
kevinchern 0a99594
Store nested configs and cite ResNet
kevinchern e2f9986
Improve docstrings and fix typos
kevinchern 9467786
Refactor nn.py into a module
kevinchern 7ca53f2
Remove leading underscore for test functions
kevinchern c7028d4
Apply suggestions from code review
kevinchern a72e07a
Address PR comments
kevinchern 9b53973
Fix typo
kevinchern 7062a11
Separate tests and add more store_config tests
kevinchern da72736
Apply suggestions from code review
kevinchern File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,16 @@ | ||
| # Copyright 2025 D-Wave | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # | ||
|
|
||
| from dwave.plugins.torch.nn.modules import * |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,17 @@ | ||
| # Copyright 2025 D-Wave | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # | ||
|
|
||
| from dwave.plugins.torch.nn.modules.linear import * | ||
| from dwave.plugins.torch.nn.modules.utils import * |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,100 @@ | ||
| # Copyright 2025 D-Wave | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # | ||
| import torch | ||
| from torch import nn | ||
|
|
||
| from dwave.plugins.torch.nn.modules.utils import store_config | ||
|
|
||
| __all__ = ["SkipLinear", "LinearBlock"] | ||
|
|
||
|
|
||
| class SkipLinear(nn.Module): | ||
| """A linear transformation or the identity depending on whether input/output dimensions match. | ||
|
|
||
| This module is identity when ``din == dout``, otherwise, it is a linear transformation (no bias | ||
| term). | ||
|
|
||
| This is based on the `ResNet paper <https://arxiv.org/abs/1512.03385>`. | ||
|
|
||
| Args: | ||
| din (int): Size of each input sample. | ||
| dout (int): Size of each output sample. | ||
| """ | ||
|
|
||
| @store_config | ||
| def __init__(self, din: int, dout: int) -> None: | ||
| super().__init__() | ||
| if din == dout: | ||
| self.linear = nn.Identity() | ||
| else: | ||
| self.linear = nn.Linear(din, dout, bias=False) | ||
|
|
||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| """Apply a linear transformation to the input variable ``x``. | ||
|
|
||
| Args: | ||
| x (torch.Tensor): The input tensor. | ||
|
|
||
| Returns: | ||
| torch.Tensor: The linearly-transformed tensor of ``x``. | ||
| """ | ||
| return self.linear(x) | ||
|
|
||
|
|
||
| class LinearBlock(nn.Module): | ||
| """A linear block consisting of normalizations, linear transformations, dropout, relu, and a skip connection. | ||
|
|
||
| The module is composed of (in order): | ||
|
|
||
| 1. a first layer norm, | ||
| 2. a first linear transformation, | ||
| 3. a dropout, | ||
| 4. a relu activation, | ||
| 5. a second layer norm, | ||
| 6. a second linear layer, and, finally, | ||
| 7. a skip connection from initial input to output. | ||
|
|
||
| This is based on the `ResNet paper <https://arxiv.org/abs/1512.03385>`_. | ||
|
|
||
| Args: | ||
| din (int): Size of each input sample. | ||
| dout (int): Size of each output sample. | ||
| p (float): Dropout probability. | ||
| """ | ||
|
|
||
| @store_config | ||
| def __init__(self, din: int, dout: int, p: float) -> None: | ||
| super().__init__() | ||
| self._skip = SkipLinear(din, dout) | ||
| dhid = max(din, dout) | ||
| self._block = nn.Sequential( | ||
| nn.LayerNorm(din), | ||
| nn.Linear(din, dhid), | ||
| nn.Dropout(p), | ||
| nn.ReLU(), | ||
| nn.LayerNorm(dhid), | ||
| nn.Linear(dhid, dout), | ||
| ) | ||
|
|
||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| """Transforms the input ``x`` with the modules. | ||
|
|
||
| Args: | ||
| x (torch.Tensor): An input tensor. | ||
|
|
||
| Returns: | ||
| torch.Tensor: Output tensor. | ||
| """ | ||
| return self._block(x) + self._skip(x) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,60 @@ | ||
| # Copyright 2025 D-Wave | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # | ||
| from __future__ import annotations | ||
| import inspect | ||
| from typing import TYPE_CHECKING, Callable | ||
|
|
||
| if TYPE_CHECKING: | ||
| from functools import partial | ||
|
|
||
| from functools import wraps | ||
| from types import MappingProxyType | ||
|
|
||
| __all__ = ["store_config"] | ||
|
|
||
|
|
||
| def store_config(fn: Callable) -> partial: | ||
| """A decorator that tracks and stores arguments of methods (excluding ``self``). | ||
|
|
||
| .. note:: | ||
| If an argument of the function also has a config attribute, then the argument's entry in | ||
| the dictionary will be replaced by the argument's config. For example, an argument ``foo`` has | ||
| a ``config`` attribute, i.e., ``foo.config`` exists, then ``self.config`` will contain the entry | ||
| ``{"foo": foo.config}``. This is motivated by the convenience of storing configs of nested | ||
| modules. | ||
|
|
||
| Args: | ||
| fn (Callable[object, ...]): A method whose arguments will be stored in ``self.config``. | ||
|
|
||
| Returns: | ||
| partial: Wrapper function that stores argument of method. | ||
| """ | ||
| @wraps(fn) | ||
| def wrapper(self, *args, **kwargs): | ||
| """Store ``args``, ``kwargs``, and ``{"module_name": self.__class__.__name__}`` as a dictionary in ``self.config``. | ||
| """ | ||
| sig = inspect.signature(fn) | ||
| bound = sig.bind(self, *args, **kwargs) | ||
| bound.apply_defaults() | ||
|
|
||
| config = {k: v for k, v in bound.arguments.items() if v != self} | ||
| config['module_name'] = self.__class__.__name__ | ||
| for k, v in config.items(): | ||
| if hasattr(v, "config"): | ||
| config[k] = v.config | ||
| self.config = MappingProxyType(config) | ||
|
|
||
| return fn(self, *args, **kwargs) | ||
| return wrapper |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| --- | ||
| features: | ||
| - Add the Python module ``dwave.plugins.torch.nn`` containing commonly-used neural network modules | ||
| and patterns used to build more complex architectures. | ||
| - Add ``LinearBlock`` and ``SkipLinear`` modules. | ||
| - Add utilities for testing torch modules added to the ``nn`` Python submodule. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,121 @@ | ||
| from inspect import signature | ||
|
|
||
| import torch | ||
|
|
||
|
|
||
| def model_probably_good( | ||
kevinchern marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| model: torch.nn.Module, shape_in: tuple[int, ...], shape_out: tuple[int, ...] | ||
| ) -> bool: | ||
| """Checks whether the model output has expected shape, is probably unconstrained, and the model | ||
| has its configs stored. | ||
|
|
||
| This function generates dummy data with a padded batch dimension on top of the | ||
| input dimension (so ``shape_in`` should exclude a batch dimension). The data is passed through | ||
| the ``model``. Subsequent tests are described in ``shapes_match``, ``probably_unconstrained``, | ||
| and ``has_correct_config``. | ||
|
|
||
| Args: | ||
| model (torch.nn.Module): The module to be tested. | ||
| shape_in (tuple[int, ...]): Input data shape excluding the batch dimension. | ||
| shape_out (tuple[int, ...]): Output data shape excluding the batch dimension. | ||
|
|
||
| Returns: | ||
| bool: Indicator for whether the model meets the three conditions above. | ||
| """ | ||
| bs = 100 | ||
| x = torch.randn((bs, ) + shape_in) | ||
kevinchern marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| y = model(x) | ||
| padded_out = (bs,)+shape_out | ||
| return (shapes_match(y, padded_out) | ||
| and probably_unconstrained(y) | ||
| and has_correct_config(model)) | ||
|
|
||
|
|
||
| def has_correct_config(model: torch.nn.Module) -> bool: | ||
| """Checks whether the model has its initialization arguments stored in a ``config`` field. | ||
|
|
||
| Args: | ||
| model (torch.nn.Module): The module to be tested. | ||
|
|
||
| Returns: | ||
| bool: Indicator for whether the model has its initialization arguments stored. | ||
| """ | ||
| if not hasattr(model, "config"): | ||
| return False | ||
| sig = signature(model.__init__) | ||
| return set(model.config.keys()) == set(sig.parameters.keys()) | {"module_name"} | ||
|
|
||
|
|
||
| def shapes_match(x: torch.Tensor, y: tuple[int, ...]) -> bool: | ||
kevinchern marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """Checks whether `x.shape` is equal to `y`. | ||
|
|
||
| Args: | ||
| x (torch.Tensor): A tensor. | ||
| y (tuple[int, ...]): The expected shape. | ||
|
|
||
| Returns: | ||
| bool: Indicator for whether the shape is as expected. | ||
| """ | ||
| return tuple(x.shape) == y | ||
|
|
||
|
|
||
| def are_all_spins(x: torch.Tensor) -> bool: | ||
| """Checks all entries of `x` are one in absolute value. | ||
|
|
||
| Args: | ||
| x (torch.Tensor): A tensor. | ||
|
|
||
| Returns: | ||
| bool: indicator for whether all entries of `x` are in ``{-1, 1}``. | ||
| """ | ||
| return (x.float().abs() == 1).all() | ||
|
|
||
|
|
||
| def has_mixed_signs(x: torch.Tensor) -> bool: | ||
| """Checks whether `x` has both positive and negative values. | ||
|
|
||
| Args: | ||
| x (torch.Tensor): A tensor to be cast to type float. | ||
|
|
||
| Returns: | ||
| bool: Indicator for whether `x` consists of both positive and negative values. | ||
| """ | ||
| return bool(x.max() > 0 and x.min() < 0) | ||
|
|
||
|
|
||
| def has_zeros(x: torch.Tensor) -> bool: | ||
| """Checks whether `x` has exact zeros. | ||
|
|
||
| Args: | ||
| x (torch.Tensor): A tensor. | ||
|
|
||
| Returns: | ||
| bool: Indicator for whether `x` has any zero-valued entries. | ||
| """ | ||
| return (x == 0).float().any() | ||
|
|
||
|
|
||
| def bounded_in_plus_minus_one(x: torch.Tensor) -> bool: | ||
| """Checks whether all entries of `x` are in ``[-1, 1]``. | ||
|
|
||
| Args: | ||
| x (torch.Tensor): A tensor. | ||
|
|
||
| Returns: | ||
| bool: Indicator for whether all values of `x` are in ``[-1, 1]``. | ||
| """ | ||
| return bool((x.abs() <= 1).all()) | ||
|
|
||
|
|
||
| def probably_unconstrained(x: torch.Tensor): | ||
| """Checks whether `x` has any activation-like constraints. | ||
| Checks `x` has no exact zeros, not bounded in ``[-1, 1]``, and has both positive and | ||
| negative-valued entries. | ||
|
|
||
| Args: | ||
| x (torch.Tensor): A tensor. | ||
|
|
||
| Returns: | ||
| bool: Indicator for whether `x` passes the "constraints". | ||
| """ | ||
| return not has_zeros(x) and not bounded_in_plus_minus_one(x) and has_mixed_signs(x) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.