@@ -219,9 +219,8 @@ class IndexDistribution(Distribution):
219
219
class (such as Bernoulli, LogNormal, etc.) with information
220
220
about the conditions on the parameters of the distribution.
221
221
222
- For example, an IndexDistribution can be defined as
223
- a Bernoulli distribution whose parameter p is a function of
224
- a different input parameter.
222
+ It can also wrap a list of pre-discretized distributions (previously
223
+ provided by TimeVaryingDiscreteDistribution) and provide the same API.
225
224
226
225
Parameters
227
226
----------
@@ -235,14 +234,17 @@ class (such as Bernoulli, LogNormal, etc.) with information
235
234
Keys should match the arguments to the engine class
236
235
constructor.
237
236
237
+ distributions: [DiscreteDistribution]
238
+ Optional. A list of discrete distributions to wrap directly.
239
+
238
240
seed : int
239
241
Seed for random number generator.
240
242
"""
241
243
242
244
conditional = None
243
245
engine = None
244
246
245
- def __init__ (self , engine , conditional , RNG = None , seed = 0 ):
247
+ def __init__ (self , engine = None , conditional = None , distributions = None , RNG = None , seed = 0 ):
246
248
if RNG is None :
247
249
# Set up the RNG
248
250
super ().__init__ (seed )
@@ -255,11 +257,24 @@ def __init__(self, engine, conditional, RNG=None, seed=0):
255
257
# and create a new one.
256
258
self .seed = seed
257
259
258
- self .conditional = conditional
260
+ # Mode 1: wrapping a list of discrete distributions
261
+ if distributions is not None :
262
+ self .distributions = distributions
263
+ self .engine = None
264
+ self .conditional = None
265
+ self .dstns = []
266
+ return
267
+
268
+ # Mode 2: engine + conditional parameters (original IndexDistribution)
269
+ self .conditional = conditional if conditional is not None else {}
259
270
self .engine = engine
260
271
261
272
self .dstns = []
262
273
274
+ # If no engine/conditional were provided, remain empty (should not happen in normal use)
275
+ if self .engine is None and not self .conditional :
276
+ return
277
+
263
278
# Test one item to determine case handling
264
279
item0 = list (self .conditional .values ())[0 ]
265
280
@@ -273,7 +288,7 @@ def __init__(self, engine, conditional, RNG=None, seed=0):
273
288
274
289
elif type (item0 ) is float :
275
290
self .dstns = [
276
- self .engine (seed = self ._rng .integers (0 , 2 ** 31 - 1 ), ** conditional )
291
+ self .engine (seed = self ._rng .integers (0 , 2 ** 31 - 1 ), ** self . conditional )
277
292
]
278
293
279
294
else :
@@ -284,6 +299,9 @@ def __init__(self, engine, conditional, RNG=None, seed=0):
284
299
)
285
300
286
301
def __getitem__ (self , y ):
302
+ # Prefer discrete list mode if present
303
+ if hasattr (self , "distributions" ) and self .distributions :
304
+ return self .distributions [y ]
287
305
return self .dstns [y ]
288
306
289
307
def discretize (self , N , ** kwds ):
@@ -302,16 +320,16 @@ def discretize(self, N, **kwds):
302
320
303
321
Returns:
304
322
------------
305
- dists : [DiscreteDistribution]
306
- A list of DiscreteDistributions that are the
307
- approximation of engine distribution under each condition.
308
-
309
- TODO: It would be better if there were a conditional discrete
310
- distribution representation. But that integrates with the
311
- solution code. This implementation will return the list of
312
- distributions representations expected by the solution code.
323
+ dists : [DiscreteDistribution] or IndexDistribution
324
+ If parameterization is constant, returns a single DiscreteDistribution.
325
+ If parameterization varies with index, returns an IndexDistribution in
326
+ discrete-list mode, wrapping the corresponding discrete distributions.
313
327
"""
314
328
329
+ # If already in discrete list mode, return self (already discretized)
330
+ if hasattr (self , "distributions" ) and self .distributions :
331
+ return self
332
+
315
333
# test one item to determine case handling
316
334
item0 = list (self .conditional .values ())[0 ]
317
335
@@ -320,8 +338,10 @@ def discretize(self, N, **kwds):
320
338
return self .dstns [0 ].discretize (N , ** kwds )
321
339
322
340
if type (item0 ) is list :
323
- return TimeVaryingDiscreteDistribution (
324
- [self [i ].discretize (N , ** kwds ) for i , _ in enumerate (item0 )]
341
+ # Return an IndexDistribution wrapping a list of discrete distributions
342
+ return IndexDistribution (
343
+ distributions = [self [i ].discretize (N , ** kwds ) for i , _ in enumerate (item0 )],
344
+ seed = self .seed ,
325
345
)
326
346
327
347
def draw (self , condition ):
@@ -345,6 +365,15 @@ def draw(self, condition):
345
365
# are of the same type.
346
366
# this matches the HARK 'time-varying' model architecture.
347
367
368
+ # If wrapping discrete distributions, draw from those
369
+ if hasattr (self , "distributions" ) and self .distributions :
370
+ draws = np .zeros (condition .size )
371
+ for c in np .unique (condition ):
372
+ these = c == condition
373
+ N = np .sum (these )
374
+ draws [these ] = self .distributions [c ].draw (N )
375
+ return draws
376
+
348
377
# test one item to determine case handling
349
378
item0 = list (self .conditional .values ())[0 ]
350
379
@@ -367,70 +396,6 @@ def draw(self, condition):
367
396
these = c == condition
368
397
N = np .sum (these )
369
398
370
- cond = {key : val [c ] for (key , val ) in self .conditional .items ()}
371
399
draws [these ] = self [c ].draw (N )
372
400
373
401
return draws
374
-
375
-
376
- class TimeVaryingDiscreteDistribution (Distribution ):
377
- """
378
- This class provides a way to define a discrete distribution that
379
- is conditional on an index.
380
-
381
- Wraps a list of discrete distributions.
382
-
383
- Parameters
384
- ----------
385
-
386
- distributions : [DiscreteDistribution]
387
- A list of discrete distributions
388
-
389
- seed : int
390
- Seed for random number generator.
391
- """
392
-
393
- distributions = []
394
-
395
- def __init__ (self , distributions , seed = 0 ):
396
- # Set up the RNG
397
- super ().__init__ (seed )
398
-
399
- self .distributions = distributions
400
-
401
- def __getitem__ (self , y ):
402
- return self .distributions [y ]
403
-
404
- def draw (self , condition ):
405
- """
406
- Generate arrays of draws.
407
- The input is an array containing the conditions.
408
- The output is an array of the same length (axis 1 dimension)
409
- as the conditions containing random draws of the conditional
410
- distribution.
411
-
412
- Parameters
413
- ----------
414
- condition : np.array
415
- The input conditions to the distribution.
416
-
417
- Returns:
418
- ------------
419
- draws : np.array
420
- """
421
- # for now, assume that all the conditionals
422
- # are of the same type.
423
- # this matches the HARK 'time-varying' model architecture.
424
-
425
- # conditions are indices into list
426
- # somewhat convoluted sampling strategy retained
427
- # for test backwards compatibility
428
- draws = np .zeros (condition .size )
429
-
430
- for c in np .unique (condition ):
431
- these = c == condition
432
- N = np .sum (these )
433
-
434
- draws [these ] = self .distributions [c ].draw (N )
435
-
436
- return draws
0 commit comments