Skip to content

Commit 7457953

Browse files
authored
Function wrappers and JIT compatibility for most of the package (#59)
* Added script decorators to many of the util funcs * HardOptimalCompletionLoss trace test * MinimumErrorRateLoss tracing, tests * Made attention layers scriptable, at least * SpecAugment tracing, tests * RandomShift tracing, tests * _string_matching changes * Fixes * Maybe this works? * compat * ctc_prefix_search script_if_tracing * Refactored + made jit decisions more centralized * Eeeo * Removed del statements in scripted modules * Bug fix for sequential_log_probs * Added SequentialLogProbabilities module * Try this * Suppress some test warnings * Fixed attention trace * Now? * broadcast_shapes compatibility * Compatibility stuff * A little more compatibility stuff * scripting SequentialLanguageModel * Avoid LookupLanguageModel scripting on 1.5.1 * What's happening now? * Skip scripting tests for 1.5.1 * Test fixes * Bug fixes and debugging 1.5.1 * Keep try * What now? * Pluz * ctc prefix search, mayhaps * beam * Oops * Ugh * Again, with feeling * Ok! * Avoid unnecessary recursion * DenseImageWarp * PolyharmonicSpline * SparseImageWarp * String matching * PadVariable * Warp1DGrid * TimeDistributedReturn * Bug fixes for CUDA and changelog
1 parent 66633c9 commit 7457953

17 files changed

+4255
-2417
lines changed

.appveyor.yml

+6-2
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,16 @@ image: Ubuntu
55

66
environment:
77
matrix:
8-
- TOXENV: py36-earliest
8+
- TOXENV: py36-t151
99
PYTHON: "3.6"
10-
- TOXENV: py37
10+
- TOXENV: py36-t181
11+
PYTHON: "3.6"
12+
- TOXENV: py37-t181
1113
PYTHON: "3.7"
1214
- TOXENV: py38
1315
PYTHON: "3.8"
16+
- TOXENV: py38-151
17+
PYTHON: "3.8"
1418
- TOXENV: py39
1519
PYTHON: "3.9"
1620
# - TOXENV: py310 # wheel not available yet

CHANGELOG.md

+9
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,15 @@
22

33
## HEAD
44

5+
- Added a number of modules to `pydrobert.torch.layers` as a wrapper around the
6+
functional versions.
7+
- Added compatibility wrappers to avoid warnings across supported pytorch
8+
versions.
9+
- Refactored code and added tests to support JIT tracing and scripting for most
10+
functions/modules in pytorch >= 1.8.1. Did not handle those in
11+
`pydrobert.torch.estimators` yet because I plan on revamping that code
12+
before the next release. I'll write up documentation shortly.
13+
- Added `pydrobert.torch.config` to store constants used in the module.
514
- Removed `setup.py`.
615
- Deleted conda recipe in prep for [conda-forge](https://conda-forge.org/).
716
- Compatibility/determinism fixes for 1.5.1.

LICENSE_pytorch.txt

+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
From PyTorch:
2+
3+
Copyright (c) 2016- Facebook, Inc (Adam Paszke)
4+
Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
5+
Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
6+
Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
7+
Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
8+
Copyright (c) 2011-2013 NYU (Clement Farabet)
9+
Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
10+
Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
11+
Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
12+
13+
From Caffe2:
14+
15+
Copyright (c) 2016-present, Facebook Inc. All rights reserved.
16+
17+
All contributions by Facebook:
18+
Copyright (c) 2016 Facebook Inc.
19+
20+
All contributions by Google:
21+
Copyright (c) 2015 Google Inc.
22+
All rights reserved.
23+
24+
All contributions by Yangqing Jia:
25+
Copyright (c) 2015 Yangqing Jia
26+
All rights reserved.
27+
28+
All contributions by Kakao Brain:
29+
Copyright 2019-2020 Kakao Brain
30+
31+
All contributions from Caffe:
32+
Copyright(c) 2013, 2014, 2015, the respective contributors
33+
All rights reserved.
34+
35+
All other contributions:
36+
Copyright(c) 2015, 2016 the respective contributors
37+
All rights reserved.
38+
39+
Caffe2 uses a copyright model similar to Caffe: each contributor holds
40+
copyright over their contributions to Caffe2. The project versioning records
41+
all such contribution and copyright details. If a contributor wants to further
42+
mark their specific copyright on a particular contribution, they should
43+
indicate their copyright solely in the commit message of the change when it is
44+
committed.
45+
46+
All rights reserved.
47+
48+
Redistribution and use in source and binary forms, with or without
49+
modification, are permitted provided that the following conditions are met:
50+
51+
1. Redistributions of source code must retain the above copyright
52+
notice, this list of conditions and the following disclaimer.
53+
54+
2. Redistributions in binary form must reproduce the above copyright
55+
notice, this list of conditions and the following disclaimer in the
56+
documentation and/or other materials provided with the distribution.
57+
58+
3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
59+
and IDIAP Research Institute nor the names of its contributors may be
60+
used to endorse or promote products derived from this software without
61+
specific prior written permission.
62+
63+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
64+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
65+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
66+
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
67+
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
68+
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
69+
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
70+
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
71+
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
72+
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
73+
POSSIBILITY OF SUCH DAMAGE.

README.md

+4
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,7 @@ details.
3737
Implementations of `pydrobert.torch.util.polyharmonic_spline` and
3838
`pydrobert.torch.util.sparse_image_warp` are based off Tensorflow's codebase,
3939
which is Apache 2.0 licensed.
40+
41+
Implementation of `pydrobert.torch._compat.broadcast_shapes` was directly
42+
taken from the Pytorch codebase, which has a BSD-style license, found in
43+
the file `LICENSE_pytorch`.

pytest.ini

+2
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,5 @@
22
markers =
33
cpu : tests are on the cpu
44
gpu : tests are on the gpu
5+
trace : tests involve tracing code (TorchScript)
6+
script : tests involve scripting code (TorchScript)

src/pydrobert/torch/__init__.py

+1-14
Original file line numberDiff line numberDiff line change
@@ -23,23 +23,10 @@
2323
__version__ = "inplace"
2424

2525
__all__ = [
26-
"command_line",
26+
"config",
2727
"data",
2828
"estimators",
29-
"INDEX_PAD_VALUE",
3029
"layers",
3130
"training",
3231
"util",
3332
]
34-
35-
36-
"""The value to pad index-based tensors with
37-
38-
Batched operations often involve variable-width input. This value is used to
39-
right-pad indexed-based tensors with to indicate that this element should be
40-
ignored.
41-
42-
The default value (-100) was chosen to coincide with the PyTorch 1.0 default
43-
for ``ignore_index`` in the likelihood losses
44-
"""
45-
INDEX_PAD_VALUE = -100

src/pydrobert/torch/_compat.py

+244
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
# Copyright 2022 Sean Robertson
2+
#
3+
# Code for broadcast_shapes was adapted from PyTorch
4+
# https://github.com/pytorch/pytorch/blob/2367face24afb159f73ebf40dc6f23e46132b770/torch/functional.py
5+
# Code for TorchVersion was taken directly from PyTorch
6+
# https://github.com/pytorch/pytorch/blob/b737e09f60dd56dbae520e436648e1f3ebc1f937/torch/torch_version.py
7+
# See LICENSE_pytorch in project root directory for PyTorch license.
8+
9+
# Licensed under the Apache License, Version 2.0 (the "License");
10+
# you may not use this file except in compliance with the License.
11+
# You may obtain a copy of the License at
12+
13+
# http://www.apache.org/licenses/LICENSE-2.0
14+
15+
# Unless required by applicable law or agreed to in writing, software
16+
# distributed under the License is distributed on an "AS IS" BASIS,
17+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18+
# See the License for the specific language governing permissions and
19+
# limitations under the License.
20+
21+
from typing import Any, Iterable, List, Optional, Tuple, Union, NamedTuple, Set
22+
23+
import torch
24+
import pydrobert.torch.config as config
25+
26+
27+
__all__ = [
28+
"broadcast_shapes",
29+
"jit_isinstance",
30+
"linalg_solve",
31+
"meshgrid",
32+
"pad_sequence",
33+
"script",
34+
"SpoofPackedSequence",
35+
"trunc_divide",
36+
]
37+
38+
39+
# to avoid some scripting issues with torch.utils.nn.PackedSequence
40+
class SpoofPackedSequence(NamedTuple):
41+
data: torch.Tensor
42+
batch_sizes: torch.Tensor
43+
sorted_indices: Optional[torch.Tensor]
44+
unsorted_indices: Optional[torch.Tensor]
45+
46+
47+
if config.USE_JIT:
48+
script = torch.jit.script
49+
else:
50+
try:
51+
script = torch.jit.script_if_tracing
52+
except AttributeError:
53+
54+
def script(obj, *args, **kwargs):
55+
return obj
56+
57+
58+
try:
59+
from torch.torch_version import __version__ as _v # type: ignore
60+
except ModuleNotFoundError:
61+
from torch.version import __version__ as internal_version
62+
from pkg_resources import packaging # type: ignore[attr-defined]
63+
64+
Version = packaging.version.Version
65+
InvalidVersion = packaging.version.InvalidVersion
66+
67+
class TorchVersion(str):
68+
"""A string with magic powers to compare to both Version and iterables!
69+
Prior to 1.10.0 torch.__version__ was stored as a str and so many did
70+
comparisons against torch.__version__ as if it were a str. In order to not
71+
break them we have TorchVersion which masquerades as a str while also
72+
having the ability to compare against both packaging.version.Version as
73+
well as tuples of values, eg. (1, 2, 1)
74+
Examples:
75+
Comparing a TorchVersion object to a Version object
76+
TorchVersion('1.10.0a') > Version('1.10.0a')
77+
Comparing a TorchVersion object to a Tuple object
78+
TorchVersion('1.10.0a') > (1, 2) # 1.2
79+
TorchVersion('1.10.0a') > (1, 2, 1) # 1.2.1
80+
Comparing a TorchVersion object against a string
81+
TorchVersion('1.10.0a') > '1.2'
82+
TorchVersion('1.10.0a') > '1.2.1'
83+
"""
84+
85+
# fully qualified type names here to appease mypy
86+
def _convert_to_version(
87+
self, inp: Union[packaging.version.Version, str, Iterable]
88+
) -> packaging.version.Version:
89+
if isinstance(inp, Version):
90+
return inp
91+
elif isinstance(inp, str):
92+
return Version(inp)
93+
elif isinstance(inp, Iterable):
94+
# Ideally this should work for most cases by attempting to group
95+
# the version tuple, assuming the tuple looks (MAJOR, MINOR, ?PATCH)
96+
# Examples:
97+
# * (1) -> Version("1")
98+
# * (1, 20) -> Version("1.20")
99+
# * (1, 20, 1) -> Version("1.20.1")
100+
return Version(".".join((str(item) for item in inp)))
101+
else:
102+
raise InvalidVersion(inp)
103+
104+
def __gt__(self, cmp):
105+
try:
106+
return Version(self).__gt__(self._convert_to_version(cmp))
107+
except InvalidVersion:
108+
# Fall back to regular string comparison if dealing with an invalid
109+
# version like 'parrot'
110+
return super().__gt__(cmp)
111+
112+
def __lt__(self, cmp):
113+
try:
114+
return Version(self).__lt__(self._convert_to_version(cmp))
115+
except InvalidVersion:
116+
# Fall back to regular string comparison if dealing with an invalid
117+
# version like 'parrot'
118+
return super().__lt__(cmp)
119+
120+
def __eq__(self, cmp):
121+
try:
122+
return Version(self).__eq__(self._convert_to_version(cmp))
123+
except InvalidVersion:
124+
# Fall back to regular string comparison if dealing with an invalid
125+
# version like 'parrot'
126+
return super().__eq__(cmp)
127+
128+
def __ge__(self, cmp):
129+
try:
130+
return Version(self).__ge__(self._convert_to_version(cmp))
131+
except InvalidVersion:
132+
# Fall back to regular string comparison if dealing with an invalid
133+
# version like 'parrot'
134+
return super().__ge__(cmp)
135+
136+
def __le__(self, cmp):
137+
try:
138+
return Version(self).__le__(self._convert_to_version(cmp))
139+
except InvalidVersion:
140+
# Fall back to regular string comparison if dealing with an invalid
141+
# version like 'parrot'
142+
return super().__le__(cmp)
143+
144+
_v = TorchVersion(internal_version)
145+
146+
if _v < "1.8.0":
147+
148+
@script
149+
def pad_sequence(
150+
sequences: List[torch.Tensor],
151+
batch_first: bool = False,
152+
padding_value: float = 0.0,
153+
) -> torch.Tensor:
154+
shape = sequences[0].size()
155+
shape_rest = shape[1:]
156+
lens = [x.size(0) for x in sequences]
157+
max_len = max(lens)
158+
pad_shapes = [(max_len - x,) + shape_rest for x in lens]
159+
sequences = [
160+
torch.cat(
161+
[
162+
seq,
163+
torch.full(ps, padding_value, device=seq.device, dtype=seq.dtype),
164+
],
165+
0,
166+
)
167+
for seq, ps in zip(sequences, pad_shapes)
168+
]
169+
return torch.stack(sequences, 0 if batch_first else 1)
170+
171+
def linalg_solve(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
172+
return torch.solve(B, A)[0]
173+
174+
@torch.jit.unused
175+
def _jit_isinstance(obj: Any, x: type) -> bool:
176+
if isinstance(obj, torch.nn.utils.rnn.PackedSequence):
177+
obj = obj.data, obj.batch_sizes, obj.sorted_indices, obj.unsorted_indices
178+
origin = getattr(x, "__origin__", None)
179+
if origin is None:
180+
return isinstance(obj, x)
181+
if origin in {tuple, list, set, List, Set, Tuple}:
182+
args = getattr(x, "__args__", None)
183+
if not args:
184+
return (
185+
(origin in {tuple, Tuple} and obj == tuple())
186+
or (origin in {list, List} and obj == list())
187+
or (origin in {set, Set} and obj == set())
188+
)
189+
if origin in {tuple, Tuple}:
190+
return (len(obj) is len(args)) and all(
191+
_jit_isinstance(*y) for y in zip(obj, args)
192+
)
193+
else:
194+
assert len(args) == 1
195+
return all(_jit_isinstance(o, args[0]) for o in obj)
196+
elif origin is Union:
197+
args = x.__args__
198+
return any(_jit_isinstance(obj, y) for y in args)
199+
return False
200+
201+
def jit_isinstance(obj: Any, x: type) -> bool:
202+
if torch.jit.is_scripting():
203+
return isinstance(obj, x)
204+
else:
205+
return _jit_isinstance(obj, x)
206+
207+
208+
else:
209+
pad_sequence = torch.nn.utils.rnn.pad_sequence
210+
linalg_solve = torch.linalg.solve
211+
jit_isinstance = torch.jit.isinstance
212+
213+
214+
@torch.no_grad()
215+
def broadcast_shapes(a: List[int], b: List[int]) -> List[int]:
216+
scalar = torch.zeros((), device="cpu")
217+
tensor_a = scalar.expand(a)
218+
tensor_b = scalar.expand(b)
219+
tensor_a, tensor_b = torch.broadcast_tensors(tensor_a, tensor_b)
220+
return tensor_a.shape
221+
222+
223+
if _v < "1.10.0":
224+
meshgrid = torch.meshgrid
225+
226+
trunc_divide = torch.floor_divide
227+
else:
228+
229+
def trunc_divide(input: torch.Tensor, other: Any) -> torch.Tensor:
230+
if not torch.jit.is_scripting():
231+
return input.div(other, rounding_mode="trunc")
232+
elif torch.jit.isinstance(other, float):
233+
return input.div(other, rounding_mode="trunc")
234+
elif torch.jit.isinstance(other, int):
235+
return input.div(other, rounding_mode="trunc")
236+
elif torch.jit.isinstance(other, torch.Tensor):
237+
return input.div(other, rounding_mode="trunc")
238+
else:
239+
assert False
240+
241+
def meshgrid(a: torch.Tensor, b: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
242+
x = torch.meshgrid(a, b, indexing="ij")
243+
assert len(x) == 2
244+
return x[0], x[1]

0 commit comments

Comments
 (0)