Skip to content

Commit 66633c9

Browse files
authored
Compatibility fixes + setup changes (#58)
* Appveyor py310 and skip middling * Try this * Please * Mucking around with tox and versions * blech * Now? * Why do you hate me? * Is it you? * Hmmm * Mebbeh? * Weirdness * Do it mah dude * Blerg? * More blerg * Please * Hmm? * Getting minimum pytorch version working. * Updated CHANGELOG + formatting
1 parent 12eff87 commit 66633c9

15 files changed

+100
-215
lines changed

.appveyor.yml

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
version: build.{build}.branch.{branch}
2+
# FIXME(sdrobert): tox-ltt doesn't work on Windows as of yet, causing a
3+
# virtualenv error. Until it's fixed we'll use the Ubuntu image.
4+
image: Ubuntu
5+
6+
environment:
7+
matrix:
8+
- TOXENV: py36-earliest
9+
PYTHON: "3.6"
10+
- TOXENV: py37
11+
PYTHON: "3.7"
12+
- TOXENV: py38
13+
PYTHON: "3.8"
14+
- TOXENV: py39
15+
PYTHON: "3.9"
16+
# - TOXENV: py310 # wheel not available yet
17+
# PYTHON: "3.10"
18+
19+
stack: python %PYTHON%
20+
21+
branches:
22+
except:
23+
- /docs/
24+
25+
for:
26+
-
27+
matrix:
28+
only:
29+
- PYTHON: "3.7"
30+
- PYTHON: "3.8"
31+
skip_non_tags: true
32+
33+
build: off
34+
35+
install:
36+
- python3 -m pip install -U pip virtualenv setuptools wheel six
37+
- python3 -m pip install -U tox tox-ltt
38+
39+
test_script:
40+
- python3 -m tox

.gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -106,4 +106,4 @@ venv.bak/
106106
.ftpignore
107107
.ftpconfig
108108
.vscode
109-
version.py
109+
_version.py

CHANGELOG.md

+5
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22

33
## HEAD
44

5+
- Removed `setup.py`.
6+
- Deleted conda recipe in prep for [conda-forge](https://conda-forge.org/).
7+
- Compatibility/determinism fixes for 1.5.1.
8+
- Bump minimum PyTorch version to 1.5.1. Actually testing this minimum!
9+
- `version.py` -> `_version.py`.
510
- A number of modifications and additions related to decoding and language
611
models, including:
712
- `beam_search_advance` has been simplified, with much of the end-of-sequence

appveyor.yml

-28
This file was deleted.

pyproject.toml

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
11
[build-system]
2-
requires = ["setuptools>=42", "wheel", "setuptools_scm[toml]>=3.4"]
2+
requires = ["setuptools>=45", "wheel", "setuptools_scm>=6.2"]
33
build-backend = "setuptools.build_meta"
4+
5+
[tool.setuptools_scm]
6+
write_to = "src/pydrobert/torch/_version.py"

recipe/meta.yaml

-80
This file was deleted.

setup.cfg

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ package_dir =
2323
python_requires = >= 3.6
2424
install_requires =
2525
numpy
26-
torch>=1.0.1
26+
torch>=1.5.1
2727
param
2828

2929
[options.entry_points]

setup.py

-6
This file was deleted.

src/pydrobert/torch/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
__copyright__ = "Copyright 2021 Sean Robertson"
1919

2020
try:
21-
from .version import version as __version__ # type: ignore
21+
from ._version import version as __version__ # type: ignore
2222
except ImportError:
2323
__version__ = "inplace"
2424

src/pydrobert/torch/estimators.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,6 @@
2828

2929
import torch
3030

31-
try:
32-
torch_bool = torch.bool
33-
except AttributeError:
34-
torch_bool = torch.uint8
35-
3631
__all__ = [
3732
"to_z",
3833
"to_b",
@@ -425,11 +420,10 @@ def _to_z_tilde(logits, b, dist):
425420
-torch.log(-log_v / theta - log_v.gather(-1, b[..., None])),
426421
)
427422
elif dist in ONEHOT_SYNONYMS:
428-
b = b.byte()
429423
theta = torch.softmax(logits, dim=-1)
430424
log_v = v.log()
431425
z_tilde = torch.where(
432-
b,
426+
b.bool(),
433427
-torch.log(-log_v),
434428
-torch.log(-log_v / theta - log_v.gather(-1, b.argmax(-1, keepdim=True))),
435429
)

src/pydrobert/torch/layers.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1194,7 +1194,7 @@ def forward(
11941194
f"Expected dim 2 of logits to be {self.lm.vocab_size + 1}, got {Vp1}"
11951195
)
11961196
if lens is None:
1197-
lens = torch.full((N,), T, device=logits.device)
1197+
lens = torch.full((N,), T, device=logits.device, dtype=torch.long)
11981198
len_min = len_max = T
11991199
elif lens.dim() != 1:
12001200
raise RuntimeError("lens must be 1 dimensional")
@@ -1211,7 +1211,9 @@ def forward(
12111211
y_prev_lens = y_prev_last = torch.zeros(
12121212
(N, 1), dtype=torch.long, device=logits.device
12131213
)
1214-
prev_is_prefix = torch.full((N, 1, 1), True, device=logits.device)
1214+
prev_is_prefix = torch.full(
1215+
(N, 1, 1), True, device=logits.device, dtype=torch.bool
1216+
)
12151217
if self.lm is not None:
12161218
prev = self.lm.update_input(prev, y_prev)
12171219
prev_width = 1

src/pydrobert/torch/util.py

+16-5
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def beam_search_advance(
249249
K = min(width, Kp * V)
250250
cand_log_probs = (log_probs_prev.unsqueeze(2) + log_probs_t).flatten(1)
251251
log_probs_next, next_ind = cand_log_probs.topk(K, 1)
252-
next_src = torch.div(next_ind, V, rounding_mode="trunc")
252+
next_src = next_ind.floor_divide(V)
253253
next_token = (next_ind % V).unsqueeze(0) # (1, N, K)
254254

255255
if tm1:
@@ -578,7 +578,9 @@ def ctc_prefix_search_advance(
578578
del tot_probs_cand
579579

580580
next_is_nonext = next_ind >= (Kp * V)
581-
next_src = torch.where(next_is_nonext, next_ind - (Kp * V), next_ind // V)
581+
next_src = torch.where(
582+
next_is_nonext, next_ind - (Kp * V), next_ind.floor_divide(V)
583+
)
582584
next_ext = next_ind % V
583585

584586
y_next_prefix_lens = y_prev_lens.gather(1, next_src) # (N, K)
@@ -1967,7 +1969,7 @@ def pad_variable(
19671969
arange_ = torch.arange(Tp, device=x.device)
19681970
left_mask = (pad[0].unsqueeze(1) > arange_).unsqueeze(2).expand(N, Tp, F)
19691971
if mode == "constant":
1970-
buff = torch.tensor(value, device=x.device, dtype=x.dtype).view(1)
1972+
buff = torch.tensor(value, device=x.device).to(x.dtype).view(1)
19711973
left_pad = buff.expand(pad[0].sum() * F)
19721974
right_pad = buff.expand(pad[1].sum() * F)
19731975
elif mode == "reflect":
@@ -2172,8 +2174,11 @@ def _string_matching(
21722174
assert not exclude_last or (return_mask or return_prf_dsts)
21732175
if ref.dim() != 2 or hyp.dim() != 2:
21742176
raise RuntimeError("ref and hyp must be 2 dimensional")
2177+
mult = 1.0
21752178
if ins_cost == del_cost == sub_cost > 0.0:
21762179
# results are equivalent and faster to return
2180+
if not return_mistakes:
2181+
mult = ins_cost
21772182
ins_cost = del_cost = sub_cost = 1.0
21782183
return_mistakes = False
21792184
elif return_mistakes and warn:
@@ -2220,8 +2225,12 @@ def _string_matching(
22202225
hyp_lens = hyp_lens - hyp_eq_mask.to(hyp_lens.dtype)
22212226
del ref_eq_mask, hyp_eq_mask
22222227
else:
2223-
ref_lens = torch.full((batch_size,), max_ref_steps, device=ref.device)
2224-
hyp_lens = torch.full((batch_size,), max_hyp_steps, device=ref.device)
2228+
ref_lens = torch.full(
2229+
(batch_size,), max_ref_steps, device=ref.device, dtype=torch.long
2230+
)
2231+
hyp_lens = torch.full(
2232+
(batch_size,), max_hyp_steps, device=ref.device, dtype=torch.long
2233+
)
22252234
ins_cost = torch.tensor(float(ins_cost), device=device)
22262235
del_cost = torch.tensor(float(del_cost), device=device)
22272236
sub_cost = torch.tensor(float(sub_cost), device=device)
@@ -2345,6 +2354,7 @@ def _string_matching(
23452354
)
23462355
return mask
23472356
elif return_prf_dsts:
2357+
prefix_ers = prefix_ers * mult
23482358
if norm:
23492359
prefix_ers = prefix_ers / ref_lens.to(row.dtype)
23502360
zero_mask = ref_lens.eq(0).unsqueeze(0)
@@ -2381,6 +2391,7 @@ def _string_matching(
23812391
er = mistakes.gather(0, ref_lens.unsqueeze(0)).squeeze(0)
23822392
else:
23832393
er = row.gather(0, ref_lens.unsqueeze(0)).squeeze(0)
2394+
er = er * mult
23842395
if norm:
23852396
er = er / ref_lens.to(er.dtype)
23862397
zero_mask = ref_lens.eq(0)

tests/test_layers.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -767,9 +767,7 @@ def update_input(self, prev, hist):
767767
def calc_idx_log_probs(self, hist, prev, idx):
768768
idx_zero = idx == 0
769769
if idx_zero.all():
770-
x = torch.arange(hist.size(0), device=hist.device).clamp(
771-
max=self.vocab_size
772-
)
770+
x = torch.arange(hist.size(0), device=hist.device)
773771
elif not idx.dim():
774772
x = hist[idx - 1]
775773
else:
@@ -785,8 +783,8 @@ def calc_idx_log_probs(self, hist, prev, idx):
785783
{"hidden_state": h_1, "cell_state": c_1},
786784
)
787785

788-
T, N, V, K = 64, 16, 32, 8
789-
assert K <= V and N <= V
786+
T, N, V, K = 64, 16, 128, 8
787+
assert K <= V and N * K <= V
790788
lm = RNNLM(V)
791789
search = layers.BeamSearch(lm, K, eos=0, max_iters=T).to(device)
792790
y_prev = torch.arange(N, device=device)

0 commit comments

Comments
 (0)