Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/T4AITensorCompat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ export random_mps, random_mpo # Random tensor train generation
export product # Official API name (match ITensorMPS)
export apply # Backwards-compatible alias

abstract type AbstractTTN end # Abstract type for tree tensor network

include("defaults.jl")
include("tensortrain.jl")
Expand Down
59 changes: 16 additions & 43 deletions src/contraction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,50 +67,14 @@ result = contract(M1, M2) # Using default fit algorithm
result = contract(M1, M2; alg=Algorithm"densitymatrix"(), maxdim=100)
```
"""
function contract(M1::TensorTrain, M2::TensorTrain; alg=Algorithm"fit"(), cutoff::Real=default_cutoff(), maxdim::Int=default_maxdim(), nsweeps::Int=default_nsweeps(), kwargs...)::TensorTrain
# Detect MPS-like vs MPO-like by counting physical indices per site
is_mps1 = begin
sites_per_tensor = siteinds(M1)
length(M1) > 0 && all(length(s) == 1 for s in sites_per_tensor)
end
is_mps2 = begin
sites_per_tensor = siteinds(M2)
length(M2) > 0 && all(length(s) == 1 for s in sites_per_tensor)
end

# Convert to ITensorMPS types based on detected type
M1_ = is_mps1 ? ITensorMPS.MPS(M1) : ITensorMPS.MPO(M1)
M2_ = is_mps2 ? ITensorMPS.MPS(M2) : ITensorMPS.MPO(M2)
function contract(M1::TensorTrain, M2::TensorTrain; alg::Union{String, Algorithm}=Algorithm"fit"(), cutoff::Real=default_cutoff(), maxdim::Int=default_maxdim(), nsweeps::Int=default_nsweeps(), kwargs...)::TensorTrain
# Convert both M1 and M2 to MPO format for T4A implementation
M1_ = ITensorMPS.MPO(M1)
M2_ = ITensorMPS.MPO(M2)

alg = Algorithm(alg)

# Handle MPO * MPS case
if !is_mps1 && is_mps2
# MPO * MPS: use ITensorMPS.contract
# Only pass nsweeps for fit algorithm, otherwise remove it from kwargs
if alg == Algorithm"fit"()
result = ITensorMPS.contract(M1_, M2_; alg=alg, cutoff=cutoff, maxdim=maxdim, nsweeps=nsweeps, kwargs...)
else
# Remove nsweeps from kwargs for non-fit algorithms
kwargs_dict = Dict(kwargs)
delete!(kwargs_dict, :nsweeps)
result = ITensorMPS.contract(M1_, M2_; alg=alg, cutoff=cutoff, maxdim=maxdim, kwargs_dict...)
end
return TensorTrain(result)
elseif is_mps1 && !is_mps2
# MPS * MPO: use ITensorMPS.contract (commutative)
# Only pass nsweeps for fit algorithm, otherwise remove it from kwargs
if alg == Algorithm"fit"()
result = ITensorMPS.contract(M2_, M1_; alg=alg, cutoff=cutoff, maxdim=maxdim, nsweeps=nsweeps, kwargs...)
else
# Remove nsweeps from kwargs for non-fit algorithms
kwargs_dict = Dict(kwargs)
delete!(kwargs_dict, :nsweeps)
result = ITensorMPS.contract(M2_, M1_; alg=alg, cutoff=cutoff, maxdim=maxdim, kwargs_dict...)
end
return TensorTrain(result)
else
# MPO * MPO: use T4AITensorCompat algorithms
# Use T4AITensorCompat algorithms
if alg == Algorithm"densitymatrix"()
return TensorTrain(ContractionImpl.contract_densitymatrix(M1_, M2_; cutoff, maxdim, kwargs...))
elseif alg == Algorithm"fit"()
Expand All @@ -121,7 +85,6 @@ function contract(M1::TensorTrain, M2::TensorTrain; alg=Algorithm"fit"(), cutoff
return TensorTrain(ITensorMPS.contract(M1_, M2_; alg=Algorithm"naive"(), cutoff, maxdim, kwargs...))
else
error("Unknown algorithm: $alg")
end
end
end

Expand Down Expand Up @@ -197,7 +160,7 @@ using `contract` internally and adjusting prime levels for compatibility.
- `nsweeps::Int`: Number of sweeps for variational algorithms. Defaults to `default_nsweeps()`.
- `kwargs...`: Additional keyword arguments passed to contract
"""
function product(A::TensorTrain, Ψ::TensorTrain; alg=Algorithm"fit"(), cutoff=default_cutoff(), maxdim=default_maxdim(), nsweeps=default_nsweeps(), kwargs...)
function product(A::TensorTrain, Ψ::TensorTrain; alg::Union{String, Algorithm}=Algorithm"fit"(), cutoff=default_cutoff(), maxdim=default_maxdim(), nsweeps=default_nsweeps(), kwargs...)
if :algorithm ∈ keys(kwargs)
error("keyword argument :algorithm is not allowed")
end
Expand All @@ -210,6 +173,16 @@ function product(A::TensorTrain, Ψ::TensorTrain; alg=Algorithm"fit"(), cutoff=d
@warn "cutoff is too small for densitymatrix algorithm. Use fit algorithm instead."
end

# Check that A is MPO-like (has 2 or more physical indices per site)
is_mpo_like_A = begin
sites_per_tensor = siteinds(A)
length(A) > 0 && all(length(s) >= 2 for s in sites_per_tensor)
end

if !is_mpo_like_A
error("First argument `A` must be MPO-like (have 2 or more physical indices per site), but got a tensor train with $(length(A) > 0 ? length(siteinds(A)[1]) : 0) physical index(es) per site")
end

# Detect MPS-like vs MPO-like by counting physical indices per site:
# MPS tensors have 1 physical index per site, MPO tensors have 2 per site.
# This is robust to boundary tensors having fewer link indices.
Expand Down
21 changes: 0 additions & 21 deletions src/ext/T4AITensorCompatChainRulesCoreExt.jl

This file was deleted.

Loading
Loading