diff --git a/src/belief.jl b/src/belief.jl index eede996..eba4dd2 100644 --- a/src/belief.jl +++ b/src/belief.jl @@ -80,7 +80,7 @@ function _collect_message!(vectors_out::Vector, t::AbstractArray, vectors_in::Ve # TODO: speed up if needed! code = star_code(length(vectors_in)) cost, gradient = cost_and_gradient(code, (t, vectors_in...)) - for (o, g) in zip(vectors_out, gradient[2:end]) + for (o, g) in zip(vectors_out, conj.(gradient[2:end])) o .= g end return cost[] @@ -115,7 +115,7 @@ Run the belief propagation algorithm, and return the final state and the informa ### Keyword Arguments - `max_iter::Int=100`: the maximum number of iterations -- `tol::Float64=1e-6`: the tolerance for the convergence +- `tol::Float64=1e-6`: the tolerance for the convergence, the convergence is checked by infidelity of messages in consecutive iterations. For complex numbers, the converged message may be different only by a phase factor. - `damping::Float64=0.2`: the damping factor for the message update, updated-message = damping * old-message + (1 - damping) * new-message """ function belief_propagate(bp::BeliefPropgation; kwargs...) @@ -133,7 +133,7 @@ function belief_propagate!(bp::BeliefPropgation, state::BPState{T}; max_iter::In collect_message!(bp, state; normalize = true) process_message!(state; normalize = true, damping = damping) # check convergence - if all(iv -> all(it -> isapprox(state.message_in[iv][it], pre_message_in[iv][it], atol = tol), 1:length(bp.v2t[iv])), 1:num_variables(bp)) + if all(iv -> all(it -> message_converged(state.message_in[iv][it], pre_message_in[iv][it], atol = tol), 1:length(bp.v2t[iv])), 1:num_variables(bp)) return BPInfo(true, i) end pre_message_in = deepcopy(state.message_in) @@ -141,6 +141,13 @@ function belief_propagate!(bp::BeliefPropgation, state::BPState{T}; max_iter::In return BPInfo(false, max_iter) end +# check if two messages are converged by fidelity (needed for complex numbers) +function message_converged(a, b; atol) + a_norm = norm(a) + b_norm = norm(b) + return isapprox(a_norm, b_norm, atol=atol) && isapprox(sqrt(abs(a' * b)), a_norm, atol=atol) +end + # if BP is exact and converged (e.g. tree like), the result should be the same as the tensor network contraction function contraction_results(state::BPState{T}) where {T} return [sum(reduce((x, y) -> x .* y, mi)) for mi in state.message_in] diff --git a/src/mar.jl b/src/mar.jl index 3e399b4..0eaa154 100644 --- a/src/mar.jl +++ b/src/mar.jl @@ -78,6 +78,7 @@ probabilities of the queried variables, represented by tensors. function marginals(tn::TensorNetworkModel; usecuda = false, rescale = true)::Dict{Vector{Int}} # sometimes, the cost can overflow, then we need to rescale the tensors during contraction. cost, grads = cost_and_gradient(tn.code, (adapt_tensors(tn; usecuda, rescale)...,)) + grads = conj.(grads) @debug "cost = $cost" ixs = OMEinsum.getixsv(tn.code) queryvars = ixs[tn.unity_tensors_idx] diff --git a/test/belief.jl b/test/belief.jl index 150c302..1d43a56 100644 --- a/test/belief.jl +++ b/test/belief.jl @@ -46,21 +46,21 @@ end @testset "belief propagation" begin n = 5 chi = 3 - mps_uai = TensorInference.random_tensor_train_uai(Float64, n, chi) + mps_uai = TensorInference.random_tensor_train_uai(ComplexF64, n, chi) bp = BeliefPropgation(mps_uai) @test TensorInference.initial_state(bp) isa TensorInference.BPState - state, info = belief_propagate(bp) + state, info = belief_propagate(bp; max_iter=100, tol=1e-8) @test info.converged @test info.iterations < 20 mars = marginals(state) tnet = TensorNetworkModel(mps_uai) mars_tnet = marginals(tnet) for v in 1:TensorInference.num_variables(bp) - @test mars[[v]] ≈ mars_tnet[[v]] atol=1e-6 + @test mars[[v]] ≈ mars_tnet[[v]] atol=1e-4 end end -@testset "belief propagation on circle" begin +@testset "belief propagation on circle (Real)" begin n = 10 chi = 3 mps_uai = TensorInference.random_tensor_train_uai(Float64, n, chi; periodic=true) @@ -78,6 +78,25 @@ end end end + +@testset "belief propagation on circle (Complex)" begin + n = 10 + chi = 3 + mps_uai = TensorInference.random_tensor_train_uai(ComplexF64, n, chi; periodic=true) # FIXME: fail to converge + bp = BeliefPropgation(mps_uai) + @test TensorInference.initial_state(bp) isa TensorInference.BPState + state, info = belief_propagate(bp; max_iter=100, tol=1e-6) + @test info.converged + @test info.iterations < 100 + contraction_res = TensorInference.contraction_results(state) + tnet = TensorNetworkModel(mps_uai) + mars = marginals(state) + mars_tnet = marginals(tnet) + for v in 1:TensorInference.num_variables(bp) + @test TensorInference.message_converged(mars[[v]], mars_tnet[[v]]; atol=1e-4) + end +end + @testset "marginal uai2014" begin for problem in [problem_from_artifact("uai2014", "MAR", "Promedus", 14), problem_from_artifact("uai2014", "MAR", "ObjectDetection", 42)] optimizer = TreeSA(ntrials = 1, niters = 5, βs = 0.1:0.1:100)