Skip to content

Commit

Permalink
implementation and tests for ancestral sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
naseweisssss committed Nov 4, 2024
1 parent 304fd88 commit 573b3e5
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 24 deletions.
22 changes: 20 additions & 2 deletions src/experimental/ProbabilisticGraphicalModels/bayesnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,28 @@ Ancestral sampling works by:
2. Sampling from each node in order, using the already-sampled parent values for conditional distributions
"""
function ancestral_sampling(bn::BayesianNetwork{V}) where {V}
ordered_vertices = Graphs.topological_sort(bn.graph)
println(ordered_vertices)
ordered_vertices = Graphs.topological_sort_by_dfs(bn.graph)
samples = Dict{V,Any}()

for vertex_id in ordered_vertices
vertex_name = bn.names[vertex_id]

if bn.is_observed[vertex_id]
samples[vertex_name] = bn.values[vertex_name]
continue
end

if bn.is_stochastic[vertex_id]
dist_idx = findfirst(id -> id == vertex_id, bn.stochastic_ids)
samples[vertex_name] = rand(bn.distributions[dist_idx])
else
# deterministic node
parent_ids = Graphs.inneighbors(bn.graph, vertex_id)
parent_values = [samples[bn.names[pid]] for pid in parent_ids]
func_idx = findfirst(id -> id == vertex_id, bn.deterministic_ids)
samples[vertex_name] = bn.deterministic_functions[func_idx](parent_values...)
end
end

return samples
end
Expand Down
36 changes: 14 additions & 22 deletions test/experimental/ProbabilisticGraphicalModels/bayesnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,33 +97,25 @@ using JuliaBUGS.ProbabilisticGraphicalModels:
end

@testset "Simple ancestral sampling" begin
# Create a Bayesian network
bn = BayesianNetwork{Symbol}()

# Add stochastic vertices with their distributions
add_stochastic_vertex!(bn, :A, Normal(0, 1), false) # Stochastic variable A
add_stochastic_vertex!(bn, :B, Normal(0, 1), false) # Stochastic variable B
add_stochastic_vertex!(bn, :C, Normal(0, 1), false) # Stochastic variable C

# Add edges to define relationships
add_edge!(bn, :A, :B) # A -> B
add_edge!(bn, :A, :C) # A -> C

# Perform ancestral sampling

# Add stochastic vertices
add_stochastic_vertex!(bn, :A, Normal(0, 1), false)
add_stochastic_vertex!(bn, :B, Normal(1, 2), false)

# Add deterministic vertex C = A + B
add_deterministic_vertex!(bn, :C, (a, b) -> a + b)
add_edge!(bn, :A, :C)
add_edge!(bn, :B, :C)

samples = ancestral_sampling(bn)

# Debugging: Print samples to see its contents
println("Samples: ", samples)

# Check if all sampled variables are present

@test haskey(samples, :A)
@test haskey(samples, :B)
@test haskey(samples, :C)

# Check if the values are numerical since we are sampling from Normal distributions
@test isa(samples[:A], Number)
@test isa(samples[:B], Number)
@test isa(samples[:C], Number)
@test samples[:A] isa Number
@test samples[:B] isa Number
@test samples[:C] samples[:A] + samples[:B]
end


Expand Down

0 comments on commit 573b3e5

Please sign in to comment.