Skip to content

Commit e3687f8

Browse files
Remove TimeVaryingDiscreteDistribution, extend IndexDistribution functionality
Co-authored-by: alanlujan91 <[email protected]>
1 parent a98e751 commit e3687f8

File tree

6 files changed

+54
-97
lines changed

6 files changed

+54
-97
lines changed

HARK/Calibration/Income/IncomeProcesses.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
DiscreteDistributionLabeled,
1212
IndexDistribution,
1313
MeanOneLogNormal,
14-
TimeVaryingDiscreteDistribution,
1514
Lognormal,
1615
Uniform,
1716
)
@@ -813,14 +812,14 @@ def get_PermShkDstn_from_IncShkDstn(IncShkDstn, RNG):
813812
PermShkDstn = [
814813
this.make_univariate(0, seed=RNG.integers(0, 2**31 - 1)) for this in IncShkDstn
815814
]
816-
return TimeVaryingDiscreteDistribution(PermShkDstn, seed=RNG.integers(0, 2**31 - 1))
815+
return IndexDistribution(distributions=PermShkDstn, seed=RNG.integers(0, 2**31 - 1))
817816

818817

819818
def get_TranShkDstn_from_IncShkDstn(IncShkDstn, RNG):
820819
TranShkDstn = [
821820
this.make_univariate(1, seed=RNG.integers(0, 2**31 - 1)) for this in IncShkDstn
822821
]
823-
return TimeVaryingDiscreteDistribution(TranShkDstn, seed=RNG.integers(0, 2**31 - 1))
822+
return IndexDistribution(distributions=TranShkDstn, seed=RNG.integers(0, 2**31 - 1))
824823

825824

826825
def get_PermShkDstn_from_IncShkDstn_markov(IncShkDstn, RNG):

