16
16
from functools import singledispatch
17
17
18
18
import numpy as np
19
+ import pytensor
19
20
import pytensor .tensor as pt
20
21
21
22
@@ -179,6 +180,33 @@ def __init__(self, n):
179
180
"""
180
181
self .n = n
181
182
183
+ def step (self , i , counter , L , y ):
184
+ y_star = y [counter : counter + i ]
185
+ dsy = y_star .dot (y_star )
186
+ alpha_r = 1 / (dsy + 1 )
187
+ gamma = pt .sqrt (dsy + 2 ) * alpha_r
188
+
189
+ x = pt .join (0 , gamma * y_star , pt .atleast_1d (alpha_r ))
190
+ next_L = L [i , : i + 1 ].set (x )
191
+ log_det = pt .log (2 ) + 0.5 * (i - 2 ) * pt .log (dsy + 2 ) - i * pt .log (1 + dsy )
192
+
193
+ return next_L , log_det
194
+
195
+ def _compute_L_and_logdet_scan (self , value , * inputs ):
196
+ L = pt .eye (self .n )
197
+ idxs = pt .arange (1 , self .n )
198
+ counters = pt .arange (0 , self .n ).cumsum ()
199
+
200
+ results , _ = pytensor .scan (
201
+ self .step , outputs_info = [L , None ], sequences = [idxs , counters ], non_sequences = [value ]
202
+ )
203
+
204
+ L_seq , log_det_seq = results
205
+ L = L_seq [- 1 ]
206
+ log_det = pt .sum (log_det_seq )
207
+
208
+ return L , log_det
209
+
182
210
def _compute_L_and_logdet (self , value , * inputs ):
183
211
n = self .n
184
212
counter = 0
@@ -201,7 +229,7 @@ def _compute_L_and_logdet(self, value, *inputs):
201
229
return L , log_det
202
230
203
231
def backward (self , value , * inputs ):
204
- L , _ = self ._compute_L_and_logdet (value , * inputs )
232
+ L , _ = self ._compute_L_and_logdet_scan (value , * inputs )
205
233
return L
206
234
207
235
def forward (self , value , * inputs ):
@@ -211,7 +239,7 @@ def forward(self, value, *inputs):
211
239
return pt .as_tensor_variable (np .random .normal (size = size ))
212
240
213
241
def log_jac_det (self , value , * inputs ):
214
- _ , log_det = self ._compute_L_and_logdet (value , * inputs )
242
+ _ , log_det = self ._compute_L_and_logdet_scan (value , * inputs )
215
243
return log_det
216
244
217
245
0 commit comments