diff --git a/src/experimental/ProbabilisticGraphicalModels/bayesnet.jl b/src/experimental/ProbabilisticGraphicalModels/bayesnet.jl index 076d7a671..dc51241bd 100644 --- a/src/experimental/ProbabilisticGraphicalModels/bayesnet.jl +++ b/src/experimental/ProbabilisticGraphicalModels/bayesnet.jl @@ -166,10 +166,28 @@ Ancestral sampling works by: 2. Sampling from each node in order, using the already-sampled parent values for conditional distributions """ function ancestral_sampling(bn::BayesianNetwork{V}) where {V} - ordered_vertices = Graphs.topological_sort(bn.graph) - println(ordered_vertices) + ordered_vertices = Graphs.topological_sort_by_dfs(bn.graph) samples = Dict{V,Any}() + for vertex_id in ordered_vertices + vertex_name = bn.names[vertex_id] + + if bn.is_observed[vertex_id] + samples[vertex_name] = bn.values[vertex_name] + continue + end + + if bn.is_stochastic[vertex_id] + dist_idx = findfirst(id -> id == vertex_id, bn.stochastic_ids) + samples[vertex_name] = rand(bn.distributions[dist_idx]) + else + # deterministic node + parent_ids = Graphs.inneighbors(bn.graph, vertex_id) + parent_values = [samples[bn.names[pid]] for pid in parent_ids] + func_idx = findfirst(id -> id == vertex_id, bn.deterministic_ids) + samples[vertex_name] = bn.deterministic_functions[func_idx](parent_values...) + end + end return samples end diff --git a/test/experimental/ProbabilisticGraphicalModels/bayesnet.jl b/test/experimental/ProbabilisticGraphicalModels/bayesnet.jl index b98f8b036..12825dad0 100644 --- a/test/experimental/ProbabilisticGraphicalModels/bayesnet.jl +++ b/test/experimental/ProbabilisticGraphicalModels/bayesnet.jl @@ -97,33 +97,25 @@ using JuliaBUGS.ProbabilisticGraphicalModels: end @testset "Simple ancestral sampling" begin - # Create a Bayesian network bn = BayesianNetwork{Symbol}() - - # Add stochastic vertices with their distributions - add_stochastic_vertex!(bn, :A, Normal(0, 1), false) # Stochastic variable A - add_stochastic_vertex!(bn, :B, Normal(0, 1), false) # Stochastic variable B - add_stochastic_vertex!(bn, :C, Normal(0, 1), false) # Stochastic variable C - - # Add edges to define relationships - add_edge!(bn, :A, :B) # A -> B - add_edge!(bn, :A, :C) # A -> C - - # Perform ancestral sampling + + # Add stochastic vertices + add_stochastic_vertex!(bn, :A, Normal(0, 1), false) + add_stochastic_vertex!(bn, :B, Normal(1, 2), false) + + # Add deterministic vertex C = A + B + add_deterministic_vertex!(bn, :C, (a, b) -> a + b) + add_edge!(bn, :A, :C) + add_edge!(bn, :B, :C) + samples = ancestral_sampling(bn) - - # Debugging: Print samples to see its contents - println("Samples: ", samples) - - # Check if all sampled variables are present + @test haskey(samples, :A) @test haskey(samples, :B) @test haskey(samples, :C) - - # Check if the values are numerical since we are sampling from Normal distributions - @test isa(samples[:A], Number) - @test isa(samples[:B], Number) - @test isa(samples[:C], Number) + @test samples[:A] isa Number + @test samples[:B] isa Number + @test samples[:C] ≈ samples[:A] + samples[:B] end