Skip to content

Commit

Permalink
add more complex tests
Browse files Browse the repository at this point in the history
  • Loading branch information
naseweisssss committed Nov 5, 2024
1 parent 1d0ed43 commit bc2fccf
Showing 1 changed file with 59 additions and 0 deletions.
59 changes: 59 additions & 0 deletions test/experimental/ProbabilisticGraphicalModels/bayesnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,5 +114,64 @@ using JuliaBUGS.ProbabilisticGraphicalModels:
@test samples[:C] samples[:A] + samples[:B]
end

@testset "Complex ancestral sampling" begin
bn = BayesianNetwork{Symbol}()

add_stochastic_vertex!(bn, , Normal(0, 2), false)
add_stochastic_vertex!(bn, , LogNormal(0, 0.5), false)
add_stochastic_vertex!(bn, :X, Normal(0, 1), false)
add_stochastic_vertex!(bn, :Y, Normal(0, 1), false)

add_deterministic_vertex!(bn, :X_scaled, (μ, σ, x) -> x * σ + μ)
add_deterministic_vertex!(bn, :Y_scaled, (μ, σ, y) -> y * σ + μ)
add_deterministic_vertex!(bn, :Sum, (x, y) -> x + y)
add_deterministic_vertex!(bn, :Product, (x, y) -> x * y)
add_deterministic_vertex!(bn, :N, () -> 2.0)
add_deterministic_vertex!(bn, :Mean, (s, n) -> s / n)

add_edge!(bn, , :X_scaled)
add_edge!(bn, , :X_scaled)
add_edge!(bn, :X, :X_scaled)

add_edge!(bn, , :Y_scaled)
add_edge!(bn, , :Y_scaled)
add_edge!(bn, :Y, :Y_scaled)

add_edge!(bn, :X_scaled, :Sum)
add_edge!(bn, :Y_scaled, :Sum)

add_edge!(bn, :X_scaled, :Product)
add_edge!(bn, :Y_scaled, :Product)

add_edge!(bn, :Sum, :Mean)
add_edge!(bn, :N, :Mean)

samples = ancestral_sampling(bn)

@test all(haskey(samples, k) for k in [, , :X, :Y, :X_scaled, :Y_scaled, :Sum, :Product, :Mean, :N])

@test all(samples[k] isa Number for k in keys(samples))

@test samples[:X_scaled] samples[:X] * samples[] + samples[]
@test samples[:Y_scaled] samples[:Y] * samples[] + samples[]
@test samples[:Sum] samples[:X_scaled] + samples[:Y_scaled]
@test samples[:Product] samples[:X_scaled] * samples[:Y_scaled]
@test samples[:Mean] samples[:Sum] / samples[:N]
@test samples[:N] 2.0

@test samples[] > 0

# Multiple samples test
n_samples = 1000
means = zeros(n_samples)
for i in 1:n_samples
samples = ancestral_sampling(bn)
means[i] = samples[:Mean]
end

@test mean(means) 0 atol=0.5
@test std(means) > 0
end

@testset "Bayes Ball" begin end
end

0 comments on commit bc2fccf

Please sign in to comment.