Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ancestral Sampling and Bayes Ball Algorithm #233

Merged
merged 35 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
304fd88
change SimpleGraph to SimpleDiGraph for topological sort and added te…
naseweisssss Nov 4, 2024
573b3e5
implementation and tests for ancestral sampling
naseweisssss Nov 4, 2024
62c07ed
Update src/experimental/ProbabilisticGraphicalModels/bayesnet.jl
naseweisssss Nov 4, 2024
344b4f1
Update src/experimental/ProbabilisticGraphicalModels/bayesnet.jl
naseweisssss Nov 4, 2024
0bb810a
Update test/experimental/ProbabilisticGraphicalModels/bayesnet.jl
naseweisssss Nov 4, 2024
6ce7559
Update test/experimental/ProbabilisticGraphicalModels/bayesnet.jl
naseweisssss Nov 4, 2024
c385c37
Update test/experimental/ProbabilisticGraphicalModels/bayesnet.jl
naseweisssss Nov 4, 2024
1d0ed43
Apply suggestions from code review
naseweisssss Nov 4, 2024
bc2fccf
add more complex tests
naseweisssss Nov 5, 2024
102069d
Apply suggestions from code review
naseweisssss Nov 5, 2024
503910f
add wrapper functions
naseweisssss Nov 5, 2024
593d18e
I dont think this works, test still fails
naseweisssss Nov 6, 2024
6f2bbcf
less test failings, but still not good
naseweisssss Nov 6, 2024
8ffba3b
linting
naseweisssss Nov 6, 2024
2c91dac
Merge branch 'master' into rylin/bayesnet_implementations
sunxd3 Nov 8, 2024
8339b0e
formatting
sunxd3 Nov 8, 2024
25308ff
add documents on return values
naseweisssss Nov 15, 2024
35c1c3f
TDD add test for VE algorithsm
naseweisssss Nov 15, 2024
1a83c12
added a sample implementation of variable elimination
naseweisssss Nov 15, 2024
c387813
formatting
naseweisssss Nov 15, 2024
deaf99c
temporary pass test
naseweisssss Nov 16, 2024
830c208
formatting
naseweisssss Nov 16, 2024
cae8914
added test case for the corner case X/Y in Z
naseweisssss Nov 17, 2024
966e4d1
Apply suggestions from code review
naseweisssss Nov 17, 2024
f9347bc
Improve `show` function of BUGSModel (#236)
sunxd3 Nov 14, 2024
a4cbbba
Permit dot call (like `Distributions.Normal`) to be used in model def…
sunxd3 Nov 16, 2024
1ca9933
remove comments
naseweisssss Nov 17, 2024
e229cac
Merge branch 'master' into rylin/bayesnet_implementations
sunxd3 Nov 18, 2024
34a17b2
remove conditioned descendant
naseweisssss Nov 18, 2024
0921f2e
remove VE
naseweisssss Nov 18, 2024
43027c5
extended bayes ball to have X and Y as a vector
naseweisssss Nov 19, 2024
a0ea8e9
removing debugging statement
naseweisssss Nov 19, 2024
a318f60
fix formating
sunxd3 Nov 20, 2024
121d412
formatting
sunxd3 Nov 20, 2024
0cff1db
Merge branch 'master' into rylin/bayesnet_implementations
sunxd3 Nov 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 105 additions & 7 deletions src/experimental/ProbabilisticGraphicalModels/bayesnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
A structure representing a Bayesian Network.
"""
struct BayesianNetwork{V,T,F}
graph::SimpleGraph{T}
graph::SimpleDiGraph{T}
"names of the variables in the network"
names::Vector{V}
"mapping from variable names to ids"
Expand All @@ -25,7 +25,7 @@ end

function BayesianNetwork{V}() where {V}
return BayesianNetwork(
SimpleGraph{Int}(), # by default, vertex ids are integers
SimpleDiGraph{Int}(), # by default, vertex ids are integers
V[],
Dict{V,Int}(),
Dict{V,Any}(),
Expand Down Expand Up @@ -166,11 +166,26 @@ Ancestral sampling works by:
2. Sampling from each node in order, using the already-sampled parent values for conditional distributions
"""
function ancestral_sampling(bn::BayesianNetwork{V}) where {V}
ordered_vertices = Graphs.topological_sort(bn.graph)

ordered_vertices = Graphs.topological_sort_by_dfs(bn.graph)
samples = Dict{V,Any}()

# TODO: Implement sampling logic
for vertex_id in ordered_vertices
vertex_name = bn.names[vertex_id]
if bn.is_observed[vertex_id]
samples[vertex_name] = bn.values[vertex_name]
continue
end
if bn.is_stochastic[vertex_id]
dist_idx = findfirst(id -> id == vertex_id, bn.stochastic_ids)
sunxd3 marked this conversation as resolved.
Show resolved Hide resolved
samples[vertex_name] = rand(bn.distributions[dist_idx])
else
# deterministic node
parent_ids = Graphs.inneighbors(bn.graph, vertex_id)
parent_values = [samples[bn.names[pid]] for pid in parent_ids]
func_idx = findfirst(id -> id == vertex_id, bn.deterministic_ids)
naseweisssss marked this conversation as resolved.
Show resolved Hide resolved
samples[vertex_name] = bn.deterministic_functions[func_idx](parent_values...)
end
end

return samples
end
Expand All @@ -184,11 +199,94 @@ If Z is provided, the conditioning information in `bn` will be ignored.
function is_conditionally_independent end
naseweisssss marked this conversation as resolved.
Show resolved Hide resolved

function is_conditionally_independent(bn::BayesianNetwork{V}, X::V, Y::V) where {V}
# TODO: Implement
# Use currently observed variables as Z
Z = V[v for (v, is_obs) in zip(bn.names, bn.is_observed) if is_obs]
return is_conditionally_independent(bn, X, Y, Z)
end

function is_conditionally_independent(
bn::BayesianNetwork{V}, X::V, Y::V, Z::Vector{V}
) where {V}
# TODO: Implement
# Get vertex IDs
x_id = bn.names_to_ids[X]
y_id = bn.names_to_ids[Y]
z_ids = Set([bn.names_to_ids[z] for z in Z])

# Track visited nodes and their states
n_vertices = nv(bn.graph)
visited = falses(n_vertices)

# Queue entries are (node_id, from_parent)
queue = Tuple{Int, Bool}[]

# Start from X, can go both up and down initially
push!(queue, (x_id, true)) # As if coming from a parent

while !isempty(queue)
current_id, from_parent = popfirst!(queue)

if visited[current_id]
continue
end
visited[current_id] = true

# If we reached Y, path is active
if current_id == y_id
return false
end

is_conditioned = current_id in z_ids

# Get neighbors
parents = inneighbors(bn.graph, current_id)
children = outneighbors(bn.graph, current_id)

# Rule 1: If coming from parent and not conditioned, can go to children
if from_parent && !is_conditioned
append!(queue, [(child, true) for child in children])
end

# Rule 2: If coming from child and not conditioned, can go to parents
if !from_parent && !is_conditioned
append!(queue, [(parent, false) for parent in parents])
end

# Rule 3: If at a collider (or descendant of collider) and it's conditioned,
# can go up to parents
if !from_parent && (is_conditioned || has_conditioned_descendant(bn, current_id, z_ids))
if length(parents) > 1 # Is a collider
append!(queue, [(parent, false) for parent in parents])
end
end
end

return true
end

function has_conditioned_descendant(bn::BayesianNetwork, node_id::Int, z_ids::Set{Int})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

slight confusion here, why the ball can pass through if the collider child has conditioned descendants? and this could also incur repeated computations?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://pmc.ncbi.nlm.nih.gov/articles/PMC6089543/figure/F4/
image

I am considering this. Please let me know what do you think.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think it might incur repeated computations, but its a trade off for ensuring all possible paths are considered. I am thinking of keeping a visited nodes list, but unsure if that will cause conflict

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure about this, this article is talking about causal effects. But here we are only testing conditional independence. These are related but not entirely identical concepts.

Also I think this is a good example why it is important to communicate why you implement the algorithms in the way you do.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah i see makes sense

visited = falses(nv(bn.graph))
queue = Int[node_id]

while !isempty(queue)
current = popfirst!(queue)

if visited[current]
continue
end
visited[current] = true

# Check if current node is conditioned
if current in z_ids
return true
end

# Add all unvisited children to queue
for child in outneighbors(bn.graph, current)
if !visited[child]
push!(queue, child)
end
end
end

return false
end
167 changes: 163 additions & 4 deletions test/experimental/ProbabilisticGraphicalModels/bayesnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ using JuliaBUGS.ProbabilisticGraphicalModels:
add_deterministic_vertex!,
add_edge!,
condition,
decondition

decondition,
ancestral_sampling,
is_conditionally_independent
@testset "BayesianNetwork" begin
@testset "Adding vertices" begin
bn = BayesianNetwork{Symbol}()
Expand Down Expand Up @@ -96,7 +97,165 @@ using JuliaBUGS.ProbabilisticGraphicalModels:
@test bn_cond2.values[:B] == 2.0
end

@testset "Simple ancestral sampling" begin end
@testset "Simple ancestral sampling" begin
bn = BayesianNetwork{Symbol}()
# Add stochastic vertices
add_stochastic_vertex!(bn, :A, Normal(0, 1), false)
add_stochastic_vertex!(bn, :B, Normal(1, 2), false)
# Add deterministic vertex C = A + B
add_deterministic_vertex!(bn, :C, (a, b) -> a + b)
add_edge!(bn, :A, :C)
add_edge!(bn, :B, :C)
samples = ancestral_sampling(bn)
@test haskey(samples, :A)
@test haskey(samples, :B)
@test haskey(samples, :C)
@test samples[:A] isa Number
@test samples[:B] isa Number
@test samples[:C] ≈ samples[:A] + samples[:B]
end

@testset "Complex ancestral sampling" begin
bn = BayesianNetwork{Symbol}()
add_stochastic_vertex!(bn, :μ, Normal(0, 2), false)
add_stochastic_vertex!(bn, :σ, LogNormal(0, 0.5), false)
add_stochastic_vertex!(bn, :X, Normal(0, 1), false)
add_stochastic_vertex!(bn, :Y, Normal(0, 1), false)
add_deterministic_vertex!(bn, :X_scaled, (μ, σ, x) -> x * σ + μ)
add_deterministic_vertex!(bn, :Y_scaled, (μ, σ, y) -> y * σ + μ)
add_deterministic_vertex!(bn, :Sum, (x, y) -> x + y)
add_deterministic_vertex!(bn, :Product, (x, y) -> x * y)
add_deterministic_vertex!(bn, :N, () -> 2.0)
add_deterministic_vertex!(bn, :Mean, (s, n) -> s / n)
add_edge!(bn, :μ, :X_scaled)
add_edge!(bn, :σ, :X_scaled)
add_edge!(bn, :X, :X_scaled)
add_edge!(bn, :μ, :Y_scaled)
add_edge!(bn, :σ, :Y_scaled)
add_edge!(bn, :Y, :Y_scaled)
add_edge!(bn, :X_scaled, :Sum)
add_edge!(bn, :Y_scaled, :Sum)
add_edge!(bn, :X_scaled, :Product)
add_edge!(bn, :Y_scaled, :Product)
add_edge!(bn, :Sum, :Mean)
add_edge!(bn, :N, :Mean)
samples = ancestral_sampling(bn)

@test all(
haskey(samples, k) for
k in [:μ, :σ, :X, :Y, :X_scaled, :Y_scaled, :Sum, :Product, :Mean, :N]
)

@test all(samples[k] isa Number for k in keys(samples))
@test samples[:X_scaled] ≈ samples[:X] * samples[:σ] + samples[:μ]
@test samples[:Y_scaled] ≈ samples[:Y] * samples[:σ] + samples[:μ]
@test samples[:Sum] ≈ samples[:X_scaled] + samples[:Y_scaled]
@test samples[:Product] ≈ samples[:X_scaled] * samples[:Y_scaled]
@test samples[:Mean] ≈ samples[:Sum] / samples[:N]
@test samples[:N] ≈ 2.0
@test samples[:σ] > 0
# Multiple samples test
n_samples = 1000
means = zeros(n_samples)
for i in 1:n_samples
samples = ancestral_sampling(bn)
means[i] = samples[:Mean]
end

@test mean(means) ≈ 0 atol = 0.5
@test std(means) > 0
end

@testset "Bayes Ball" begin
@testset "Chain Structure (A → B → C)" begin
bn = BayesianNetwork{Symbol}()

add_stochastic_vertex!(bn, :A, Normal(), false)
add_stochastic_vertex!(bn, :B, Normal(), false)
add_stochastic_vertex!(bn, :C, Normal(), false)

add_edge!(bn, :A, :B)
add_edge!(bn, :B, :C)

@test is_conditionally_independent(bn, :A, :C, [:B])
@test !is_conditionally_independent(bn, :A, :C, Symbol[])
end

@testset "Fork Structure (A ← B → C)" begin
bn = BayesianNetwork{Symbol}()

add_stochastic_vertex!(bn, :A, Normal(), false)
add_stochastic_vertex!(bn, :B, Normal(), false)
add_stochastic_vertex!(bn, :C, Normal(), false)

add_edge!(bn, :B, :A)
add_edge!(bn, :B, :C)

@test is_conditionally_independent(bn, :A, :C, [:B])
@test !is_conditionally_independent(bn, :A, :C, Symbol[])
end

@testset "Bayes Ball" begin end
@testset "Collider Structure (A → B ← C)" begin
bn = BayesianNetwork{Symbol}()

add_stochastic_vertex!(bn, :A, Normal(), false)
add_stochastic_vertex!(bn, :B, Normal(), false)
add_stochastic_vertex!(bn, :C, Normal(), false)

add_edge!(bn, :A, :B)
add_edge!(bn, :C, :B)

@test is_conditionally_independent(bn, :A, :C, Symbol[])
@test !is_conditionally_independent(bn, :A, :C, [:B])
end

@testset "Complex Structure" begin
bn = BayesianNetwork{Symbol}()

for v in [:A, :B, :C, :D, :E]
add_stochastic_vertex!(bn, v, Normal(), false)
end

# Create structure:
# A → B → D
# ↓ ↑
# C → E
add_edge!(bn, :A, :B)
add_edge!(bn, :B, :C)
add_edge!(bn, :B, :D)
add_edge!(bn, :C, :E)
add_edge!(bn, :E, :D)

@test is_conditionally_independent(bn, :A, :E, [:B, :C])
@test !is_conditionally_independent(bn, :A, :E, [:B])
@test !is_conditionally_independent(bn, :A, :E, Symbol[])
end

@testset "Using Observed Variables" begin
bn = BayesianNetwork{Symbol}()

add_stochastic_vertex!(bn, :A, Normal(), false)
add_stochastic_vertex!(bn, :B, Normal(), true) # B is observed
add_stochastic_vertex!(bn, :C, Normal(), false)

add_edge!(bn, :A, :B)
add_edge!(bn, :B, :C)

@test is_conditionally_independent(bn, :A, :C)

bn_decond = decondition(bn)
@test !is_conditionally_independent(bn_decond, :A, :C)
end

@testset "Error Handling" begin
bn = BayesianNetwork{Symbol}()

add_stochastic_vertex!(bn, :A, Normal(), false)
add_stochastic_vertex!(bn, :B, Normal(), false)

@test_throws KeyError is_conditionally_independent(bn, :A, :NonExistent)
@test_throws KeyError is_conditionally_independent(bn, :NonExistent, :B)
@test_throws KeyError is_conditionally_independent(bn, :A, :B, [:NonExistent])
end
end
end
Loading