diff --git a/src/fitalgorithm.jl b/src/fitalgorithm.jl index 09ea5c6..a5b6cd8 100644 --- a/src/fitalgorithm.jl +++ b/src/fitalgorithm.jl @@ -5,7 +5,7 @@ import Base: copy """ Contract M1 and M2, and return the result as an MPO. """ -function contract_fit(M1::MPO, M2::MPO; init = nothing, kwargs...)::MPO +function contract_fit(M1::MPO, M2::MPO; init = nothing, nsweeps=1, kwargs...)::MPO M2_ = MPS([M2[v] for v in eachindex(M2)]) if init === nothing init_MPO::MPO = ITensors.contract(M1, M2; alg = "zipup", kwargs...) @@ -13,7 +13,7 @@ function contract_fit(M1::MPO, M2::MPO; init = nothing, kwargs...)::MPO else init = MPS([init[v] for v in eachindex(M2)]) end - M12_ = contract_fit(M1, M2_; init_mps = init, kwargs...) + M12_ = contract_fit(M1, M2_; init_mps = init, nsweeps = nsweeps, kwargs...) M12 = MPO([M12_[v] for v in eachindex(M1)]) return M12 diff --git a/src/util.jl b/src/util.jl index 1a074ba..f064906 100644 --- a/src/util.jl +++ b/src/util.jl @@ -41,6 +41,8 @@ function _log_or_not_dot(y::MPO, A::MPO, x::MPO, loginner::Bool; kwargs...)::Num sim!(linkinds, ydag) check_hascommoninds(siteinds, A, y) O = ydag[1] * A[1] * x[1] + + log_inner_tot = 0.0 if loginner normO = norm(O) log_inner_tot = log(normO)