Skip to content

Commit 6665dc3

Browse files
Fix tt-prefixed utility names
1 parent a735f00 commit 6665dc3

File tree

2 files changed

+22
-25
lines changed

2 files changed

+22
-25
lines changed

pymc3_hmm/utils.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
import matplotlib.pyplot as plt
55
import numpy as np
66
import pandas as pd
7+
import scipy.special as sp
78
from matplotlib import cm
89
from matplotlib.axes import Axes
910
from matplotlib.colors import Colormap
10-
from scipy.special import logsumexp
1111

1212
vsearchsorted = np.vectorize(np.searchsorted, otypes=[int], signature="(n),()->()")
1313

@@ -76,7 +76,7 @@ def compute_trans_freqs(states, N_states, counts_only=False):
7676
return res
7777

7878

79-
def tt_logsumexp(x, axis=None, keepdims=False):
79+
def logsumexp(x, axis=None, keepdims=False):
8080
"""Construct a Theano graph for a log-sum-exp calculation."""
8181
x_max_ = at.max(x, axis=axis, keepdims=True)
8282

@@ -103,7 +103,7 @@ def tt_logsumexp(x, axis=None, keepdims=False):
103103
return res + x_max_
104104

105105

106-
def tt_logdotexp(A, b):
106+
def logdotexp(A, b):
107107
"""Construct a Theano graph for a numerically stable log-scale dot product.
108108
109109
The result is more or less equivalent to `tt.log(tt.exp(A).dot(tt.exp(b)))`
@@ -120,11 +120,11 @@ def tt_logdotexp(A, b):
120120
sqz = True
121121

122122
b_bcast = b.dimshuffle(shape_b)
123-
res = tt_logsumexp(A_bcast + b_bcast, axis=1)
123+
res = logsumexp(A_bcast + b_bcast, axis=1)
124124
return res.squeeze() if sqz else res
125125

126126

127-
def logdotexp(A, b):
127+
def np_logdotexp(A, b):
128128
"""Compute a numerically stable log-scale dot product of NumPy values.
129129
130130
The result is more or less equivalent to `np.log(np.exp(A).dot(np.exp(b)))`
@@ -138,7 +138,7 @@ def logdotexp(A, b):
138138

139139
A_bcast = np.expand_dims(A, -1)
140140

141-
res = logsumexp(A_bcast + b_bcast, axis=1)
141+
res = sp.logsumexp(A_bcast + b_bcast, axis=1)
142142
return res.squeeze() if sqz else res
143143

144144

@@ -184,10 +184,10 @@ def multilogit_inv(ys):
184184
"""
185185
if isinstance(ys, np.ndarray):
186186
lib = np
187-
lib_logsumexp = logsumexp
187+
lib_logsumexp = sp.logsumexp
188188
else:
189189
lib = at
190-
lib_logsumexp = tt_logsumexp
190+
lib_logsumexp = logsumexp
191191

192192
# exp_ys = lib.exp(ys)
193193
# res = lib.concatenate([exp_ys, lib.ones(tuple(ys.shape)[:-1] + (1,))], axis=-1)

tests/test_utils.py

+14-17
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import aesara
21
import aesara.tensor as at
32
import numpy as np
43
import pytest
@@ -7,9 +6,9 @@
76
from pymc3_hmm.utils import (
87
compute_trans_freqs,
98
logdotexp,
9+
logsumexp,
1010
multilogit_inv,
11-
tt_logdotexp,
12-
tt_logsumexp,
11+
np_logdotexp,
1312
)
1413

1514

@@ -35,62 +34,60 @@ def test_compute_trans_freqs():
3534
)
3635
def test_logsumexp(test_input):
3736
np_res = sp.special.logsumexp(test_input)
38-
tt_res = tt_logsumexp(at.as_tensor_variable(test_input)).eval()
39-
assert np.array_equal(np_res, tt_res)
37+
at_res = logsumexp(at.as_tensor_variable(test_input)).eval()
38+
assert np.array_equal(np_res, at_res)
4039

4140

42-
def test_logdotexp():
41+
def test_np_logdotexp():
4342
A = np.c_[[1.0, 2.0], [3.0, 4.0], [10.0, 20.0]]
4443
b = np.c_[[0.1], [0.2], [30.0]].T
4544

46-
test_res = logdotexp(np.log(A), np.log(b))
45+
test_res = np_logdotexp(np.log(A), np.log(b))
4746
assert test_res.shape == (2, 1)
4847
assert np.allclose(A.dot(b), np.exp(test_res))
4948

5049
b = np.r_[0.1, 0.2, 30.0]
51-
test_res = logdotexp(np.log(A), np.log(b))
50+
test_res = np_logdotexp(np.log(A), np.log(b))
5251
assert test_res.shape == (2,)
5352
assert np.allclose(A.dot(b), np.exp(test_res))
5453

5554
A = np.c_[[1.0, 2.0], [10.0, 20.0]]
5655
b = np.c_[[0.1], [0.2]].T
57-
test_res = logdotexp(np.log(A), np.log(b))
56+
test_res = np_logdotexp(np.log(A), np.log(b))
5857
assert test_res.shape == (2, 1)
5958
assert np.allclose(A.dot(b), np.exp(test_res))
6059

6160
b = np.r_[0.1, 0.2]
62-
test_res = logdotexp(np.log(A), np.log(b))
61+
test_res = np_logdotexp(np.log(A), np.log(b))
6362
assert test_res.shape == (2,)
6463
assert np.allclose(A.dot(b), np.exp(test_res))
6564

6665

67-
def test_tt_logdotexp():
66+
def test_at_logdotexp():
6867

6968
np.seterr(over="ignore", under="ignore")
7069

71-
aesara.config.compute_test_value = "warn"
72-
7370
A = np.c_[[1.0, 2.0], [3.0, 4.0], [10.0, 20.0]]
7471
b = np.c_[[0.1], [0.2], [30.0]].T
7572
A_tt = at.as_tensor_variable(A)
7673
b_tt = at.as_tensor_variable(b)
77-
test_res = tt_logdotexp(at.log(A_tt), at.log(b_tt)).eval()
74+
test_res = logdotexp(at.log(A_tt), at.log(b_tt)).eval()
7875
assert test_res.shape == (2, 1)
7976
assert np.allclose(A.dot(b), np.exp(test_res))
8077

8178
b = np.r_[0.1, 0.2, 30.0]
82-
test_res = tt_logdotexp(at.log(A), at.log(b)).eval()
79+
test_res = logdotexp(at.log(A), at.log(b)).eval()
8380
assert test_res.shape == (2,)
8481
assert np.allclose(A.dot(b), np.exp(test_res))
8582

8683
A = np.c_[[1.0, 2.0], [10.0, 20.0]]
8784
b = np.c_[[0.1], [0.2]].T
88-
test_res = tt_logdotexp(at.log(A), at.log(b)).eval()
85+
test_res = logdotexp(at.log(A), at.log(b)).eval()
8986
assert test_res.shape == (2, 1)
9087
assert np.allclose(A.dot(b), np.exp(test_res))
9188

9289
b = np.r_[0.1, 0.2]
93-
test_res = tt_logdotexp(at.log(A), at.log(b)).eval()
90+
test_res = logdotexp(at.log(A), at.log(b)).eval()
9491
assert test_res.shape == (2,)
9592
assert np.allclose(A.dot(b), np.exp(test_res))
9693

0 commit comments

Comments
 (0)