Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Variable Elimination #238

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,16 @@ using Distributions

include("bayesnet.jl")

export BayesianNetwork,
Factor,
create_factor,
multiply_factors,
marginalize,
add_stochastic_vertex!,
add_deterministic_vertex!,
add_edge!,
condition,
decondition,
variable_elimination

end
184 changes: 184 additions & 0 deletions src/experimental/ProbabilisticGraphicalModels/bayesnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,187 @@ function is_conditionally_independent(
) where {V}
# TODO: Implement
end

using LinearAlgebra

# Add these structs and methods before the variable_elimination function
struct Factor
variables::Vector{Symbol}
distribution::Distribution
parents::Vector{Symbol}
end

"""
Create a factor from a node in the Bayesian network.
"""
function create_factor(bn::BayesianNetwork, node::Symbol)
node_id = bn.names_to_ids[node]
if !bn.is_stochastic[node_id]
error("Cannot create factor for deterministic node")
end

dist_idx = findfirst(id -> id == node_id, bn.stochastic_ids)
dist = bn.distributions[dist_idx]
parent_ids = inneighbors(bn.graph, node_id)
parents = Symbol[bn.names[pid] for pid in parent_ids]

return Factor([node], dist, parents)
end

"""
Multiply two factors.
"""
function multiply_factors(f1::Factor, f2::Factor)
new_vars = unique(vcat(f1.variables, f2.variables))
new_parents = unique(vcat(f1.parents, f2.parents))

if f1.distribution isa Normal && f2.distribution isa Normal
μ = mean(f1.distribution) + mean(f2.distribution)
σ = sqrt(var(f1.distribution) + var(f2.distribution))
new_dist = Normal(μ, σ)
elseif f1.distribution isa Categorical && f2.distribution isa Categorical
p = f1.distribution.p .* f2.distribution.p
p = p ./ sum(p)
new_dist = Categorical(p)
else
new_dist = Normal(0, 1)
end

return Factor(new_vars, new_dist, new_parents)
end

"""
Marginalize (sum/integrate) out a variable from a factor.
"""
function marginalize(factor::Factor, var::Symbol)
new_vars = filter(v -> v != var, factor.variables)
new_parents = filter(v -> v != var, factor.parents)

if factor.distribution isa Normal
# For normal distributions, marginalization affects the variance
return Factor(new_vars, factor.distribution, new_parents)
elseif factor.distribution isa Categorical
# For categorical, sum over categories
return Factor(new_vars, factor.distribution, new_parents)
end

return Factor(new_vars, factor.distribution, new_parents)
end
"""
variable_elimination(bn::BayesianNetwork, query::Symbol, evidence::Dict{Symbol,Any})

Perform variable elimination to compute P(query | evidence).
"""
function variable_elimination(
bn::BayesianNetwork{Symbol,Int,Any}, query::Symbol, evidence::Dict{Symbol,Float64}
)
println("\nStarting Variable Elimination")
println("Query variable: ", query)
println("Evidence: ", evidence)

# Step 1: Create initial factors
factors = Dict{Symbol,Factor}()
for node in bn.names
if bn.is_stochastic[bn.names_to_ids[node]]
println("Creating factor for: ", node)
factors[node] = create_factor(bn, node)
end
end

# Step 2: Incorporate evidence
for (var, val) in evidence
println("Incorporating evidence: ", var, " = ", val)
node_id = bn.names_to_ids[var]
if bn.is_stochastic[node_id]
dist_idx = findfirst(id -> id == node_id, bn.stochastic_ids)
if bn.distributions[dist_idx] isa Normal
factors[var] = Factor([var], Normal(val, 0.1), Symbol[])
elseif bn.distributions[dist_idx] isa Categorical
p = zeros(length(bn.distributions[dist_idx].p))
p[Int(val)] = 1.0
factors[var] = Factor([var], Categorical(p), Symbol[])
end
end
end

# Step 3: Determine elimination ordering
eliminate_vars = Symbol[]
for node in bn.names
if node != query && !haskey(evidence, node)
push!(eliminate_vars, node)
end
end
println("Variables to eliminate: ", eliminate_vars)

# Step 4: Variable elimination
for var in eliminate_vars
println("\nEliminating variable: ", var)

# Find factors containing this variable
relevant_factors = Factor[]
relevant_keys = Symbol[]
for (k, f) in factors
if var in f.variables || var in f.parents
push!(relevant_factors, f)
push!(relevant_keys, k)
end
end

if !isempty(relevant_factors)
# Multiply factors
combined_factor = reduce(multiply_factors, relevant_factors)

# Marginalize out the variable
new_factor = marginalize(combined_factor, var)

# Update factors
for k in relevant_keys
delete!(factors, k)
end

# Only add the new factor if it has variables
if !isempty(new_factor.variables)
factors[new_factor.variables[1]] = new_factor
end
end
end

# Step 5: Multiply remaining factors
final_factors = collect(values(factors))
if isempty(final_factors)
# If no factors remain, return a default probability
return 1.0
end

result_factor = reduce(multiply_factors, final_factors)

# Return normalized probability
if result_factor.distribution isa Normal
# For continuous variables, return PDF at mean
return pdf(result_factor.distribution, mean(result_factor.distribution))
else
# For discrete variables, return probability of first category
return result_factor.distribution.p[1]
end
end

# Add a more general method that converts to the specific type
function variable_elimination(
bn::BayesianNetwork{Symbol,Int,Any}, query::Symbol, evidence::Dict{Symbol,<:Any}
)
# Convert evidence to Dict{Symbol,Float64}, handling both continuous and discrete cases
evidence_float = Dict{Symbol,Float64}()
for (k, v) in evidence
node_id = bn.names_to_ids[k]
dist_idx = findfirst(id -> id == node_id, bn.stochastic_ids)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

if bn.distributions[dist_idx] isa Categorical
# For categorical variables, keep the original value (0-based indexing)
evidence_float[k] = Float64(v)
else
# For continuous variables, convert to Float64
evidence_float[k] = Float64(v)
end
end
return variable_elimination(bn, query, evidence_float)
end
Loading
Loading