Skip to content

Commit

Permalink
keep consistent
Browse files Browse the repository at this point in the history
  • Loading branch information
naseweisssss committed Nov 18, 2024
1 parent 730bb36 commit 300209a
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 330 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,16 @@ using Distributions

include("bayesnet.jl")

export BayesianNetwork,
Factor,
create_factor,
multiply_factors,
marginalize,
add_stochastic_vertex!,
add_deterministic_vertex!,
add_edge!,
condition,
decondition,
variable_elimination

end
109 changes: 8 additions & 101 deletions src/experimental/ProbabilisticGraphicalModels/bayesnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
A structure representing a Bayesian Network.
"""
struct BayesianNetwork{V,T,F}
graph::SimpleDiGraph{T}
graph::SimpleGraph{T}
"names of the variables in the network"
names::Vector{V}
"mapping from variable names to ids"
Expand All @@ -25,7 +25,7 @@ end

function BayesianNetwork{V}() where {V}
return BayesianNetwork(
SimpleDiGraph{Int}(), # by default, vertex ids are integers
SimpleGraph{Int}(), # by default, vertex ids are integers
V[],
Dict{V,Int}(),
Dict{V,Any}(),
Expand Down Expand Up @@ -164,36 +164,13 @@ Perform ancestral sampling on a Bayesian network to generate one sample from the
Ancestral sampling works by:
1. Finding a topological ordering of the nodes
2. Sampling from each node in order, using the already-sampled parent values for conditional distributions
### Return Value
The function returns a `Dict{V, Any}` where:
- Each key is a variable name (of type `V`) in the Bayesian Network.
- Each value is the sampled value for that variable, which can be of any type (`Any`).
This dictionary represents a single sample from the joint distribution of the Bayesian Network, capturing the dependencies and conditional relationships defined in the network structure.
"""
function ancestral_sampling(bn::BayesianNetwork{V}) where {V}
ordered_vertices = Graphs.topological_sort_by_dfs(bn.graph)
ordered_vertices = Graphs.topological_sort(bn.graph)

samples = Dict{V,Any}()

for vertex_id in ordered_vertices
vertex_name = bn.names[vertex_id]
if bn.is_observed[vertex_id]
samples[vertex_name] = bn.values[vertex_name]
continue
end
if bn.is_stochastic[vertex_id]
dist_idx = findfirst(id -> id == vertex_id, bn.stochastic_ids)
samples[vertex_name] = rand(bn.distributions[dist_idx])
else
# deterministic node
parent_ids = Graphs.inneighbors(bn.graph, vertex_id)
parent_values = [samples[bn.names[pid]] for pid in parent_ids]
func_idx = findfirst(id -> id == vertex_id, bn.deterministic_ids)
samples[vertex_name] = bn.deterministic_functions[func_idx](parent_values...)
end
end
# TODO: Implement sampling logic

return samples
end
Expand All @@ -207,82 +184,13 @@ If Z is provided, the conditioning information in `bn` will be ignored.
function is_conditionally_independent end

function is_conditionally_independent(bn::BayesianNetwork{V}, X::V, Y::V) where {V}
# Use currently observed variables as Z
Z = V[v for (v, is_obs) in zip(bn.names, bn.is_observed) if is_obs]
return is_conditionally_independent(bn, X, Y, Z)
# TODO: Implement
end

function is_conditionally_independent(
bn::BayesianNetwork{V}, X::V, Y::V, Z::Vector{V}
) where {V}
println("debugging: X: $X, Y: $Y, Z: $Z")
if X in Z || Y in Z
return true
end

# Get vertex IDs
x_id = bn.names_to_ids[X]
y_id = bn.names_to_ids[Y]
z_ids = Set([bn.names_to_ids[z] for z in Z])

# Track visited nodes and their states
n_vertices = nv(bn.graph)
visited = falses(n_vertices)

# Queue entries are (node_id, from_parent)
queue = Tuple{Int,Bool}[]

# Start from X
push!(queue, (x_id, true)) # As if coming from parent
push!(queue, (x_id, false)) # As if coming from child

while !isempty(queue)
current_id, from_parent = popfirst!(queue)

if visited[current_id]
continue
end
visited[current_id] = true

# If we reached Y, path is active
if current_id == y_id
return false
end

is_conditioned = current_id in z_ids
parents = inneighbors(bn.graph, current_id)
children = outneighbors(bn.graph, current_id)

# Case 1: Node is not conditioned
if !is_conditioned
# Can go to children if coming from parent or at start node
if from_parent || current_id == x_id
for child in children
push!(queue, (child, true))
end
end

# Can go to parents if coming from child or at start node
if !from_parent || current_id == x_id
for parent in parents
push!(queue, (parent, false))
end
end
end

# Case 2: Node is conditioned or has conditioned descendants
if is_conditioned
# If this is a collider or descendant of collider
if length(parents) > 1 || !isempty(children)
# Can go to parents regardless of direction
for parent in parents
push!(queue, (parent, false))
end
end
end
end

return true
# TODO: Implement
end

using LinearAlgebra
Expand Down Expand Up @@ -350,7 +258,6 @@ function marginalize(factor::Factor, var::Symbol)

return Factor(new_vars, factor.distribution, new_parents)
end

"""
variable_elimination(bn::BayesianNetwork, query::Symbol, evidence::Dict{Symbol,Any})
Expand Down Expand Up @@ -456,4 +363,4 @@ function variable_elimination(
# Convert evidence to Dict{Symbol,Float64}
evidence_float = Dict{Symbol,Float64}(k => Float64(v) for (k, v) in evidence)
return variable_elimination(bn, query, evidence_float)
end
end
Loading

0 comments on commit 300209a

Please sign in to comment.