Skip to content

Commit ddbc3fe

Browse files
just use scan bro
1 parent 3e5721c commit ddbc3fe

File tree

1 file changed

+30
-2
lines changed

1 file changed

+30
-2
lines changed

pymc/distributions/transforms.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from functools import singledispatch
1717

1818
import numpy as np
19+
import pytensor
1920
import pytensor.tensor as pt
2021

2122

@@ -179,6 +180,33 @@ def __init__(self, n):
179180
"""
180181
self.n = n
181182

183+
def step(self, i, counter, L, y):
184+
y_star = y[counter : counter + i]
185+
dsy = y_star.dot(y_star)
186+
alpha_r = 1 / (dsy + 1)
187+
gamma = pt.sqrt(dsy + 2) * alpha_r
188+
189+
x = pt.join(0, gamma * y_star, pt.atleast_1d(alpha_r))
190+
next_L = L[i, : i + 1].set(x)
191+
log_det = pt.log(2) + 0.5 * (i - 2) * pt.log(dsy + 2) - i * pt.log(1 + dsy)
192+
193+
return next_L, log_det
194+
195+
def _compute_L_and_logdet_scan(self, value, *inputs):
196+
L = pt.eye(self.n)
197+
idxs = pt.arange(1, self.n)
198+
counters = pt.arange(0, self.n).cumsum()
199+
200+
results, _ = pytensor.scan(
201+
self.step, outputs_info=[L, None], sequences=[idxs, counters], non_sequences=[value]
202+
)
203+
204+
L_seq, log_det_seq = results
205+
L = L_seq[-1]
206+
log_det = pt.sum(log_det_seq)
207+
208+
return L, log_det
209+
182210
def _compute_L_and_logdet(self, value, *inputs):
183211
n = self.n
184212
counter = 0
@@ -201,7 +229,7 @@ def _compute_L_and_logdet(self, value, *inputs):
201229
return L, log_det
202230

203231
def backward(self, value, *inputs):
204-
L, _ = self._compute_L_and_logdet(value, *inputs)
232+
L, _ = self._compute_L_and_logdet_scan(value, *inputs)
205233
return L
206234

207235
def forward(self, value, *inputs):
@@ -211,7 +239,7 @@ def forward(self, value, *inputs):
211239
return pt.as_tensor_variable(np.random.normal(size=size))
212240

213241
def log_jac_det(self, value, *inputs):
214-
_, log_det = self._compute_L_and_logdet(value, *inputs)
242+
_, log_det = self._compute_L_and_logdet_scan(value, *inputs)
215243
return log_det
216244

217245

0 commit comments

Comments
 (0)