Skip to content

Commit 9c20e07

Browse files
author
Will Tebbutt
authored
Wct/fix bfr (#38)
* Fix BFR posterior * Add comment for future readers * Bump patch version * Make tests more lightweight * Fix typo
1 parent 7d5607d commit 9c20e07

File tree

4 files changed

+50
-10
lines changed

4 files changed

+50
-10
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BayesianLinearRegressors"
22
uuid = "f579363c-4606-5e5c-a623-c4549f609c4b"
33
authors = ["Will Tebbutt <[email protected]>"]
4-
version = "0.3.7"
4+
version = "0.3.8"
55

66
[deps]
77
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"

src/basis_function_regression.jl

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,30 @@ struct BasisFunctionRegressor{Tblr<:BayesianLinearRegressor,Tϕ} <: AbstractGP
3636
ϕ::Tϕ
3737
end
3838

39-
function (bfblr::BasisFunctionRegressor)(x::AbstractVector, args...)
40-
return bfblr.blr(bfblr.ϕ(x), args...)
39+
const FiniteBFR = FiniteGP{<:BasisFunctionRegressor}
40+
41+
_to_finite_blr(fx::FiniteBFR) = fx.f.blr(fx.f.ϕ(fx.x), fx.Σy)
42+
43+
# All functionality below just implements the primary and secondary AbstractGPs APIs.
44+
# See AbstractGPs.jl's documentation for information regarding their semantics.
45+
46+
AbstractGPs.mean(fx::FiniteBFR) = mean(_to_finite_blr(fx))
47+
48+
AbstractGPs.cov(fx::FiniteBFR) = cov(_to_finite_blr(fx))
49+
50+
AbstractGPs.var(fx::FiniteBFR) = var(_to_finite_blr(fx))
51+
52+
AbstractGPs.mean_and_cov(fx::FiniteBFR) = mean_and_cov(_to_finite_blr(fx))
53+
54+
AbstractGPs.mean_and_var(fx::FiniteBFR) = mean_and_var(_to_finite_blr(fx))
55+
56+
function AbstractGPs.rand(rng::AbstractRNG, fx::FiniteBFR, samples::Int)
57+
return rand(rng, _to_finite_blr(fx), samples)
58+
end
59+
60+
AbstractGPs.logpdf(fx::FiniteBFR, y::AbstractVector{<:Real}) = logpdf(_to_finite_blr(fx), y)
61+
62+
function AbstractGPs.posterior(fx::FiniteBFR, y::AbstractVector{<:Real})
63+
f_post = posterior(_to_finite_blr(fx), y)
64+
return BasisFunctionRegressor(f_post, fx.f.ϕ)
4165
end

test/basis_function_regression.jl

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
@testset "basis_function_regression" begin
22
@testset "basis_func_blr $Tx" for Tx in [Matrix, ColVecs, RowVecs]
3-
@testset "consistency" begin
3+
@testset "self-consistency" begin
44
rng, N, D, samples = MersenneTwister(123456), 11, 2, 1_000_000
55
X, f, Σy = generate_toy_problem(rng, N, D, Tx)
66

@@ -10,8 +10,24 @@
1010
rng, f_bf(X, Σy)
1111
)
1212
end
13+
@testset "consistency with BLR" begin
14+
rng = MersenneTwister(123456)
15+
N = 11
16+
D = 2
17+
X, f, Σy = generate_toy_problem(rng, N, D, Tx)
18+
f_bf = BasisFunctionRegressor(f, ϕ)
19+
20+
# Compute logpdf using both the BLR and BFR. Should agree.
21+
y = rand(rng, f_bf(X, Σy))
22+
@test logpdf(f(ϕ(X), Σy), y) logpdf(f_bf(X, Σy), y)
23+
24+
# Check that posteriors agree.
25+
f_bf_post = posterior(f_bf(X, Σy), y)
26+
f_post = posterior(f(ϕ(X), Σy), y)
27+
@test mean(f_bf_post(X)) mean(f_post(ϕ(X)))
28+
end
1329
@testset "rand" begin
14-
rng, N, D, samples = MersenneTwister(123456), 11, 2, 10_000_000
30+
rng, N, D, samples = MersenneTwister(123456), 11, 2, 1_000_000
1531
X, f, Σy = generate_toy_problem(rng, N, D, Tx)
1632

1733
f_bf = BasisFunctionRegressor(f, ϕ)
@@ -20,8 +36,8 @@
2036
Y = rand(rng, f_bf(X, Σy), samples)
2137
m_empirical = mean(Y; dims=2)
2238
Σ_empirical = (Y .- mean(Y; dims=2)) * (Y .- mean(Y; dims=2))' ./ samples
23-
@test mean(f_bf(X, Σy)) m_empirical atol = 1e-3 rtol = 1e-3
24-
@test cov(f_bf(X, Σy)) Σ_empirical atol = 1e-3 rtol = 1e-3
39+
@test mean(f_bf(X, Σy)) m_empirical atol = 1e-2 rtol = 1e-2
40+
@test cov(f_bf(X, Σy)) Σ_empirical atol = 1e-2 rtol = 1e-2
2541
end
2642
end
2743
end

test/bayesian_linear_regression.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@
99
)
1010
end
1111
@testset "rand" begin
12-
rng, N, D, samples = MersenneTwister(123456), 11, 3, 10_000_000
12+
rng, N, D, samples = MersenneTwister(123456), 11, 3, 1_000_000
1313
X, f, Σy = generate_toy_problem(rng, N, D, Tx)
1414

1515
# Roughly test the statistical properties of rand.
1616
Y = rand(rng, f(X, Σy), samples)
1717
m_empirical = mean(Y; dims=2)
1818
Σ_empirical = (Y .- mean(Y; dims=2)) * (Y .- mean(Y; dims=2))' ./ samples
19-
@test mean(f(X, Σy)) m_empirical atol = 1e-3 rtol = 1e-3
20-
@test cov(f(X, Σy)) Σ_empirical atol = 1e-3 rtol = 1e-3
19+
@test mean(f(X, Σy)) m_empirical atol = 1e-2 rtol = 1e-2
20+
@test cov(f(X, Σy)) Σ_empirical atol = 1e-2 rtol = 1e-2
2121

2222
@testset "Zygote (everything dense)" begin
2323
function rand_blr(X, A_Σy, mw, A_Λw)

0 commit comments

Comments
 (0)