diff --git a/examples/kernelregression.dx b/examples/kernelregression.dx index 0135bd1f8..760149aa0 100644 --- a/examples/kernelregression.dx +++ b/examples/kernelregression.dx @@ -1,6 +1,7 @@ '# Kernel Regression import linalg +import stats import plot struct ConjGradState(a|VSpace) = @@ -43,7 +44,7 @@ The optimal coefficients are found by solving a linear system $\alpha=G^{-1}y$,\ where $G_{ij}:=k(x_i, x_j)+\delta_{ij}\lambda$, $\lambda>0$ and $y = (y_1,\dots,y_N)^\top\in\mathbb R^N$ -- Synthetic data -Nx = Fin 100 +Nx = Fin 20 noise = 0.1 [k1, k2] = split_key (new_key 0) @@ -69,9 +70,15 @@ Nxtest = Fin 1000 xtest : Nxtest=>Float = for i. rand (ixkey k1 i) preds = map predict xtest +-- True function. +:html show_plot $ xy_plot xtest (map trueFun xtest) +> + +-- Observed values. :html show_plot $ xy_plot xs ys > +-- Ridge regression prediction. :html show_plot $ xy_plot xtest preds > @@ -83,29 +90,44 @@ with the Bayes rule, gives the variance of the prediction. ' In this implementation, the conjugate gradient solver is replaced with the cholesky solver from `lib/linalg.dx` for efficiency. -def gp_regress( - kernel: (a, a) -> Float, - xs: n=>a, - ys: n=>Float - ) -> ((a) -> (Float, Float)) given (n|Ix, a) = - noise_var = 0.0001 +def gp( + kernel: (a, a) -> Float, + xs: n=>a, + mean_fn: (a) -> Float, + noise_var: Float +) -> MultivariateNormal n given (n|Ix, a) = gram = for i j. kernel xs[i] xs[j] - c = chol (gram + eye *. noise_var) - alpha = chol_solve c ys - predict = \x. - k' = for i. kernel xs[i] x - mu = sum for i. alpha[i] * k'[i] - alpha' = chol_solve c k' - var = kernel x x + noise_var - sum for i. k'[i] * alpha'[i] - (mu, var) - predict + loc = for i. mean_fn xs[i] + chol_cov = chol (gram + eye *. noise_var) + MultivariateNormal loc chol_cov -gp_predict = gp_regress (\x y. rbf 0.2 x y) xs ys - -(gp_preds, vars) = unzip (map gp_predict xtest) - -:html show_plot $ xyc_plot xtest gp_preds (map sqrt vars) +def gp_regress( + kernel: (a, a) -> Float, + xs: n=>a, + ys: n=>Float, + mean_fn: (a) -> Float, + noise_var: Float +) -> ((m=>a) -> MultivariateNormal m) given (n|Ix, m|Ix, a) = + prior_gp = gp kernel xs mean_fn noise_var + gram_obs_inv_y = chol_solve prior_gp.chol_cov (ys - prior_gp.loc) + predictive_gp_fn = \xs_pred:m=>a. + gram_pred_obs = for i j. kernel xs_pred[i] xs[j] + loc = gram_pred_obs **. gram_obs_inv_y + (for i. mean_fn xs_pred[i]) + gram_pred = (for i j. kernel xs_pred[i] xs_pred[j]) + eye *. noise_var + gram_obs_inv_gram_pred_obs = for i. chol_solve prior_gp.chol_cov gram_pred_obs[i] + schur = gram_pred - gram_obs_inv_gram_pred_obs ** (transpose gram_pred_obs) + MultivariateNormal loc (chol schur) + predictive_gp_fn + +def mean_fn(x:Float) -> Float = 0. * x +gp_predict_fn : (Nxtest=>Float) -> MultivariateNormal Nxtest = gp_regress (\x y. rbf 0.2 x y) xs ys mean_fn 0.0001 +gp_predict_dist = gp_predict_fn xtest +var_pred = for i. vdot gp_predict_dist.chol_cov[i] gp_predict_dist.chol_cov[i] + +-- GP posterior predictive mean, colored by variance. +:html show_plot $ xyc_plot xtest gp_predict_dist.loc (map sqrt var_pred) > -:html show_plot $ xy_plot xtest vars +-- Posterior predictive variance. +:html show_plot $ xy_plot xtest var_pred > diff --git a/lib/stats.dx b/lib/stats.dx index e4e688ccb..00ea01ea0 100644 --- a/lib/stats.dx +++ b/lib/stats.dx @@ -1,6 +1,8 @@ '# Stats Probability distributions and other functions useful for statistical computing. +import linalg + '## Log-space floating point numbers When working with probability densities, mass functions, distributions, likelihoods, etc., we often work on a logarithmic scale to prevent floating @@ -333,6 +335,26 @@ instance OrderedDist(Uniform, Float, Float) def quantile(d, q) = d.low + ((d.high - d.low) * q) +'## Multivariate probability distributions +Some commonly encountered multivariate distributions. +### Multivariate Normal distribution +The [Multivariate Normal distribution](https://en.wikipedia.org/wiki/Multivariate_normal_distribution) is parameterised by its *mean*, `loc`, and Cholesky-factored `scale`. + +struct MultivariateNormal(n|Ix) = + loc : (n=>Float) + chol_cov : LowerTriMat n Float + +instance Random(MultivariateNormal(n), n=>Float) given (n|Ix) + def draw(d, k) = + std_norm = for i:n. randn (ixkey k i) + d.loc + for i:n. sum(for j:(..i). d.chol_cov[i, j] * std_norm[inject j]) + +instance Dist(MultivariateNormal(n), n=>Float, Float) given (n|Ix) + def density(d, x) = + y = forward_substitute d.chol_cov (x - d.loc) + dim = n_to_f (size n) + Exp (-(dim / 2) * log (2 * pi) - sum(log (lower_tri_diag d.chol_cov)) - 0.5 * dot y y) + '## Data summaries Some data summary functions. Note that `mean` is provided by the prelude. diff --git a/tests/stats-tests.dx b/tests/stats-tests.dx index 3f29c1999..cd35a2868 100644 --- a/tests/stats-tests.dx +++ b/tests/stats-tests.dx @@ -245,6 +245,34 @@ quantile (Uniform 2.0 5.0) 0.2 ~~ 2.6 rand_vec 5 (\k. draw (Uniform 2.0 5.0) k) (new_key 0) :: Fin 5=>Float > [4.610805, 2.740888, 2.510233, 3.040717, 3.731907] +-- multivariate normal + +draw (Uniform 2.0 5.0) (new_key 0) :: Float + +chol_cov_mat : Fin 2=>Fin 2=>Float = [[0.2, 0.], [-0.3, 0.1]] + +> Compiler bug! +> > Please report this at github.com/google-research/dex-lang/i +> > +> > Unexpected table: chol_co.1 +> > CallStack (from HasCallStack): +> > error, called at src/lib/Simplify.hs:571:22 in dex-0.1.0. +-- chol_cov = for i:(Fin 2). for j:(..i). chol_cov_mat[i, inject(to=Fin 2, j)] +-- chol_cov = for i:(Fin 2). for j:(..i). chol_cov_mat[i, (ordinal j)@_] +-- chol_cov = for i:(Fin 2). for j:(..i). chol_cov_mat[i, inject j] + +-- I think this used to work. +> > Type error: +> > Expected: (RangeTo (Fin 2) 0) +> > Actual: (Fin 1) +chol_cov : (i:Fin 2)=>(..i)=>Float = [[0.2], [-0.3, 0.1]] +loc : (Fin 2=>Float) = [1., 2.] +draw (MultivariateNormal loc chol_cov) (new_key 0) :: (Fin 2=>Float) +> [0.706645, 2.599938] + +ln (density (MultivariateNormal [1., 1] chol_cov) [0.5, 0.5]) ~~ -79.1758 +> True + -- data summaries