diff --git a/docs/src/library/public.md b/docs/src/library/public.md index eac0c69..02beead 100644 --- a/docs/src/library/public.md +++ b/docs/src/library/public.md @@ -50,7 +50,7 @@ LastStage ### Functions ```@docs -compile_algo +@posterior_marginals prod sum redu diff --git a/docs/src/usage.md b/docs/src/usage.md index 06eba09..1888928 100644 --- a/docs/src/usage.md +++ b/docs/src/usage.md @@ -13,10 +13,9 @@ Calculates the posterior marginal of each variable in the input graph. The input graph should be defined in the [UAI model file format](@ref). ```@example main -algo = compile_algo("problems/asia/asia.uai") -eval(algo) +algo = @posterior_marginals("problems/asia/asia.uai") obsvars, obsvals = Int64[], Int64[] -marginals = run_algo(obsvars, obsvals) +marginals = algo(obsvars, obsvals) ``` #### Example 2 @@ -27,12 +26,11 @@ format](@ref). The evidence variables and values should be given in the [UAI evidence file format](@ref). ```@example main -algo = compile_algo( +algo = @posterior_marginals( "problems/asia/asia.uai", uai_evid_filepath = "problems/asia/asia.uai.evid") -eval(algo) obsvars, obsvals = JunctionTrees.read_uai_evid_file("problems/asia/asia.uai.evid") -marginals = run_algo(obsvars, obsvals) +marginals = algo(obsvars, obsvals) ``` #### Example 3 @@ -42,13 +40,12 @@ junction tree (which is passed as an argument) is used. This junction tree should be defined in the [PACE graph format](@ref). ```@example main -algo = compile_algo( +algo = @posterior_marginals( "problems/asia/asia.uai", uai_evid_filepath = "problems/asia/asia.uai.evid", td_filepath = "problems/asia/asia.td") -eval(algo) obsvars, obsvals = JunctionTrees.read_uai_evid_file("problems/asia/asia.uai.evid") -marginals = run_algo(obsvars, obsvals) +marginals = algo(obsvars, obsvals) ``` #### Example 4 @@ -57,7 +54,7 @@ Returns the expression of the junction tree algorithm up to the backward pass stage. ```@example main -backward_pass_expr = compile_algo( "problems/asia/asia.uai", last_stage = BackwardPass) +backward_pass_expr = @posterior_marginals("problems/asia/asia.uai", last_stage = BackwardPass) ``` The stages supported are: diff --git a/src/JunctionTrees.jl b/src/JunctionTrees.jl index dabdfa2..2a029e6 100644 --- a/src/JunctionTrees.jl +++ b/src/JunctionTrees.jl @@ -3,7 +3,7 @@ Main module for `JunctionTrees.jl` -- a Julia implementation of the junction tre One main function is exported from this module for public use: -- [`compile_algo`](@ref). Compiles and returns an expression that computes the posterior marginals of the model given evidence using the junction tree algorithm. +- [`@posterior_marginals`](@ref). Compiles and returns an expression that computes the posterior marginals of the model given evidence using the junction tree algorithm. # Exports @@ -16,7 +16,7 @@ using Combinatorics: combinations using MacroTools: @capture, rmlines using MLStyle: @match -export compile_algo, Factor, prod, sum, redu, norm, LastStage, ForwardPass, +export @posterior_marginals, Factor, prod, sum, redu, norm, LastStage, ForwardPass, BackwardPass, JointMarginals, UnnormalizedMarginals, Marginals import Base: diff --git a/src/graphical_transformation.jl b/src/graphical_transformation.jl index d0c0446..94325c3 100644 --- a/src/graphical_transformation.jl +++ b/src/graphical_transformation.jl @@ -55,12 +55,6 @@ Construct a tree decomposition graph based on `td_filepath`. The `td_filepath` file format is defined in: https://pacechallenge.org/2017/treewidth/. - -# Example -``` -td_filepath = "../problems/Promedus_26/Promedus_26.td" -td = compile_algo(td_filepath) -``` """ function construct_td_graph(td_filepath::AbstractString) diff --git a/src/junction_tree_algorithm.jl b/src/junction_tree_algorithm.jl index 9fd313c..dd3e584 100644 --- a/src/junction_tree_algorithm.jl +++ b/src/junction_tree_algorithm.jl @@ -1,6 +1,6 @@ """ Enumerated type used to select up to which stage an expression of the junction -tree algorithm should be returned after calling [`compile_algo`](@ref). +tree algorithm should be returned after calling [`@posterior_marginals`](@ref). """ @enum LastStage begin ForwardPass @@ -9,11 +9,18 @@ tree algorithm should be returned after calling [`compile_algo`](@ref). UnnormalizedMarginals Marginals end -@doc "When assigned to the keyword argument `last_stage` of [`compile_algo`](@ref), an expression up to and including the forward pass is returned." ForwardPass -@doc "When assigned to the keyword argument `last_stage` of [`compile_algo`](@ref), an expression up to and including the backward pass is returned." BackwardPass -@doc "When assigned to the keyword argument `last_stage` of [`compile_algo`](@ref), an expression that computes the cluster joint marginals is returned." JointMarginals -@doc "When assigned to the keyword argument `last_stage` of [`compile_algo`](@ref), an expression that computes the joint marginals is returned." UnnormalizedMarginals -@doc "When assigned to the keyword argument `last_stage` of [`compile_algo`](@ref), an expression that computes the posterior marginals is returned (default)." Marginals +@doc "When assigned to the keyword argument `last_stage` of [`@posterior_marginals`](@ref), an expression up to and including the forward pass is returned." ForwardPass +@doc "When assigned to the keyword argument `last_stage` of [`@posterior_marginals`](@ref), an expression up to and including the backward pass is returned." BackwardPass +@doc "When assigned to the keyword argument `last_stage` of [`@posterior_marginals`](@ref), an expression that computes the cluster joint marginals is returned." JointMarginals +@doc "When assigned to the keyword argument `last_stage` of [`@posterior_marginals`](@ref), an expression that computes the joint marginals is returned." UnnormalizedMarginals +@doc "When assigned to the keyword argument `last_stage` of [`@posterior_marginals`](@ref), an expression that computes the posterior marginals is returned (default)." Marginals + +macro posterior_marginals(uai_filepath, kwargs...) + posterior_marginals( + Base.eval(__module__, uai_filepath); + map(kwarg -> Pair(first(kwarg.args), Base.eval(__module__, last(kwarg.args))), kwargs)... + ) +end """ $(TYPEDSIGNATURES) @@ -36,10 +43,9 @@ of all the variables in the model. ``` package_root_dir = pathof(JunctionTrees) |> dirname |> dirname uai_filepath = joinpath(package_root_dir, "docs", "src", "problems", "paskin", "paskin.uai") -algo = compile_algo(uai_filepath) -eval(algo) +algo = @posterior_marginals(uai_filepath) obsvars, obsvals = Int64[], Int64[] -marginals = run_algo(obsvars, obsvals) +marginals = algo(obsvars, obsvals) # output @@ -56,12 +62,11 @@ marginals = run_algo(obsvars, obsvals) package_root_dir = pathof(JunctionTrees) |> dirname |> dirname uai_filepath = joinpath(package_root_dir, "docs", "src", "problems", "paskin", "paskin.uai") uai_evid_filepath = joinpath(package_root_dir, "docs", "src", "problems", "paskin", "paskin.uai.evid") -algo = compile_algo( +algo = @posterior_marginals( uai_filepath, uai_evid_filepath = uai_evid_filepath) -eval(algo) obsvars, obsvals = JunctionTrees.read_uai_evid_file(uai_evid_filepath) -marginals = run_algo(obsvars, obsvals) +marginals = algo(obsvars, obsvals) # output @@ -74,16 +79,17 @@ marginals = run_algo(obsvars, obsvals) Factor{Float64, 1}((6,), [0.6118571666785584, 0.3881428333214415]) ``` """ -function compile_algo(uai_filepath::AbstractString; - uai_evid_filepath::AbstractString = "", - td_filepath::AbstractString = "", - apply_partial_evaluation::Bool = false, - last_stage::LastStage = Marginals, - smart_root_selection::Bool = true, - factor_eltype::DataType = Float64, - use_omeinsum::Bool = false, - correct_fp_overflows::Bool = false, - ) +function posterior_marginals( + uai_filepath::AbstractString; + uai_evid_filepath::AbstractString="", + td_filepath::AbstractString="", + apply_partial_evaluation::Bool=false, + last_stage::LastStage=Marginals, + smart_root_selection::Bool=true, + factor_eltype::DataType=Float64, + use_omeinsum::Bool=false, + correct_fp_overflows::Bool=false, +) # Read PGM nvars, cards, _, factors = read_uai_file(uai_filepath, factor_eltype = factor_eltype) diff --git a/src/omeinsum.jl b/src/omeinsum.jl index 9a65150..10adc27 100644 --- a/src/omeinsum.jl +++ b/src/omeinsum.jl @@ -8,10 +8,9 @@ Speed up the sum-product in the algorithm using `OMEinsum`'s contraction routine ``` package_root_dir = pathof(JunctionTrees) |> dirname |> dirname uai_filepath = joinpath(package_root_dir, "docs", "src", "problems", "paskin", "paskin.uai") -algo = compile_algo(uai_filepath, use_omeinsum = true); -eval(algo) +algo = @posterior_marginals(uai_filepath, use_omeinsum = true); obsvars, obsvals = Int64[], Int64[] -marginals = run_algo(obsvars, obsvals) +marginals = algo(obsvars, obsvals) ``` """ function boost_algo(algo::Expr; optimizer=GreedyMethod()) diff --git a/test/junction_tree_algorithm.jl b/test/junction_tree_algorithm.jl index 3358075..b32b4d5 100644 --- a/test/junction_tree_algorithm.jl +++ b/test/junction_tree_algorithm.jl @@ -22,13 +22,12 @@ using JunctionTrees @debug " Test: Default (Min-fill heuristic)" @debug " Compiling algo..." - algo = compile_algo( - uai_filepath; + algo = @posterior_marginals( + uai_filepath, uai_evid_filepath = uai_evid_filepath, ) - eval(algo) @debug " Running algo..." - marginals = run_algo(obsvars, obsvals) |> x -> map(y -> y.vals, x) + marginals = algo(obsvars, obsvals) |> x -> map(y -> y.vals, x) @test isapprox(marginals, reference_marginals, atol=0.03) # ------------------------------------------------------------------------------ @@ -37,15 +36,14 @@ using JunctionTrees @debug " Test: Using an existing junction tree" @debug " Compiling algo..." - algo = compile_algo( - uai_filepath; + algo = @posterior_marginals( + uai_filepath, uai_evid_filepath = uai_evid_filepath, td_filepath = td_filepath, correct_fp_overflows = true, ) - eval(algo) @debug " Running algo..." - marginals = run_algo(obsvars, obsvals) |> x -> map(y -> y.vals, x) + marginals = algo(obsvars, obsvals) |> x -> map(y -> y.vals, x) @test isapprox(marginals, reference_marginals, atol=0.03) # ------------------------------------------------------------------------------ @@ -54,15 +52,14 @@ using JunctionTrees @debug " Test: Float32 factor values" @debug " Compiling algo..." - algo = compile_algo( - uai_filepath; + algo = @posterior_marginals( + uai_filepath, uai_evid_filepath = uai_evid_filepath, td_filepath = td_filepath, factor_eltype = Float32, ) - eval(algo) @debug " Running algo..." - marginals = run_algo(obsvars, obsvals) |> x -> map(y -> y.vals, x) + marginals = algo(obsvars, obsvals) |> x -> map(y -> y.vals, x) @test isapprox(marginals, reference_marginals) # ------------------------------------------------------------------------------ @@ -71,16 +68,15 @@ using JunctionTrees @debug " Test: Partial evaluation" @debug " Compiling algo..." - algo = compile_algo( - uai_filepath; + algo = @posterior_marginals( + uai_filepath, uai_evid_filepath = uai_evid_filepath, td_filepath = td_filepath, apply_partial_evaluation = true, correct_fp_overflows = true, ) - eval(algo) @debug " Running algo..." - marginal_factors = run_algo(obsvars, obsvals) + marginal_factors = algo(obsvars, obsvals) # Filter the observed variables from the obtained solution marginal_factors_filtered = filter(x -> !(x.vars[1] in obsvars) , marginal_factors) marginals = map(y -> y.vals, marginal_factors_filtered) @@ -94,15 +90,14 @@ using JunctionTrees @debug " Test: OMEinsum" @debug " Compiling algo..." - algo = compile_algo( - uai_filepath; + algo = @posterior_marginals( + uai_filepath, uai_evid_filepath = uai_evid_filepath, td_filepath = td_filepath, use_omeinsum = true, ) - eval(algo) @debug " Running algo..." - marginals = run_algo(obsvars, obsvals) |> x -> map(y -> y.vals, x) + marginals = algo(obsvars, obsvals) |> x -> map(y -> y.vals, x) @test isapprox(marginals, reference_marginals, atol=0.03) end diff --git a/test/uai2014.jl b/test/uai2014.jl index 1082b24..3293865 100644 --- a/test/uai2014.jl +++ b/test/uai2014.jl @@ -50,16 +50,15 @@ for benchmark in benchmarks println(" Test type: default") println(" Compiling algo...") - algo = compile_algo( - uai_filepath; + algo = @posterior_marginals( + uai_filepath, uai_evid_filepath = uai_evid_filepath, td_filepath = td_filepath, factor_eltype = Float64, correct_fp_overflows = true, ) - eval(algo) println(" Running algo...") - marginals = run_algo(obsvars, obsvals) |> x -> map(y -> y.vals, x) + marginals = algo(obsvars, obsvals) |> x -> map(y -> y.vals, x) @test isapprox(marginals, reference_marginals, atol = 0.01) # ------------------------------------------------------------------------------ @@ -68,16 +67,15 @@ for benchmark in benchmarks # println(" Test type: OMEinsum backend") # println(" Compiling algo...") - # algo = compile_algo( - # uai_filepath; + # algo = @posterior_marginals( + # uai_filepath, # uai_evid_filepath = uai_evid_filepath, # td_filepath = td_filepath, # factor_eltype = Float64, # use_omeinsum = true, # ) - # eval(algo) # println(" Running algo...") - # marginals = run_algo(obsvars, obsvals) |> x -> map(y -> y.vals, x) + # marginals = algo(obsvars, obsvals) |> x -> map(y -> y.vals, x) # @test isapprox(marginals, reference_marginals, atol = 0.01) end