Skip to content

Commit

Permalink
modified bayes net definition
Browse files Browse the repository at this point in the history
  • Loading branch information
naseweisssss committed Nov 27, 2024
1 parent 61f2674 commit 6237304
Show file tree
Hide file tree
Showing 3 changed files with 359 additions and 398 deletions.
86 changes: 31 additions & 55 deletions src/experimental/ProbabilisticGraphicalModels/bayesnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@ struct BayesianNetwork{V,T,F}
"values of each variable in the network"
values::Dict{V,Any} # TODO: make it a NamedTuple for better performance in the future
"distributions of the stochastic variables"
distributions::Vector{Distribution}
# A distribution can be either:
# - A fixed distribution (like Uniform(0,1))
# - A function that takes parent values and returns a distribution
distributions::Vector{Union{Distribution,Function}}
"deterministic functions of the deterministic variables"
deterministic_functions::Vector{F}
"ids of the stochastic variables"
Expand Down Expand Up @@ -115,16 +118,17 @@ Adds a stochastic vertex with name `name` and distribution `dist` to the Bayesia
if successful, 0 otherwise.
"""
function add_stochastic_vertex!(
bn::BayesianNetwork{V,T}, name::V, dist::Distribution, is_observed::Bool
bn::BayesianNetwork{V,T},
name::V,
dist::Union{Distribution,Function},
is_observed::Bool
)::T where {V,T}
Graphs.add_vertex!(bn.graph) || return 0
id = nv(bn.graph)
push!(bn.distributions, dist)
push!(bn.is_stochastic, true)
push!(bn.is_observed, is_observed)
push!(bn.names, name)
bn.names_to_ids[name] = id
push!(bn.stochastic_ids, id)
return id
end

Expand Down Expand Up @@ -366,65 +370,37 @@ function eliminate_variables(
query_id::Int,
assignments::Dict{V,Any}
) where {V}
# Base case: reached the query variable
# Base case: reached query variable
if isempty(ordered_vertices) || ordered_vertices[1] == query_id
dist_idx = findfirst(id -> id == query_id, bn.stochastic_ids)
return bn.distributions[dist_idx]
end

current_id = ordered_vertices[1]
remaining_vertices = ordered_vertices[2:end]
current_name = bn.names[current_id]

# If the current node is observed, use its value
if bn.is_observed[current_id]
assignments[current_name] = bn.values[current_name]
return eliminate_variables(bn, remaining_vertices, query_id, assignments)
end
# For current variable, create mixture over its values
components = Distribution[]
weights = Float64[]

# Handle stochastic nodes
if bn.is_stochastic[current_id]
dist_idx = findfirst(id -> id == current_id, bn.stochastic_ids)
current_dist = bn.distributions[dist_idx]
# Try both values (0 and 1) # TODO: generalize for other values
for value in [0, 1]
new_assignments = copy(assignments)
new_assignments[bn.names[current_id]] = value

if is_discrete_distribution(current_dist)
# For discrete nodes, create mixture of distributions
support_values = get_support(current_dist)
components = Distribution[]
weights = Float64[]

for value in support_values
# Create new assignment with current value
new_assignments = copy(assignments)
new_assignments[current_name] = value

# Recursive call
component = eliminate_variables(bn, remaining_vertices, query_id, new_assignments)
push!(components, component)
push!(weights, pdf(current_dist, value))
end

# Normalize weights
weights ./= sum(weights)

# Return mixture distribution with explicit variate form
if all(c isa DiscreteUnivariateDistribution for c in components)
return MixtureModel{Discrete, Univariate}(components, weights)
elseif all(c isa ContinuousUnivariateDistribution for c in components)
return MixtureModel{Continuous, Univariate}(components, weights)
else
error("Mixed discrete and continuous distributions not supported")
end
else
# For continuous nodes, integrate (not implemented yet)
error("Continuous variable elimination not implemented yet")
end
else
# Handle deterministic nodes
func_idx = findfirst(id -> id == current_id, bn.deterministic_ids)
parent_ids = Graphs.inneighbors(bn.graph, current_id)
parent_values = [assignments[bn.names[pid]] for pid in parent_ids]
assignments[current_name] = bn.deterministic_functions[func_idx](parent_values...)
return eliminate_variables(bn, remaining_vertices, query_id, assignments)
# Get distribution for remaining variables
component = eliminate_variables(bn, remaining_vertices, query_id, new_assignments)
println("Components so far: ", components)
println("Current component: ", component)
push!(components, component)

# Get weight from current node's distribution
dist_idx = findfirst(id -> id == current_id, bn.stochastic_ids)
push!(weights, pdf(bn.distributions[dist_idx], value))
end
end

# Normalize weights
weights ./= sum(weights)

return MixtureModel(components, weights)
end
Loading

0 comments on commit 6237304

Please sign in to comment.