Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/src/library/public.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ LastStage
### Functions

```@docs
compile_algo
@posterior_marginals
prod
sum
redu
Expand Down
17 changes: 7 additions & 10 deletions docs/src/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@ Calculates the posterior marginal of each variable in the input graph. The
input graph should be defined in the [UAI model file format](@ref).

```@example main
algo = compile_algo("problems/asia/asia.uai")
eval(algo)
algo = @posterior_marginals("problems/asia/asia.uai")
obsvars, obsvals = Int64[], Int64[]
marginals = run_algo(obsvars, obsvals)
marginals = algo(obsvars, obsvals)
```

#### Example 2
Expand All @@ -27,12 +26,11 @@ format](@ref). The evidence variables and values should be given in the [UAI
evidence file format](@ref).

```@example main
algo = compile_algo(
algo = @posterior_marginals(
"problems/asia/asia.uai",
uai_evid_filepath = "problems/asia/asia.uai.evid")
eval(algo)
obsvars, obsvals = JunctionTrees.read_uai_evid_file("problems/asia/asia.uai.evid")
marginals = run_algo(obsvars, obsvals)
marginals = algo(obsvars, obsvals)
```

#### Example 3
Expand All @@ -42,13 +40,12 @@ junction tree (which is passed as an argument) is used. This junction tree
should be defined in the [PACE graph format](@ref).

```@example main
algo = compile_algo(
algo = @posterior_marginals(
"problems/asia/asia.uai",
uai_evid_filepath = "problems/asia/asia.uai.evid",
td_filepath = "problems/asia/asia.td")
eval(algo)
obsvars, obsvals = JunctionTrees.read_uai_evid_file("problems/asia/asia.uai.evid")
marginals = run_algo(obsvars, obsvals)
marginals = algo(obsvars, obsvals)
```

#### Example 4
Expand All @@ -57,7 +54,7 @@ Returns the expression of the junction tree algorithm up to the backward pass
stage.

```@example main
backward_pass_expr = compile_algo( "problems/asia/asia.uai", last_stage = BackwardPass)
backward_pass_expr = @posterior_marginals("problems/asia/asia.uai", last_stage = BackwardPass)
```

The stages supported are:
Expand Down
4 changes: 2 additions & 2 deletions src/JunctionTrees.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ Main module for `JunctionTrees.jl` -- a Julia implementation of the junction tre

One main function is exported from this module for public use:

- [`compile_algo`](@ref). Compiles and returns an expression that computes the posterior marginals of the model given evidence using the junction tree algorithm.
- [`@posterior_marginals`](@ref). Compiles and returns an expression that computes the posterior marginals of the model given evidence using the junction tree algorithm.

# Exports

Expand All @@ -16,7 +16,7 @@ using Combinatorics: combinations
using MacroTools: @capture, rmlines
using MLStyle: @match

export compile_algo, Factor, prod, sum, redu, norm, LastStage, ForwardPass,
export @posterior_marginals, Factor, prod, sum, redu, norm, LastStage, ForwardPass,
BackwardPass, JointMarginals, UnnormalizedMarginals, Marginals

import Base:
Expand Down
6 changes: 0 additions & 6 deletions src/graphical_transformation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,6 @@ Construct a tree decomposition graph based on `td_filepath`.

The `td_filepath` file format is defined in:
https://pacechallenge.org/2017/treewidth/.

# Example
```
td_filepath = "../problems/Promedus_26/Promedus_26.td"
td = compile_algo(td_filepath)
```
"""
function construct_td_graph(td_filepath::AbstractString)

Expand Down
50 changes: 28 additions & 22 deletions src/junction_tree_algorithm.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
Enumerated type used to select up to which stage an expression of the junction
tree algorithm should be returned after calling [`compile_algo`](@ref).
tree algorithm should be returned after calling [`@posterior_marginals`](@ref).
"""
@enum LastStage begin
ForwardPass
Expand All @@ -9,11 +9,18 @@ tree algorithm should be returned after calling [`compile_algo`](@ref).
UnnormalizedMarginals
Marginals
end
@doc "When assigned to the keyword argument `last_stage` of [`compile_algo`](@ref), an expression up to and including the forward pass is returned." ForwardPass
@doc "When assigned to the keyword argument `last_stage` of [`compile_algo`](@ref), an expression up to and including the backward pass is returned." BackwardPass
@doc "When assigned to the keyword argument `last_stage` of [`compile_algo`](@ref), an expression that computes the cluster joint marginals is returned." JointMarginals
@doc "When assigned to the keyword argument `last_stage` of [`compile_algo`](@ref), an expression that computes the joint marginals is returned." UnnormalizedMarginals
@doc "When assigned to the keyword argument `last_stage` of [`compile_algo`](@ref), an expression that computes the posterior marginals is returned (default)." Marginals
@doc "When assigned to the keyword argument `last_stage` of [`@posterior_marginals`](@ref), an expression up to and including the forward pass is returned." ForwardPass
@doc "When assigned to the keyword argument `last_stage` of [`@posterior_marginals`](@ref), an expression up to and including the backward pass is returned." BackwardPass
@doc "When assigned to the keyword argument `last_stage` of [`@posterior_marginals`](@ref), an expression that computes the cluster joint marginals is returned." JointMarginals
@doc "When assigned to the keyword argument `last_stage` of [`@posterior_marginals`](@ref), an expression that computes the joint marginals is returned." UnnormalizedMarginals
@doc "When assigned to the keyword argument `last_stage` of [`@posterior_marginals`](@ref), an expression that computes the posterior marginals is returned (default)." Marginals

macro posterior_marginals(uai_filepath, kwargs...)
posterior_marginals(
Base.eval(__module__, uai_filepath);
map(kwarg -> Pair(first(kwarg.args), Base.eval(__module__, last(kwarg.args))), kwargs)...
)
end

"""
$(TYPEDSIGNATURES)
Expand All @@ -36,10 +43,9 @@ of all the variables in the model.
```
package_root_dir = pathof(JunctionTrees) |> dirname |> dirname
uai_filepath = joinpath(package_root_dir, "docs", "src", "problems", "paskin", "paskin.uai")
algo = compile_algo(uai_filepath)
eval(algo)
algo = @posterior_marginals(uai_filepath)
obsvars, obsvals = Int64[], Int64[]
marginals = run_algo(obsvars, obsvals)
marginals = algo(obsvars, obsvals)

# output

Expand All @@ -56,12 +62,11 @@ marginals = run_algo(obsvars, obsvals)
package_root_dir = pathof(JunctionTrees) |> dirname |> dirname
uai_filepath = joinpath(package_root_dir, "docs", "src", "problems", "paskin", "paskin.uai")
uai_evid_filepath = joinpath(package_root_dir, "docs", "src", "problems", "paskin", "paskin.uai.evid")
algo = compile_algo(
algo = @posterior_marginals(
uai_filepath,
uai_evid_filepath = uai_evid_filepath)
eval(algo)
obsvars, obsvals = JunctionTrees.read_uai_evid_file(uai_evid_filepath)
marginals = run_algo(obsvars, obsvals)
marginals = algo(obsvars, obsvals)

# output

Expand All @@ -74,16 +79,17 @@ marginals = run_algo(obsvars, obsvals)
Factor{Float64, 1}((6,), [0.6118571666785584, 0.3881428333214415])
```
"""
function compile_algo(uai_filepath::AbstractString;
uai_evid_filepath::AbstractString = "",
td_filepath::AbstractString = "",
apply_partial_evaluation::Bool = false,
last_stage::LastStage = Marginals,
smart_root_selection::Bool = true,
factor_eltype::DataType = Float64,
use_omeinsum::Bool = false,
correct_fp_overflows::Bool = false,
)
function posterior_marginals(
uai_filepath::AbstractString;
uai_evid_filepath::AbstractString="",
td_filepath::AbstractString="",
apply_partial_evaluation::Bool=false,
last_stage::LastStage=Marginals,
smart_root_selection::Bool=true,
factor_eltype::DataType=Float64,
use_omeinsum::Bool=false,
correct_fp_overflows::Bool=false,
)

# Read PGM
nvars, cards, _, factors = read_uai_file(uai_filepath, factor_eltype = factor_eltype)
Expand Down
5 changes: 2 additions & 3 deletions src/omeinsum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@ Speed up the sum-product in the algorithm using `OMEinsum`'s contraction routine
```
package_root_dir = pathof(JunctionTrees) |> dirname |> dirname
uai_filepath = joinpath(package_root_dir, "docs", "src", "problems", "paskin", "paskin.uai")
algo = compile_algo(uai_filepath, use_omeinsum = true);
eval(algo)
algo = @posterior_marginals(uai_filepath, use_omeinsum = true);
obsvars, obsvals = Int64[], Int64[]
marginals = run_algo(obsvars, obsvals)
marginals = algo(obsvars, obsvals)
```
"""
function boost_algo(algo::Expr; optimizer=GreedyMethod())
Expand Down
35 changes: 15 additions & 20 deletions test/junction_tree_algorithm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,12 @@ using JunctionTrees

@debug " Test: Default (Min-fill heuristic)"
@debug " Compiling algo..."
algo = compile_algo(
uai_filepath;
algo = @posterior_marginals(
uai_filepath,
uai_evid_filepath = uai_evid_filepath,
)
eval(algo)
@debug " Running algo..."
marginals = run_algo(obsvars, obsvals) |> x -> map(y -> y.vals, x)
marginals = algo(obsvars, obsvals) |> x -> map(y -> y.vals, x)
@test isapprox(marginals, reference_marginals, atol=0.03)

# ------------------------------------------------------------------------------
Expand All @@ -37,15 +36,14 @@ using JunctionTrees

@debug " Test: Using an existing junction tree"
@debug " Compiling algo..."
algo = compile_algo(
uai_filepath;
algo = @posterior_marginals(
uai_filepath,
uai_evid_filepath = uai_evid_filepath,
td_filepath = td_filepath,
correct_fp_overflows = true,
)
eval(algo)
@debug " Running algo..."
marginals = run_algo(obsvars, obsvals) |> x -> map(y -> y.vals, x)
marginals = algo(obsvars, obsvals) |> x -> map(y -> y.vals, x)
@test isapprox(marginals, reference_marginals, atol=0.03)

# ------------------------------------------------------------------------------
Expand All @@ -54,15 +52,14 @@ using JunctionTrees

@debug " Test: Float32 factor values"
@debug " Compiling algo..."
algo = compile_algo(
uai_filepath;
algo = @posterior_marginals(
uai_filepath,
uai_evid_filepath = uai_evid_filepath,
td_filepath = td_filepath,
factor_eltype = Float32,
)
eval(algo)
@debug " Running algo..."
marginals = run_algo(obsvars, obsvals) |> x -> map(y -> y.vals, x)
marginals = algo(obsvars, obsvals) |> x -> map(y -> y.vals, x)
@test isapprox(marginals, reference_marginals)

# ------------------------------------------------------------------------------
Expand All @@ -71,16 +68,15 @@ using JunctionTrees

@debug " Test: Partial evaluation"
@debug " Compiling algo..."
algo = compile_algo(
uai_filepath;
algo = @posterior_marginals(
uai_filepath,
uai_evid_filepath = uai_evid_filepath,
td_filepath = td_filepath,
apply_partial_evaluation = true,
correct_fp_overflows = true,
)
eval(algo)
@debug " Running algo..."
marginal_factors = run_algo(obsvars, obsvals)
marginal_factors = algo(obsvars, obsvals)
# Filter the observed variables from the obtained solution
marginal_factors_filtered = filter(x -> !(x.vars[1] in obsvars) , marginal_factors)
marginals = map(y -> y.vals, marginal_factors_filtered)
Expand All @@ -94,15 +90,14 @@ using JunctionTrees

@debug " Test: OMEinsum"
@debug " Compiling algo..."
algo = compile_algo(
uai_filepath;
algo = @posterior_marginals(
uai_filepath,
uai_evid_filepath = uai_evid_filepath,
td_filepath = td_filepath,
use_omeinsum = true,
)
eval(algo)
@debug " Running algo..."
marginals = run_algo(obsvars, obsvals) |> x -> map(y -> y.vals, x)
marginals = algo(obsvars, obsvals) |> x -> map(y -> y.vals, x)
@test isapprox(marginals, reference_marginals, atol=0.03)

end
Expand Down
14 changes: 6 additions & 8 deletions test/uai2014.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,15 @@ for benchmark in benchmarks

println(" Test type: default")
println(" Compiling algo...")
algo = compile_algo(
uai_filepath;
algo = @posterior_marginals(
uai_filepath,
uai_evid_filepath = uai_evid_filepath,
td_filepath = td_filepath,
factor_eltype = Float64,
correct_fp_overflows = true,
)
eval(algo)
println(" Running algo...")
marginals = run_algo(obsvars, obsvals) |> x -> map(y -> y.vals, x)
marginals = algo(obsvars, obsvals) |> x -> map(y -> y.vals, x)
@test isapprox(marginals, reference_marginals, atol = 0.01)

# ------------------------------------------------------------------------------
Expand All @@ -68,16 +67,15 @@ for benchmark in benchmarks

# println(" Test type: OMEinsum backend")
# println(" Compiling algo...")
# algo = compile_algo(
# uai_filepath;
# algo = @posterior_marginals(
# uai_filepath,
# uai_evid_filepath = uai_evid_filepath,
# td_filepath = td_filepath,
# factor_eltype = Float64,
# use_omeinsum = true,
# )
# eval(algo)
# println(" Running algo...")
# marginals = run_algo(obsvars, obsvals) |> x -> map(y -> y.vals, x)
# marginals = algo(obsvars, obsvals) |> x -> map(y -> y.vals, x)
# @test isapprox(marginals, reference_marginals, atol = 0.01)

end
Expand Down