|
| 1 | +# Copyright 2022 Sean Robertson |
| 2 | +# |
| 3 | +# Code for broadcast_shapes was adapted from PyTorch |
| 4 | +# https://github.com/pytorch/pytorch/blob/2367face24afb159f73ebf40dc6f23e46132b770/torch/functional.py |
| 5 | +# Code for TorchVersion was taken directly from PyTorch |
| 6 | +# https://github.com/pytorch/pytorch/blob/b737e09f60dd56dbae520e436648e1f3ebc1f937/torch/torch_version.py |
| 7 | +# See LICENSE_pytorch in project root directory for PyTorch license. |
| 8 | + |
| 9 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 10 | +# you may not use this file except in compliance with the License. |
| 11 | +# You may obtain a copy of the License at |
| 12 | + |
| 13 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 14 | + |
| 15 | +# Unless required by applicable law or agreed to in writing, software |
| 16 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 17 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 18 | +# See the License for the specific language governing permissions and |
| 19 | +# limitations under the License. |
| 20 | + |
| 21 | +from typing import Any, Iterable, List, Optional, Tuple, Union, NamedTuple, Set |
| 22 | + |
| 23 | +import torch |
| 24 | +import pydrobert.torch.config as config |
| 25 | + |
| 26 | + |
| 27 | +__all__ = [ |
| 28 | + "broadcast_shapes", |
| 29 | + "jit_isinstance", |
| 30 | + "linalg_solve", |
| 31 | + "meshgrid", |
| 32 | + "pad_sequence", |
| 33 | + "script", |
| 34 | + "SpoofPackedSequence", |
| 35 | + "trunc_divide", |
| 36 | +] |
| 37 | + |
| 38 | + |
| 39 | +# to avoid some scripting issues with torch.utils.nn.PackedSequence |
| 40 | +class SpoofPackedSequence(NamedTuple): |
| 41 | + data: torch.Tensor |
| 42 | + batch_sizes: torch.Tensor |
| 43 | + sorted_indices: Optional[torch.Tensor] |
| 44 | + unsorted_indices: Optional[torch.Tensor] |
| 45 | + |
| 46 | + |
| 47 | +if config.USE_JIT: |
| 48 | + script = torch.jit.script |
| 49 | +else: |
| 50 | + try: |
| 51 | + script = torch.jit.script_if_tracing |
| 52 | + except AttributeError: |
| 53 | + |
| 54 | + def script(obj, *args, **kwargs): |
| 55 | + return obj |
| 56 | + |
| 57 | + |
| 58 | +try: |
| 59 | + from torch.torch_version import __version__ as _v # type: ignore |
| 60 | +except ModuleNotFoundError: |
| 61 | + from torch.version import __version__ as internal_version |
| 62 | + from pkg_resources import packaging # type: ignore[attr-defined] |
| 63 | + |
| 64 | + Version = packaging.version.Version |
| 65 | + InvalidVersion = packaging.version.InvalidVersion |
| 66 | + |
| 67 | + class TorchVersion(str): |
| 68 | + """A string with magic powers to compare to both Version and iterables! |
| 69 | + Prior to 1.10.0 torch.__version__ was stored as a str and so many did |
| 70 | + comparisons against torch.__version__ as if it were a str. In order to not |
| 71 | + break them we have TorchVersion which masquerades as a str while also |
| 72 | + having the ability to compare against both packaging.version.Version as |
| 73 | + well as tuples of values, eg. (1, 2, 1) |
| 74 | + Examples: |
| 75 | + Comparing a TorchVersion object to a Version object |
| 76 | + TorchVersion('1.10.0a') > Version('1.10.0a') |
| 77 | + Comparing a TorchVersion object to a Tuple object |
| 78 | + TorchVersion('1.10.0a') > (1, 2) # 1.2 |
| 79 | + TorchVersion('1.10.0a') > (1, 2, 1) # 1.2.1 |
| 80 | + Comparing a TorchVersion object against a string |
| 81 | + TorchVersion('1.10.0a') > '1.2' |
| 82 | + TorchVersion('1.10.0a') > '1.2.1' |
| 83 | + """ |
| 84 | + |
| 85 | + # fully qualified type names here to appease mypy |
| 86 | + def _convert_to_version( |
| 87 | + self, inp: Union[packaging.version.Version, str, Iterable] |
| 88 | + ) -> packaging.version.Version: |
| 89 | + if isinstance(inp, Version): |
| 90 | + return inp |
| 91 | + elif isinstance(inp, str): |
| 92 | + return Version(inp) |
| 93 | + elif isinstance(inp, Iterable): |
| 94 | + # Ideally this should work for most cases by attempting to group |
| 95 | + # the version tuple, assuming the tuple looks (MAJOR, MINOR, ?PATCH) |
| 96 | + # Examples: |
| 97 | + # * (1) -> Version("1") |
| 98 | + # * (1, 20) -> Version("1.20") |
| 99 | + # * (1, 20, 1) -> Version("1.20.1") |
| 100 | + return Version(".".join((str(item) for item in inp))) |
| 101 | + else: |
| 102 | + raise InvalidVersion(inp) |
| 103 | + |
| 104 | + def __gt__(self, cmp): |
| 105 | + try: |
| 106 | + return Version(self).__gt__(self._convert_to_version(cmp)) |
| 107 | + except InvalidVersion: |
| 108 | + # Fall back to regular string comparison if dealing with an invalid |
| 109 | + # version like 'parrot' |
| 110 | + return super().__gt__(cmp) |
| 111 | + |
| 112 | + def __lt__(self, cmp): |
| 113 | + try: |
| 114 | + return Version(self).__lt__(self._convert_to_version(cmp)) |
| 115 | + except InvalidVersion: |
| 116 | + # Fall back to regular string comparison if dealing with an invalid |
| 117 | + # version like 'parrot' |
| 118 | + return super().__lt__(cmp) |
| 119 | + |
| 120 | + def __eq__(self, cmp): |
| 121 | + try: |
| 122 | + return Version(self).__eq__(self._convert_to_version(cmp)) |
| 123 | + except InvalidVersion: |
| 124 | + # Fall back to regular string comparison if dealing with an invalid |
| 125 | + # version like 'parrot' |
| 126 | + return super().__eq__(cmp) |
| 127 | + |
| 128 | + def __ge__(self, cmp): |
| 129 | + try: |
| 130 | + return Version(self).__ge__(self._convert_to_version(cmp)) |
| 131 | + except InvalidVersion: |
| 132 | + # Fall back to regular string comparison if dealing with an invalid |
| 133 | + # version like 'parrot' |
| 134 | + return super().__ge__(cmp) |
| 135 | + |
| 136 | + def __le__(self, cmp): |
| 137 | + try: |
| 138 | + return Version(self).__le__(self._convert_to_version(cmp)) |
| 139 | + except InvalidVersion: |
| 140 | + # Fall back to regular string comparison if dealing with an invalid |
| 141 | + # version like 'parrot' |
| 142 | + return super().__le__(cmp) |
| 143 | + |
| 144 | + _v = TorchVersion(internal_version) |
| 145 | + |
| 146 | +if _v < "1.8.0": |
| 147 | + |
| 148 | + @script |
| 149 | + def pad_sequence( |
| 150 | + sequences: List[torch.Tensor], |
| 151 | + batch_first: bool = False, |
| 152 | + padding_value: float = 0.0, |
| 153 | + ) -> torch.Tensor: |
| 154 | + shape = sequences[0].size() |
| 155 | + shape_rest = shape[1:] |
| 156 | + lens = [x.size(0) for x in sequences] |
| 157 | + max_len = max(lens) |
| 158 | + pad_shapes = [(max_len - x,) + shape_rest for x in lens] |
| 159 | + sequences = [ |
| 160 | + torch.cat( |
| 161 | + [ |
| 162 | + seq, |
| 163 | + torch.full(ps, padding_value, device=seq.device, dtype=seq.dtype), |
| 164 | + ], |
| 165 | + 0, |
| 166 | + ) |
| 167 | + for seq, ps in zip(sequences, pad_shapes) |
| 168 | + ] |
| 169 | + return torch.stack(sequences, 0 if batch_first else 1) |
| 170 | + |
| 171 | + def linalg_solve(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: |
| 172 | + return torch.solve(B, A)[0] |
| 173 | + |
| 174 | + @torch.jit.unused |
| 175 | + def _jit_isinstance(obj: Any, x: type) -> bool: |
| 176 | + if isinstance(obj, torch.nn.utils.rnn.PackedSequence): |
| 177 | + obj = obj.data, obj.batch_sizes, obj.sorted_indices, obj.unsorted_indices |
| 178 | + origin = getattr(x, "__origin__", None) |
| 179 | + if origin is None: |
| 180 | + return isinstance(obj, x) |
| 181 | + if origin in {tuple, list, set, List, Set, Tuple}: |
| 182 | + args = getattr(x, "__args__", None) |
| 183 | + if not args: |
| 184 | + return ( |
| 185 | + (origin in {tuple, Tuple} and obj == tuple()) |
| 186 | + or (origin in {list, List} and obj == list()) |
| 187 | + or (origin in {set, Set} and obj == set()) |
| 188 | + ) |
| 189 | + if origin in {tuple, Tuple}: |
| 190 | + return (len(obj) is len(args)) and all( |
| 191 | + _jit_isinstance(*y) for y in zip(obj, args) |
| 192 | + ) |
| 193 | + else: |
| 194 | + assert len(args) == 1 |
| 195 | + return all(_jit_isinstance(o, args[0]) for o in obj) |
| 196 | + elif origin is Union: |
| 197 | + args = x.__args__ |
| 198 | + return any(_jit_isinstance(obj, y) for y in args) |
| 199 | + return False |
| 200 | + |
| 201 | + def jit_isinstance(obj: Any, x: type) -> bool: |
| 202 | + if torch.jit.is_scripting(): |
| 203 | + return isinstance(obj, x) |
| 204 | + else: |
| 205 | + return _jit_isinstance(obj, x) |
| 206 | + |
| 207 | + |
| 208 | +else: |
| 209 | + pad_sequence = torch.nn.utils.rnn.pad_sequence |
| 210 | + linalg_solve = torch.linalg.solve |
| 211 | + jit_isinstance = torch.jit.isinstance |
| 212 | + |
| 213 | + |
| 214 | +@torch.no_grad() |
| 215 | +def broadcast_shapes(a: List[int], b: List[int]) -> List[int]: |
| 216 | + scalar = torch.zeros((), device="cpu") |
| 217 | + tensor_a = scalar.expand(a) |
| 218 | + tensor_b = scalar.expand(b) |
| 219 | + tensor_a, tensor_b = torch.broadcast_tensors(tensor_a, tensor_b) |
| 220 | + return tensor_a.shape |
| 221 | + |
| 222 | + |
| 223 | +if _v < "1.10.0": |
| 224 | + meshgrid = torch.meshgrid |
| 225 | + |
| 226 | + trunc_divide = torch.floor_divide |
| 227 | +else: |
| 228 | + |
| 229 | + def trunc_divide(input: torch.Tensor, other: Any) -> torch.Tensor: |
| 230 | + if not torch.jit.is_scripting(): |
| 231 | + return input.div(other, rounding_mode="trunc") |
| 232 | + elif torch.jit.isinstance(other, float): |
| 233 | + return input.div(other, rounding_mode="trunc") |
| 234 | + elif torch.jit.isinstance(other, int): |
| 235 | + return input.div(other, rounding_mode="trunc") |
| 236 | + elif torch.jit.isinstance(other, torch.Tensor): |
| 237 | + return input.div(other, rounding_mode="trunc") |
| 238 | + else: |
| 239 | + assert False |
| 240 | + |
| 241 | + def meshgrid(a: torch.Tensor, b: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| 242 | + x = torch.meshgrid(a, b, indexing="ij") |
| 243 | + assert len(x) == 2 |
| 244 | + return x[0], x[1] |
0 commit comments