Skip to content

Commit 0d51fca

Browse files
Drawing function samples (#24)
* Implement drawing function samples * ColVecs/RowVecs test * Apply suggestions from code review Co-authored-by: willtebbutt <[email protected]> * Efficient multisample generation * Fix test * Non-mutating rand and autodiff tests * Remove `Random.eltype` definition * Add back in-place and clean up tests * test Random.Sampler * Patch bump Co-authored-by: willtebbutt <[email protected]>
1 parent b670ec4 commit 0d51fca

File tree

3 files changed

+117
-1
lines changed

3 files changed

+117
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BayesianLinearRegressors"
22
uuid = "f579363c-4606-5e5c-a623-c4549f609c4b"
33
authors = ["Will Tebbutt <[email protected]>"]
4-
version = "0.3.3"
4+
version = "0.3.4"
55

66
[deps]
77
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"

src/bayesian_linear_regression.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,34 @@ function __compute_inference_quantities(fx::FiniteBLR, y::AbstractVector{<:Real}
7373

7474
return Uw, Bt, δy, logpdf_δy, Λεy
7575
end
76+
77+
# Random function sample generation
78+
# Following the Random API: https://docs.julialang.org/en/v1/stdlib/Random/#Hooking-into-the-Random-API
79+
struct BLRFunctionSample{Tw<:AbstractVector}
80+
w::Tw
81+
end
82+
83+
(s::BLRFunctionSample)(X::AbstractMatrix{<:Real}) = X's.w
84+
(s::BLRFunctionSample)(X::ColVecs) = X.X's.w
85+
(s::BLRFunctionSample)(X::RowVecs) = X.X * s.w
86+
87+
Random.Sampler(::Type{<:AbstractRNG}, blr::BayesianLinearRegressor, ::Random.Repetition) = blr
88+
89+
function Random.rand(rng::AbstractRNG, blr::BayesianLinearRegressor)
90+
w = blr.mw .+ _cholesky(blr.Λw).U \ randn(rng, size(blr.mw))
91+
return BLRFunctionSample(w)
92+
end
93+
94+
function Random.rand(rng::AbstractRNG, blr::BayesianLinearRegressor, dims::Dims)
95+
ws = blr.mw .+ _cholesky(blr.Λw).U \ randn(rng, (only(size(blr.mw)), prod(dims)))
96+
bs = [BLRFunctionSample(w) for w in eachcol(ws)]
97+
return reshape(bs, dims)
98+
end
99+
100+
function Random.rand!(rng::AbstractRNG, A::AbstractArray{<:BLRFunctionSample}, blr::BayesianLinearRegressor)
101+
ws = blr.mw .+ _cholesky(blr.Λw).U \ randn(rng, (only(size(blr.mw)), prod(size(A))))
102+
for i in LinearIndices(A)
103+
@inbounds A[i] = BLRFunctionSample(ws[:,i])
104+
end
105+
return A
106+
end

test/bayesian_linear_regression.jl

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,4 +121,89 @@ end
121121
@test cov(f′(X′, Σy)) cov(f′2(X′, Σy))
122122
end
123123
end
124+
125+
@testset "sampling functions" begin
126+
rng, N, D = MersenneTwister(123456), 11, 5
127+
X, f, Σy = generate_toy_problem(rng, N, D)
128+
129+
g = rand(rng, f)
130+
@test g(X) == g(X) # check the sample doesn't change between evaluations
131+
132+
Xc = ColVecs(X)
133+
Xr = RowVecs(X')
134+
@test g(X) == g(Xc)
135+
@test g(X) == g(Xr)
136+
137+
# test the Random interface
138+
@test rand(rng, Random.Sampler(rng, f, Val(Inf))) isa BayesianLinearRegressors.BLRFunctionSample
139+
140+
samples1, samples2 = 10_000, 1000
141+
samples = samples1 * samples2
142+
gs = rand(rng, f, samples1, samples2)
143+
@test size(gs) == (samples1, samples2)
144+
145+
# test statistical properties of the sampled functions
146+
let
147+
Y = reduce(hcat, map(h -> h(X), reshape(gs, :)))
148+
m_empirical = mean(Y; dims = 2)
149+
Σ_empirical = (Y .- mean(Y; dims = 2)) * (Y .- mean(Y; dims = 2))' ./ samples
150+
@test mean(f(X, Σy)) m_empirical atol = 1e-3 rtol = 1e-3
151+
@test cov(f(X, Σy)) Σ_empirical + Σy atol = 1e-3 rtol = 1e-3
152+
end
153+
154+
# test statistical properties of in-place rand
155+
let
156+
A = Array{BayesianLinearRegressors.BLRFunctionSample,2}(
157+
undef,
158+
samples1,
159+
samples2,
160+
)
161+
A = rand!(rng, A, f)
162+
Y = reduce(hcat, map(h -> h(X), reshape(gs, :)))
163+
m_empirical = mean(Y; dims = 2)
164+
Σ_empirical = (Y .- mean(Y; dims = 2)) * (Y .- mean(Y; dims = 2))' ./ samples
165+
@test mean(f(X, Σy)) m_empirical atol = 1e-3 rtol = 1e-3
166+
@test cov(f(X, Σy)) Σ_empirical + Σy atol = 1e-3 rtol = 1e-3
167+
end
168+
169+
@testset "Zygote (everything dense)" begin
170+
function test_rand_funcs_adjoints(sample_function)
171+
rng, N, D = MersenneTwister(123456), 11, 5
172+
X, f, _ = generate_toy_problem(rng, N, D)
173+
mw, A_Λw = f.mw, 0.1 .* randn(rng, D, D)
174+
175+
# Run the model forwards and check that output agrees with non-Zygote output.
176+
z, back = Zygote.pullback(sample_function, X, mw, A_Λw)
177+
@test z == sample_function(X, mw, A_Λw)
178+
179+
# Compute adjoints using Zygote.
180+
= randn(rng, size(z))
181+
dX, dmw, dA_Λw = back(z̄)
182+
183+
# Verify adjoints via finite differencing.
184+
fdm = central_fdm(5, 1)
185+
@test dX first(j′vp(fdm, X -> sample_function(X, mw, A_Λw), z̄, X))
186+
@test dmw first(j′vp(fdm, mw -> sample_function(X, mw, A_Λw), z̄, mw))
187+
@test dA_Λw
188+
first(j′vp(fdm, A_Λw -> sample_function(X, mw, A_Λw), z̄, A_Λw))
189+
end
190+
191+
function rand_funcs_single(X, mw, A_Λw)
192+
Λw = Symmetric(A_Λw * A_Λw' + I)
193+
f = BayesianLinearRegressor(mw, Λw)
194+
g = rand(MersenneTwister(123456), f)
195+
return g(X)
196+
end
197+
198+
function rand_funcs_multi(X, mw, A_Λw)
199+
Λw = Symmetric(A_Λw * A_Λw' + I)
200+
f = BayesianLinearRegressor(mw, Λw)
201+
gs = rand(MersenneTwister(123456), f, 1, 1)
202+
return reduce(hcat, map(h -> h(X), reshape(gs, :)))
203+
end
204+
205+
test_rand_funcs_adjoints(rand_funcs_single)
206+
test_rand_funcs_adjoints(rand_funcs_multi)
207+
end
208+
end
124209
end

0 commit comments

Comments
 (0)