Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ancestral Sampling and Bayes Ball Algorithm #233

Merged
merged 35 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
304fd88
change SimpleGraph to SimpleDiGraph for topological sort and added te…
naseweisssss Nov 4, 2024
573b3e5
implementation and tests for ancestral sampling
naseweisssss Nov 4, 2024
62c07ed
Update src/experimental/ProbabilisticGraphicalModels/bayesnet.jl
naseweisssss Nov 4, 2024
344b4f1
Update src/experimental/ProbabilisticGraphicalModels/bayesnet.jl
naseweisssss Nov 4, 2024
0bb810a
Update test/experimental/ProbabilisticGraphicalModels/bayesnet.jl
naseweisssss Nov 4, 2024
6ce7559
Update test/experimental/ProbabilisticGraphicalModels/bayesnet.jl
naseweisssss Nov 4, 2024
c385c37
Update test/experimental/ProbabilisticGraphicalModels/bayesnet.jl
naseweisssss Nov 4, 2024
1d0ed43
Apply suggestions from code review
naseweisssss Nov 4, 2024
bc2fccf
add more complex tests
naseweisssss Nov 5, 2024
102069d
Apply suggestions from code review
naseweisssss Nov 5, 2024
503910f
add wrapper functions
naseweisssss Nov 5, 2024
593d18e
I dont think this works, test still fails
naseweisssss Nov 6, 2024
6f2bbcf
less test failings, but still not good
naseweisssss Nov 6, 2024
8ffba3b
linting
naseweisssss Nov 6, 2024
2c91dac
Merge branch 'master' into rylin/bayesnet_implementations
sunxd3 Nov 8, 2024
8339b0e
formatting
sunxd3 Nov 8, 2024
25308ff
add documents on return values
naseweisssss Nov 15, 2024
35c1c3f
TDD add test for VE algorithsm
naseweisssss Nov 15, 2024
1a83c12
added a sample implementation of variable elimination
naseweisssss Nov 15, 2024
c387813
formatting
naseweisssss Nov 15, 2024
deaf99c
temporary pass test
naseweisssss Nov 16, 2024
830c208
formatting
naseweisssss Nov 16, 2024
cae8914
added test case for the corner case X/Y in Z
naseweisssss Nov 17, 2024
966e4d1
Apply suggestions from code review
naseweisssss Nov 17, 2024
f9347bc
Improve `show` function of BUGSModel (#236)
sunxd3 Nov 14, 2024
a4cbbba
Permit dot call (like `Distributions.Normal`) to be used in model def…
sunxd3 Nov 16, 2024
1ca9933
remove comments
naseweisssss Nov 17, 2024
e229cac
Merge branch 'master' into rylin/bayesnet_implementations
sunxd3 Nov 18, 2024
34a17b2
remove conditioned descendant
naseweisssss Nov 18, 2024
0921f2e
remove VE
naseweisssss Nov 18, 2024
43027c5
extended bayes ball to have X and Y as a vector
naseweisssss Nov 19, 2024
a0ea8e9
removing debugging statement
naseweisssss Nov 19, 2024
a318f60
fix formating
sunxd3 Nov 20, 2024
121d412
formatting
sunxd3 Nov 20, 2024
0cff1db
Merge branch 'master' into rylin/bayesnet_implementations
sunxd3 Nov 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,15 @@ using Distributions

include("bayesnet.jl")

export BayesianNetwork,
add_stochastic_vertex!,
add_deterministic_vertex!,
add_edge!,
condition,
condition!,
decondition,
decondition!,
ancestral_sampling,
is_conditionally_independent

end
161 changes: 149 additions & 12 deletions src/experimental/ProbabilisticGraphicalModels/bayesnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
A structure representing a Bayesian Network.
"""
struct BayesianNetwork{V,T,F}
graph::SimpleGraph{T}
graph::SimpleDiGraph{T}
"names of the variables in the network"
names::Vector{V}
"mapping from variable names to ids"
Expand All @@ -25,7 +25,7 @@ end

function BayesianNetwork{V}() where {V}
return BayesianNetwork(
SimpleGraph{Int}(), # by default, vertex ids are integers
SimpleDiGraph{Int}(), # by default, vertex ids are integers
V[],
Dict{V,Int}(),
Dict{V,Any}(),
Expand Down Expand Up @@ -164,31 +164,168 @@ Perform ancestral sampling on a Bayesian network to generate one sample from the
Ancestral sampling works by:
1. Finding a topological ordering of the nodes
2. Sampling from each node in order, using the already-sampled parent values for conditional distributions

### Return Value
The function returns a `Dict{V, Any}` where:
- Each key is a variable name (of type `V`) in the Bayesian Network.
- Each value is the sampled value for that variable, which can be of any type (`Any`).

This dictionary represents a single sample from the joint distribution of the Bayesian Network, capturing the dependencies and conditional relationships defined in the network structure.

"""
function ancestral_sampling(bn::BayesianNetwork{V}) where {V}
ordered_vertices = Graphs.topological_sort(bn.graph)

ordered_vertices = Graphs.topological_sort_by_dfs(bn.graph)
samples = Dict{V,Any}()

# TODO: Implement sampling logic
for vertex_id in ordered_vertices
vertex_name = bn.names[vertex_id]
if bn.is_observed[vertex_id]
samples[vertex_name] = bn.values[vertex_name]
continue
end
if bn.is_stochastic[vertex_id]
dist_idx = findfirst(id -> id == vertex_id, bn.stochastic_ids)
sunxd3 marked this conversation as resolved.
Show resolved Hide resolved
samples[vertex_name] = rand(bn.distributions[dist_idx])
else
# deterministic node
parent_ids = Graphs.inneighbors(bn.graph, vertex_id)
parent_values = [samples[bn.names[pid]] for pid in parent_ids]
func_idx = findfirst(id -> id == vertex_id, bn.deterministic_ids)
naseweisssss marked this conversation as resolved.
Show resolved Hide resolved
samples[vertex_name] = bn.deterministic_functions[func_idx](parent_values...)
end
end

return samples
end

"""
is_conditionally_independent(bn::BayesianNetwork, X::V, Y::V[, Z::Vector{V}]) where {V}
is_conditionally_independent(bn::BayesianNetwork, X::Vector{V}, Y::Vector{V}, Z::Vector{V}) where {V}

Test whether sets of variables X and Y are conditionally independent given set Z in a Bayesian Network using the Bayes Ball algorithm.

# Arguments
- `bn::BayesianNetwork`: The Bayesian Network structure
- `X::Vector{V}`: First set of variables to test for independence
- `Y::Vector{V}`: Second set of variables to test for independence
- `Z::Vector{V}`: Set of conditioning variables (can be empty)

Determines if two variables X and Y are conditionally independent given the conditioning information already known.
If Z is provided, the conditioning information in `bn` will be ignored.
# Returns
- `true`: if X and Y are conditionally independent given Z (X ⊥ Y | Z)
- `false`: if X and Y are conditionally dependent given Z

# Description
The Bayes Ball algorithm determines conditional independence by checking if there exists an active path between
variables in X and Y given Z. The algorithm follows these rules:
- In a chain (A → B → C): B blocks the path if conditioned
- In a fork (A ← B → C): B blocks the path if conditioned
- In a collider (A → B ← C): B opens the path if conditioned
# Examples
```
"""
function is_conditionally_independent end
function is_conditionally_independent(
bn::BayesianNetwork{V}, X::Vector{V}, Y::Vector{V}, Z::Vector{V}
) where {V}
isempty(X) && throw(ArgumentError("X cannot be empty"))
isempty(Y) && throw(ArgumentError("Y cannot be empty"))

x_ids = Set([bn.names_to_ids[x] for x in X])
y_ids = Set([bn.names_to_ids[y] for y in Y])
z_ids = Set([bn.names_to_ids[z] for z in Z])

sunxd3 marked this conversation as resolved.
Show resolved Hide resolved
# Check if any variable in X or Y is in Z
if !isempty(intersect(x_ids, z_ids)) || !isempty(intersect(y_ids, z_ids))
return true
end

# Add observed variables to conditioning set
for (id, is_obs) in enumerate(bn.is_observed)
if is_obs
push!(z_ids, id)
end
end

# Track visited nodes and their directions
n_vertices = nv(bn.graph)
visited_up = falses(n_vertices) # Visited going up (from child to parent)
visited_down = falses(n_vertices) # Visited going down (from parent to child)

# Queue entries are (node_id, going_up)
queue = Tuple{Int,Bool}[]

# Start from all X nodes
for x_id in x_ids
push!(queue, (x_id, true)) # Try going up
push!(queue, (x_id, false)) # Try going down
end

while !isempty(queue)
current_id, going_up = popfirst!(queue)

# Skip if we've visited this node in this direction
if (going_up && visited_up[current_id]) || (!going_up && visited_down[current_id])
continue
end

# Mark as visited in current direction
if going_up
visited_up[current_id] = true
else
visited_down[current_id] = true
end

# If we reached a Y node, path is active
if current_id in y_ids
return false
end

is_conditioned = current_id in z_ids
parents = inneighbors(bn.graph, current_id)
children = outneighbors(bn.graph, current_id)

if is_conditioned
# If conditioned:
# - In a chain/fork: blocks the path
# - In a collider or descendant of collider: allows going up to parents
if length(parents) > 1 || !isempty(children) # Is collider or has children
for parent in parents
push!(queue, (parent, true)) # Can only go up to parents
end
end
else
# If not conditioned:
if going_up
# Going up: can visit parents
for parent in parents
push!(queue, (parent, true))
end
else
# Going down: can visit children
for child in children
push!(queue, (child, false))
end
end

# At starting nodes (X), we can go both up and down
if current_id in x_ids
if going_up
for child in children
push!(queue, (child, false))
end
else
for parent in parents
push!(queue, (parent, true))
end
end
end
end
end

function is_conditionally_independent(bn::BayesianNetwork{V}, X::V, Y::V) where {V}
# TODO: Implement
return true
end

# Single variable version with Z
function is_conditionally_independent(
bn::BayesianNetwork{V}, X::V, Y::V, Z::Vector{V}
) where {V}
# TODO: Implement
return is_conditionally_independent(bn, [X], [Y], Z)
end
Loading
Loading