diff --git a/src/mul.jl b/src/mul.jl index 6288dd0..a6c75b5 100644 --- a/src/mul.jl +++ b/src/mul.jl @@ -131,7 +131,7 @@ end By default, elementwise multiplication will be performed. """ function automul(M1::MPS, M2::MPS; tag_row::String="", tag_shared::String="", - tag_col::String="", alg="naive", kwargs...) + tag_col::String="", alg="naive", cutoff=1e-30, kwargs...) if in(:maxbonddim, keys(kwargs)) error("Illegal keyward parameter: maxbonddim. Use maxdim instead!") end @@ -153,13 +153,13 @@ function automul(M1::MPS, M2::MPS; tag_row::String="", tag_shared::String="", M1_, M2_ = preprocess(matmul, M1_, M2_) M1_, M2_ = preprocess(ewmul, M1_, M2_) - M = FastMPOContractions.contract_mpo_mpo(M1_, M2_; alg=alg, kwargs...) + M = FastMPOContractions.contract_mpo_mpo(M1_, M2_; alg=alg, cutoff=cutoff, kwargs...) M = Quantics.postprocess(matmul, M) M = Quantics.postprocess(ewmul, M) if in(:maxdim, keys(kwargs)) - truncate!(M; maxdim=kwargs[:maxdim]) + truncate!(M; maxdim=kwargs[:maxdim], cutoff=cutoff) end return asMPS(M) diff --git a/test/mul_tests.jl b/test/mul_tests.jl index 294fc08..ad4a83f 100644 --- a/test/mul_tests.jl +++ b/test/mul_tests.jl @@ -169,13 +169,14 @@ end end @testset "PartitionedMPS" begin - @testset "batchedmatmul" for T in [Float64] + @testset "batchedmatmul" for T in [Float64, ComplexF64] """ C(x, z, k) = sum_y A(x, y, k) * B(y, z, k) """ nbit = 2 D = 2 cutoff = 1e-25 + maxdim = typemax(Int) sx = [Index(2, "Qubit,x=$n") for n in 1:nbit] sy = [Index(2, "Qubit,y=$n") for n in 1:nbit] sz = [Index(2, "Qubit,z=$n") for n in 1:nbit] @@ -213,10 +214,10 @@ end @test b ≈ MPS(b_) ab = Quantics.automul( - a_, b_; tag_row="x", tag_shared="y", tag_col="z", alg="fit", cutoff + a_, b_; tag_row="x", tag_shared="y", tag_col="z", alg="fit", cutoff, maxdim ) ab_ref = Quantics.automul( - a, b; tag_row="x", tag_shared="y", tag_col="z", alg="fit", cutoff + a, b; tag_row="x", tag_shared="y", tag_col="z", alg="fit", cutoff, maxdim ) @test MPS(ab)≈ab_ref rtol=10 * sqrt(cutoff)