diff --git a/examples/probability.dx b/examples/probability.dx new file mode 100644 index 000000000..b5433cb91 --- /dev/null +++ b/examples/probability.dx @@ -0,0 +1,517 @@ +' # Differential Probabilistic Inference + +' This notebook develops an unconventional approach to probabilistic +inference using Dex tables and auto-differentiation. It is based +loosely on the approach described in: + +' `A Differential Approach to Inference in Bayesian Networks, Adnan Darwiche (2003)` + +' This approach can be thought of as a probabilistic programming +language (PPL) in the sense the inference is seperated from the +modeling language. However, it does not require interrupting +standard control flow. + +' ## Running Example 1: Coins and Dice + +' Let us start with a simple probabilistic modeling example to establish some notation. + + In this example we have a coin and two weighted dice. We first + flip the coin, if it is heads we roll dice 1 and if it is tails we + roll dice 2. + +Coin = Fin 2 +[tails, heads] = for i:Coin. i + +Dice = Range 1 7 +roll = \i . (i - 1)@ Dice + +None = Fin 1 +nil = 0@None + +coin : Coin =>Float = [0.2, 0.8] +dice_1 : Dice => Float = for i. 1.0 / 6.0 +dice_2 : Dice => Float = [0.5, 0.1, 0.1, 0.1, 0.1, 0.1] + + +' This defines a generative process over two random variables $\mathbf{X} = \{ A, B \} $, the coin flip and the dice roll respectively. We can write the process explicitly as, + +' + $$a \sim Pr(A)$$ + $$ b \sim Pr(B\ |\ A=a) $$ + +' ## Probability Combinators + +' A discrete probability distribution is a normalized table of probabilities + +data Dist variables = AsDist (variables => Float) +def (??) (y:m) (AsDist x: Dist m) : Float = x.y + +' Distributions are easy to create. Here are a couple simple ones. + +def normalize (x: m=>Float) : Dist m = AsDist for i. x.i / sum x +def uniform : Dist m = normalize for i. 1.0 +def delta (x:m) : Dist m = AsDist for i. select ((ordinal x) == (ordinal i)) 1.0 0.0 + +instance Arbitrary (Dist m) + arb = \key. + a = arb key + normalize $ for i. abs a.i + +' And they are displayed by their support. + +def support (AsDist x: Dist m) : List (m & Float) = + concat $ for i. select (x.i > 0.0) (AsList 1 [(i, x.i)]) mempty + +instance [Show m] Show (Dist m) + show = \a. + (AsList _ out) = support a + concat $ for i. "Key: " <> (show $ fst out.i) <> " Prob: " <> (show $ snd out.i) <> "\n" + +instance Show (Range i j) + show = \a . show $ ordinal a + +show $ (delta (4@_)):Dist Dice + + +' Expectations can be taken over arbitrary tables. + +def expect [VSpace out] (AsDist x: Dist m) (y : m => out) : out = + sum for m'. x.m' .* y.m' + +' To represent conditional probabilities such as $ Pr(B \ |\ A)$ we define a type alias. + +def Pr (b:Type) (a:Type): Type = a => Dist b + +' With this machinery we can define distributions for the coin and the dice. + +p_A : Pr Coin None = [AsDist coin] +p_B_A : Pr Dice Coin = [AsDist dice_1, AsDist dice_2] + + +' ## Attempt 1: Observations and Marginalization + +' This allows us to compute the probability of any full observation + from our model. + $$Pr(A=a, B=b) = Pr(B=b\ | A=a) Pr(A=a) $$ + + +def p_AB (a:Coin) (b:Dice) : Float = + (a ?? p_A.nil) * + (b ?? p_B_A.a) + +p_AB heads (roll 6) + + +' However, this assumes that we have full observation + of our variables. What if the coin is latent? This requires a sum. + $$Pr(B) = \sum_a Pr(B\ | A=a) Pr(A=a) $$ + +def p_B (b:Dice) : Float = + sum for a. (a ?? p_A.nil) * + (b ?? p_B_A.a) + +p_B (roll 6) + + +' But now we have two seperate functions for the same model! + This feels unnecessary and bug-prone. + +' ## Attempt 2: Indicator Variables and the Network Polynomial + +' In order to make things simpler we will introduce explicit *indicator variables* $\lambda$ + to the modeling language. + +def Var (a:Type) : Type = a => Float + +' These can either be observed or latent. + If a random variable is observed then we use an indicator. + The expectation over the indicator gives, + +' $$ p(A=a) = E_{a' \sim p(A)} \lambda_{a'} $$ + +' If it is latent the variable is one everywhere. + +def observed (x:a) : Var a = for i. select ((ordinal i) == (ordinal x)) 1.0 0.0 +def latent : Var a = one + +' The probability *chain rule* tells us that we can propagate conditioning. + $$\sum_{b} Pr(A,\ B = b) = Pr(A) \sum_b Pr(B=b\ | A)$$ + +' This implies that the expectation of these indicators factors as well. + +' $$E_{a, b\sim Pr(A,\ B)} \lambda_a \lambda_b = E_{a'\sim Pr(A)}\left[ \lambda_a E_{b' \sim Pr(B | A=a)} [\lambda_b] \right] $$ + +' We can write one step of this chain rule really cleanly in Dex. + +def (~) (lambda:Var a) (pr: Dist a) (fn_a : a => Float) : Float = + expect pr $ for a'. lambda.a' * fn_a.a' + +' This allows us to final write our model down in an extremely clean form. + +' + $$a \sim Pr(A)$$ + $$ b \sim Pr(B\ |\ A=a) $$ + +def coin_flip (a': Var Coin) (b': Var Dice) : Float = + (a' ~ p_A.nil) (for a. + (b' ~ p_B_A.a) one) + +' Now we can easily reproduce all the result above. + +coin_flip (observed heads) (observed (roll 6)) + +coin_flip (latent) (observed (roll 6)) + +coin_flip latent latent + +' This representation for joint distributions is known as the +*network polynomial*. This is a *multi-linear* function that uses +indicator variables to represent data observations. + +' $$ f(\lambda) = \sum_{\mathbf{x}} \prod_{x, \mathbf{u}\in \mathbf{x}} \lambda_x \theta_{x|\mathbf{u}} $$ + +' Here $\theta$ is the model parameters. These play the same role as above. + The $\lambda$ are *evidence indicators* which indicate the states of + the variable instantiations. + +' The *network polynomial* can be used to compute *marginal* probabilities of any + subset of variables. Let $\mathbf{e}$ be the observations of some subset of $\mathbf{X}$. + Darwiche shows that - + +' $$f[\mathbf{e}] = p(\mathbf{E} = \mathbf{e})$$ + +' Where $f[e]$ assigns 1 to any $\lambda$ term that is consistent +(non-contradictory) with $\mathbf{e}$ and 0 otherwise. Let's look at an example. + +' ## Differential Inference + +' The network polynomial is a convenient method for computing probilities, + but what makes it particularly useful is that it allows us to compute + posterior probabilities simply using derivatives. + +' For example, consider the probability on the coin flip given an observation of a + dice roll. We can compute this using Bayes' Rule. + +' $$Pr(A | B=b) \propto Pr(B=b | A) Pr(A)$$ + +normalize $ for a. coin_flip (observed a) (observed (roll 4)) + +' However using the network polynomial we can compute this same term purely with + derivatives. Computing partial derivatives directly yields joint terms. + +' $$\frac{df[\mathbf{e}]}{dx} = Pr(\mathbf{e}, x)$$ + +' This implies that the derivative of the log polynomial +yields posterior terms. + +' $$\frac{d\log f[\mathbf{e}]}{dx} = Pr(x\ |\ \mathbf{e})$$ + +' Let us try this out. We can compute the posterior probabity of + the first coin after observing the second. + +def posterior (f : (Var a) -> Float) : Dist a = + AsDist $ (grad (\ x. log $ f x)) one + + +posterior (\ x . coin_flip x (observed (roll 4)) ) + +' And this yields exactly the term above! This is really neat, it + doesn't require any application of model specific inference. + +' We can generalize this to compute a table of distributions. + +def posteriorTab (f : m => (Var a) -> Float) : m => Dist a = + out = (grad (\ x. log $ f x)) one + for i. AsDist $ out.i + + +' ## Example 2: Bayes Nets + + +' A classic example in probalistic modeling is the Wet grass Bayes' net. + In this example we need to infer the factors that could have led to + the grass being wet. + +' More details on the problem are given [here](https://en.wikipedia.org/wiki/Bayesian_network). + +' ![grass](https://upload.wikimedia.org/wikipedia/commons/thumb/0/0e/SimpleBayesNet.svg/1024px-SimpleBayesNet.svg.png) + + +' + +Rain = {norain : Unit | rain : Unit} +Sprinkler = {nosprinkler : Unit | sprinkler : Unit} +Grass = {notwet : Unit | wet : Unit} +def bernoulli (p: Float) : Dist m = AsDist for i. [1.0 - p, p].((ordinal i)@_) + +' We now define the tables above. + +rain : Pr Rain (Fin 1) = [bernoulli 0.2] +sprinkler : Pr Sprinkler Rain = for r. bernoulli $ case r of + {|norain=()|} -> 0.4 + {|rain=()|} -> 0.01 + +grass : Pr Grass (Sprinkler & Rain) = for (s, r). bernoulli $ + case s of + {|nosprinkler=()|} -> + case r of + {|norain=()|} -> 0.0 + {|rain=()|} -> 0.8 + {|sprinkler=()|} -> + case r of + {|norain=()|} -> 0.9 + {|rain=()|} -> 0.99 + + +' And the architecture of the Bayes net. + +def wet_naive (r' : Var Rain) + (s' : Var Sprinkler) + (g' : Var Grass) : Float = + (r' ~ rain.nil) (for r. + (s' ~ sprinkler.r) (for s. + (g' ~ grass.(s,r)) one)) + +wet_naive (latent) (latent) (observed {|wet=()|}) + +posterior (\x. wet_naive x (latent) (observed {|wet=()|})) + +' ## Example 3: Dice Counting + +' Here's a classic elementary probability problem. Given two + standard dice rolls, what is the probability distribution + over their sum? + +' ![dice](https://qph.fs.quoracdn.net/main-qimg-13d2e066e80c0ac1511e0477c6ffdcb4-c) + +DiceSum = Range 2 13 + +' Helper functions for Dice sum + +def (+@+) (a:a') (b:b') : c = (((ordinal a) + (ordinal b))@_) +def roll_sum (x:Int) : DiceSum = (x - 2)@_ + +def two_dice (dice : Var (Dice & Dice)) (dicesum : Var DiceSum) : Float = + (dice ~ uniform) (for (d1, d2). + (dicesum ~ delta (d1 +@+ d2)) one) + +' Here's the result. + +posterior (\m. two_dice latent m) + +' We might also ask what the probability of the dice rolls given on output value. + +support $ posterior (\m. two_dice m (observed (roll_sum 4))) + +' ## Discussion - Conditional Independence + +' One tricky problem for discrete PPLs is modeling conditional independence. + Models can be very slow to compute if we are not careful to exploint + conditional independence properties such as Markov assumptions. + +' For example, let us consider a more complex version of the coin flip + example. We will flip three times. The choice of the second weighted coin + depends on the first. The choice of third weighted coin depends on the second. + +' + $$a \sim Pr(A)$$ + $$ b \sim Pr(B\ |\ A=a) $$ + $$ c \sim Pr(C\ |\ B=b) $$ + +' In this example $C$ is conditionally independent of $A$ given $B$. + +' We can be lazy and create the distributions randomly. + +coin1 : Pr Coin None = arb $ newKey 1 +coin2 : Pr Coin Coin = arb $ newKey 2 +coin3 : Pr Coin Coin = arb $ newKey 3 + +' Now here is the generative process. + +def coin_flip2 (a': Var Coin) (b': Var Coin) (c': Var Coin) : Float = + (a' ~ coin1.nil) (for a. + (b' ~ coin2.a) (for b. + (c' ~ coin3.b) one)) + +' Note that as written this process looks like it does not take + advantage of the conditional independence property of the model. + The construction of the final coin is in a `for` constructor that + contains `a`. However, Dex knows that `a` is not used in the inner + most construct. In theory it can lift that out of the loop and exploit + the conditional independence. + +' Alternatively we can make this explicit and do the lifting ourselves. + +def coin_flip_opt2 (a': Var Coin) (b': Var Coin) (c': Var Coin) : Float = + final_flip = for b. (c' ~ coin3.b) one + (a' ~ coin1.nil) (for a. + (b' ~ coin2.a) final_flip) + +' ## Example 4: Monty Hall Problem + +' Perhaps the most celebrated elementary problem in conditional + probability is the Monty Hall problem. + +' [Monty Hall](https://en.wikipedia.org/wiki/Monty_Hall_problem) + +' ![Goat](https://upload.wikimedia.org/wikipedia/commons/3/3f/Monty_open_door.svg) + +' You are on a game show. The host asks you to pick a door at random to win a prize. + After selecting a door, one of the remaining doors (without the prize) is removed. + You are asked if you want to change your selection... + +Doors = Fin 3 +YesNo = { no:Unit | yes:Unit} +def yesno (x:Bool) : Dist YesNo = delta $ select x {|yes=()|} {|no=()|} + +' The generative model is relatively simple + 1. We will first sample our pick and the door. + 1. Then we will consider changing our pick. + 1. Finally we will see if we won. + + +def monty_hall (change': Var YesNo) (win': Var YesNo) : Float = + (one ~ uniform) (for (pick, correct): (Doors & Doors). + (change' ~ uniform) (for change. + win_dist = case change of + {|yes=()|} -> yesno (pick /= correct) + {|no=()|} -> yesno (pick == correct) + (win' ~ win_dist) one)) + +' To check the odds we will compute probabity of winning conditioned + on changing. + +{|yes=()|} ?? (posterior $ monty_hall (observed {|yes=()|})) + + +' And compare to proability of winning with no change. + +{|yes=()|} ?? (posterior $ monty_hall (observed {|no=()|})) + + + +' ## Example 5: Hidden Markov Models + +' Finally we conclude with a more complex example. A hidden Markov model is + one of the most widely used discrete time series models. It models the relationship between discrete hidden states $Z$ and emissions $X$. + +Z = Fin 5 +X = Fin 10 + +' It consists of three distributions: initial, transition, and emission. + +initial : Pr Z nil = arb $ newKey 1 +emission : Pr X Z = arb $ newKey 2 +transition : Pr Z Z = arb $ newKey 3 + +' The model itself takes the following form for $m$ steps. +' + $$ z_0 \sim \text{initial}$$ + $$ z_1 \sim \text{transition}(z_0)$$ + $$ x_1 \sim \text{emission}(z_1)$$ + $$ ...$$ + +' This is implemented in reverse order for clarity (backward algorithm). + +def hmm (init': Var Z) (x': m => Var X) (z' : m => Var Z) : Float = + (init' ~ initial.nil) $ yieldState one ( \future . + for i:m. + j = ((size m) - (ordinal i) - 1)@_ + future := for z. + (x'.j ~ emission.z) (for _. + (z'.j ~ transition.z) (get future))) + + +' We can marginalize out over latents. + +hmm (observed (1@_)) (for i:(Fin 2). observed (1@_)) (for i. latent) + + +' Or we can compute the posterior probabilities of specific values. + +posteriorTab $ \z . hmm (observed (1@_)) (for i:(Fin 2). observed (1@_)) z + +' ## Example 5a. HMM Monoid + +' We can also write out an HMM using a Monoid. Here we define a monoid + for square matrix multiplication. + +def MarkovMonoid (a:Type) : Monoid (a => a => Float) = + M = a -- XXX: Typing `Monoid a` below would quantify it over a, which we don't want + named-instance result : Monoid (M => M => Float) + mempty = for m1 m2. select ((ordinal m1) == (ordinal m2)) 1.0 0.0 + mcombine = \m1 m2. for i j. sum for k. m1.i.k * m2.k.j + result + +' We also define a Markov version of our sample function. + Instead of summing out over the usage of its result, + it constructs a matrix a vector. + +def markov (lambda:Var a) (pr: Dist a) : a => Float = + for a'. (a' ?? pr) * lambda.a' + +' Here we write out the HMM using a forward style approach. + Each time through the algorithm the accumulator represents + the matrix of the joint likelihood from position 1 to i. + +def hmm_monoid (init': Var Z) (x': m => Var X) (z' : m => Var Z) : Float = + scores = yieldAccum (MarkovMonoid Z) \ref . + for i:m. + ref += for z:Z. + emit = (x'.i ~ emission.z) one + emit .* (markov z'.i transition.z) + (init' ~ initial.nil) $ for j. sum scores.j + +' At first glance, this seems much less efficient. Above we + had an algorithm that only required $O(Z)$ storage whereas this + requires $O(Z^2)$. In theory this approach can be parallelized + over the intermediate size variable $m$. + +' This should give the same result as before. + +hmm_monoid (observed (1@_)) (for i:(Fin 2). observed (1@_)) (for i. latent) + +' Unfortunately though, the code for monoid's does not yet allow for +auto-differentiation. + +posteriorTab $ \z . hmm (observed (1@_)) (for i:(Fin 2). observed (1@_)) z + +' ## Fancier Distributions + +def without_replacement (y: n=>m) (AsDist x: Dist m) : Dist m = + renorm = sum for n'. x.(y.n') + AsDist $ for m'. + case (any for n'. (ordinal (y.n')) == (ordinal m')) of + False -> (x.m' / (1.0 - renorm)) + True -> 0.0 + + + +' ### Probability Exercises (from Stat 110 textbook) + +' A college has 10 (non-overlapping) time slots for its 10 courses, and blithely assigns +courses to time slots randomly and independently. A student randomly chooses 3 of the +courses to enroll in. What is the probability that there is a conflict in the student’s +schedule? + +Slot = Fin 10 +def courses (conflict:Var YesNo): Float = + (one ~ uniform) (for (i,j,k):(Slot& Slot& Slot). + (conflict ~ yesno ((i == j) || (j == k) || (i == k))) one) + +courses (observed {|yes=()|}) + +' A certain family has 6 children, consisting of 3 boys and 3 girls. Assuming that all +birth orders are equally likely, what is the probability that the 3 eldest children are the +3 girls. + +Children = Fin 6 +def birth (event:Var YesNo): Float = + (one ~ uniform) (for i:Children. + (one ~ without_replacement [i] uniform) (for j:Children. + (one ~ without_replacement [i, j] uniform) (for k:Children. + (event ~ yesno ((ordinal i < 3) && (ordinal j < 3) && (ordinal k < 3))) one))) + +birth (observed {|yes=()|}) + +