HARK/core.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from HARK.distributions import (
2626
Distribution,
2727
IndexDistribution,
28-
TimeVaryingDiscreteDistribution,
2928
combine_indep_dstns,
3029
)
3130
from HARK.parallel import multi_thread_commands, multi_thread_commands_fake
@@ -1051,7 +1050,7 @@ def check_elements_of_time_vary_are_lists(self):
10511050
continue
10521051
if not isinstance(
10531052
getattr(self, param),
1054-
(TimeVaryingDiscreteDistribution, IndexDistribution),
1053+
(IndexDistribution,),
10551054
):
10561055
assert type(getattr(self, param)) == list, (
10571056
param

HARK/distributions/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
"DiscreteDistributionLabeled",
44
"Distribution",
55
"IndexDistribution",
6-
"TimeVaryingDiscreteDistribution",
76
"Lognormal",
87
"MeanOneLogNormal",
98
"Normal",
@@ -29,7 +28,6 @@
2928
Distribution,
3029
IndexDistribution,
3130
MarkovProcess,
32-
TimeVaryingDiscreteDistribution,
3331
)
3432
from HARK.distributions.continuous import (
3533
Lognormal,
@@ -50,7 +48,7 @@
5048
approx_lognormal_gauss_hermite,
5149
calc_expectation,
5250
calc_lognormal_style_pars_from_normal_pars,
53-
calc_normal_style_pars_from_lognormal_pars,
51+
calc_lognormal_style_pars_from_lognormal_pars,
5452
combine_indep_dstns,
5553
distr_of_function,
5654
expected,

HARK/distributions/base.py

Lines changed: 45 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -219,9 +219,8 @@ class IndexDistribution(Distribution):
219219
class (such as Bernoulli, LogNormal, etc.) with information
220220
about the conditions on the parameters of the distribution.
221221
222-
For example, an IndexDistribution can be defined as
223-
a Bernoulli distribution whose parameter p is a function of
224-
a different input parameter.
222+
It can also wrap a list of pre-discretized distributions (previously
223+
provided by TimeVaryingDiscreteDistribution) and provide the same API.
225224
226225
Parameters
227226
----------
@@ -235,14 +234,17 @@ class (such as Bernoulli, LogNormal, etc.) with information
235234
Keys should match the arguments to the engine class
236235
constructor.
237236
237+
distributions: [DiscreteDistribution]
238+
Optional. A list of discrete distributions to wrap directly.
239+
238240
seed : int
239241
Seed for random number generator.
240242
"""
241243

242244
conditional = None
243245
engine = None
244246

245-
def __init__(self, engine, conditional, RNG=None, seed=0):
247+
def __init__(self, engine=None, conditional=None, distributions=None, RNG=None, seed=0):
246248
if RNG is None:
247249
# Set up the RNG
248250
super().__init__(seed)
@@ -255,11 +257,24 @@ def __init__(self, engine, conditional, RNG=None, seed=0):
255257
# and create a new one.
256258
self.seed = seed
257259

258-
self.conditional = conditional
260+
# Mode 1: wrapping a list of discrete distributions
261+
if distributions is not None:
262+
self.distributions = distributions
263+
self.engine = None
264+
self.conditional = None
265+
self.dstns = []
266+
return
267+
268+
# Mode 2: engine + conditional parameters (original IndexDistribution)
269+
self.conditional = conditional if conditional is not None else {}
259270
self.engine = engine
260271

261272
self.dstns = []
262273

274+
# If no engine/conditional were provided, remain empty (should not happen in normal use)
275+
if self.engine is None and not self.conditional:
276+
return
277+
263278
# Test one item to determine case handling
264279
item0 = list(self.conditional.values())[0]
265280

@@ -273,7 +288,7 @@ def __init__(self, engine, conditional, RNG=None, seed=0):
273288

274289
elif type(item0) is float:
275290
self.dstns = [
276-
self.engine(seed=self._rng.integers(0, 2**31 - 1), **conditional)
291+
self.engine(seed=self._rng.integers(0, 2**31 - 1), **self.conditional)
277292
]
278293

279294
else:
@@ -284,6 +299,9 @@ def __init__(self, engine, conditional, RNG=None, seed=0):
284299
)
285300

286301
def __getitem__(self, y):
302+
# Prefer discrete list mode if present
303+
if hasattr(self, "distributions") and self.distributions:
304+
return self.distributions[y]
287305
return self.dstns[y]
288306

289307
def discretize(self, N, **kwds):
@@ -302,16 +320,16 @@ def discretize(self, N, **kwds):
302320
303321
Returns:
304322
------------
305-
dists : [DiscreteDistribution]
306-
A list of DiscreteDistributions that are the
307-
approximation of engine distribution under each condition.
308-
309-
TODO: It would be better if there were a conditional discrete
310-
distribution representation. But that integrates with the
311-
solution code. This implementation will return the list of
312-
distributions representations expected by the solution code.
323+
dists : [DiscreteDistribution] or IndexDistribution
324+
If parameterization is constant, returns a single DiscreteDistribution.
325+
If parameterization varies with index, returns an IndexDistribution in
326+
discrete-list mode, wrapping the corresponding discrete distributions.
313327
"""
314328

329+
# If already in discrete list mode, return self (already discretized)
330+
if hasattr(self, "distributions") and self.distributions:
331+
return self
332+
315333
# test one item to determine case handling
316334
item0 = list(self.conditional.values())[0]
317335

@@ -320,8 +338,10 @@ def discretize(self, N, **kwds):
320338
return self.dstns[0].discretize(N, **kwds)
321339

322340
if type(item0) is list:
323-
return TimeVaryingDiscreteDistribution(
324-
[self[i].discretize(N, **kwds) for i, _ in enumerate(item0)]
341+
# Return an IndexDistribution wrapping a list of discrete distributions
342+
return IndexDistribution(
343+
distributions=[self[i].discretize(N, **kwds) for i, _ in enumerate(item0)],
344+
seed=self.seed,
325345
)
326346

327347
def draw(self, condition):
@@ -345,6 +365,15 @@ def draw(self, condition):
345365
# are of the same type.
346366
# this matches the HARK 'time-varying' model architecture.
347367

368+
# If wrapping discrete distributions, draw from those
369+
if hasattr(self, "distributions") and self.distributions:
370+
draws = np.zeros(condition.size)
371+
for c in np.unique(condition):
372+
these = c == condition
373+
N = np.sum(these)
374+
draws[these] = self.distributions[c].draw(N)
375+
return draws
376+
348377
# test one item to determine case handling
349378
item0 = list(self.conditional.values())[0]
350379

@@ -367,70 +396,6 @@ def draw(self, condition):
367396
these = c == condition
368397
N = np.sum(these)
369398

370-
cond = {key: val[c] for (key, val) in self.conditional.items()}
371399
draws[these] = self[c].draw(N)
372400

373401
return draws
374-
375-
376-
class TimeVaryingDiscreteDistribution(Distribution):
377-
"""
378-
This class provides a way to define a discrete distribution that
379-
is conditional on an index.
380-
381-
Wraps a list of discrete distributions.
382-
383-
Parameters
384-
----------
385-
386-
distributions : [DiscreteDistribution]
387-
A list of discrete distributions
388-
389-
seed : int
390-
Seed for random number generator.
391-
"""
392-
393-
distributions = []
394-
395-
def __init__(self, distributions, seed=0):
396-
# Set up the RNG
397-
super().__init__(seed)
398-
399-
self.distributions = distributions
400-
401-
def __getitem__(self, y):
402-
return self.distributions[y]
403-
404-
def draw(self, condition):
405-
"""
406-
Generate arrays of draws.
407-
The input is an array containing the conditions.
408-
The output is an array of the same length (axis 1 dimension)
409-
as the conditions containing random draws of the conditional
410-
distribution.
411-
412-
Parameters
413-
----------
414-
condition : np.array
415-
The input conditions to the distribution.
416-
417-
Returns:
418-
------------
419-
draws : np.array
420-
"""
421-
# for now, assume that all the conditionals
422-
# are of the same type.
423-
# this matches the HARK 'time-varying' model architecture.
424-
425-
# conditions are indices into list
426-
# somewhat convoluted sampling strategy retained
427-
# for test backwards compatibility
428-
draws = np.zeros(condition.size)
429-
430-
for c in np.unique(condition):
431-
these = c == condition
432-
N = np.sum(these)
433-
434-
draws[these] = self.distributions[c].draw(N)
435-
436-
return draws

HARK/distributions/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55
from scipy import stats
66

7-
from HARK.distributions.base import TimeVaryingDiscreteDistribution
7+
from HARK.distributions.base import IndexDistribution
88
from HARK.distributions.discrete import (
99
DiscreteDistribution,
1010
DiscreteDistributionLabeled,
@@ -265,10 +265,10 @@ def add_discrete_outcome_constant_mean(distribution, x, p, sort=False):
265265
Probability associated with each point in array of discrete
266266
points for discrete probability mass function.
267267
"""
268-
if type(distribution) == TimeVaryingDiscreteDistribution:
268+
if isinstance(distribution, IndexDistribution) and hasattr(distribution, "distributions") and distribution.distributions:
269269
# apply recursively on all the internal distributions
270-
return TimeVaryingDiscreteDistribution(
271-
[
270+
return IndexDistribution(
271+
distributions=[
272272
add_discrete_outcome_constant_mean(d, x, p)
273273
for d in distribution.distributions
274274
],

HARK/simulation/monte_carlo.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from HARK.distributions import (
1111
Distribution,
1212
IndexDistribution,
13-
TimeVaryingDiscreteDistribution,
1413
)
1514
from HARK.model import Aggregate
1615
from HARK.model import DBlock
@@ -47,10 +46,7 @@ def draw_shocks(shocks: Mapping[str, Distribution], conditions: Sequence[int]):
4746
draws[shock_var] = np.ones(len(conditions)) * shock
4847
elif isinstance(shock, Aggregate):
4948
draws[shock_var] = shock.dist.draw(1)[0]
50-
elif isinstance(shock, IndexDistribution) or isinstance(
51-
shock, TimeVaryingDiscreteDistribution
52-
):
53-
## TODO his type test is awkward. They should share a superclass.
49+
elif isinstance(shock, IndexDistribution):
5450
draws[shock_var] = shock.draw(conditions)
5551
else:
5652
draws[shock_var] = shock.draw(len(conditions))

0 commit comments

Comments
 (0)