|
121 | 121 | @test cov(f′(X′, Σy)) ≈ cov(f′2(X′, Σy))
|
122 | 122 | end
|
123 | 123 | end
|
| 124 | + |
| 125 | + @testset "sampling functions" begin |
| 126 | + rng, N, D = MersenneTwister(123456), 11, 5 |
| 127 | + X, f, Σy = generate_toy_problem(rng, N, D) |
| 128 | + |
| 129 | + g = rand(rng, f) |
| 130 | + @test g(X) == g(X) # check the sample doesn't change between evaluations |
| 131 | + |
| 132 | + Xc = ColVecs(X) |
| 133 | + Xr = RowVecs(X') |
| 134 | + @test g(X) == g(Xc) |
| 135 | + @test g(X) == g(Xr) |
| 136 | + |
| 137 | + # test the Random interface |
| 138 | + @test rand(rng, Random.Sampler(rng, f, Val(Inf))) isa BayesianLinearRegressors.BLRFunctionSample |
| 139 | + |
| 140 | + samples1, samples2 = 10_000, 1000 |
| 141 | + samples = samples1 * samples2 |
| 142 | + gs = rand(rng, f, samples1, samples2) |
| 143 | + @test size(gs) == (samples1, samples2) |
| 144 | + |
| 145 | + # test statistical properties of the sampled functions |
| 146 | + let |
| 147 | + Y = reduce(hcat, map(h -> h(X), reshape(gs, :))) |
| 148 | + m_empirical = mean(Y; dims = 2) |
| 149 | + Σ_empirical = (Y .- mean(Y; dims = 2)) * (Y .- mean(Y; dims = 2))' ./ samples |
| 150 | + @test mean(f(X, Σy)) ≈ m_empirical atol = 1e-3 rtol = 1e-3 |
| 151 | + @test cov(f(X, Σy)) ≈ Σ_empirical + Σy atol = 1e-3 rtol = 1e-3 |
| 152 | + end |
| 153 | + |
| 154 | + # test statistical properties of in-place rand |
| 155 | + let |
| 156 | + A = Array{BayesianLinearRegressors.BLRFunctionSample,2}( |
| 157 | + undef, |
| 158 | + samples1, |
| 159 | + samples2, |
| 160 | + ) |
| 161 | + A = rand!(rng, A, f) |
| 162 | + Y = reduce(hcat, map(h -> h(X), reshape(gs, :))) |
| 163 | + m_empirical = mean(Y; dims = 2) |
| 164 | + Σ_empirical = (Y .- mean(Y; dims = 2)) * (Y .- mean(Y; dims = 2))' ./ samples |
| 165 | + @test mean(f(X, Σy)) ≈ m_empirical atol = 1e-3 rtol = 1e-3 |
| 166 | + @test cov(f(X, Σy)) ≈ Σ_empirical + Σy atol = 1e-3 rtol = 1e-3 |
| 167 | + end |
| 168 | + |
| 169 | + @testset "Zygote (everything dense)" begin |
| 170 | + function test_rand_funcs_adjoints(sample_function) |
| 171 | + rng, N, D = MersenneTwister(123456), 11, 5 |
| 172 | + X, f, _ = generate_toy_problem(rng, N, D) |
| 173 | + mw, A_Λw = f.mw, 0.1 .* randn(rng, D, D) |
| 174 | + |
| 175 | + # Run the model forwards and check that output agrees with non-Zygote output. |
| 176 | + z, back = Zygote.pullback(sample_function, X, mw, A_Λw) |
| 177 | + @test z == sample_function(X, mw, A_Λw) |
| 178 | + |
| 179 | + # Compute adjoints using Zygote. |
| 180 | + z̄ = randn(rng, size(z)) |
| 181 | + dX, dmw, dA_Λw = back(z̄) |
| 182 | + |
| 183 | + # Verify adjoints via finite differencing. |
| 184 | + fdm = central_fdm(5, 1) |
| 185 | + @test dX ≈ first(j′vp(fdm, X -> sample_function(X, mw, A_Λw), z̄, X)) |
| 186 | + @test dmw ≈ first(j′vp(fdm, mw -> sample_function(X, mw, A_Λw), z̄, mw)) |
| 187 | + @test dA_Λw ≈ |
| 188 | + first(j′vp(fdm, A_Λw -> sample_function(X, mw, A_Λw), z̄, A_Λw)) |
| 189 | + end |
| 190 | + |
| 191 | + function rand_funcs_single(X, mw, A_Λw) |
| 192 | + Λw = Symmetric(A_Λw * A_Λw' + I) |
| 193 | + f = BayesianLinearRegressor(mw, Λw) |
| 194 | + g = rand(MersenneTwister(123456), f) |
| 195 | + return g(X) |
| 196 | + end |
| 197 | + |
| 198 | + function rand_funcs_multi(X, mw, A_Λw) |
| 199 | + Λw = Symmetric(A_Λw * A_Λw' + I) |
| 200 | + f = BayesianLinearRegressor(mw, Λw) |
| 201 | + gs = rand(MersenneTwister(123456), f, 1, 1) |
| 202 | + return reduce(hcat, map(h -> h(X), reshape(gs, :))) |
| 203 | + end |
| 204 | + |
| 205 | + test_rand_funcs_adjoints(rand_funcs_single) |
| 206 | + test_rand_funcs_adjoints(rand_funcs_multi) |
| 207 | + end |
| 208 | + end |
124 | 209 | end
|
0 commit comments