Skip to content

Commit

Permalink
Implement a bases api and rework abstract types
Browse files Browse the repository at this point in the history
  • Loading branch information
akirakyle committed Nov 25, 2024
1 parent 3d63b38 commit eab995e
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 69 deletions.
58 changes: 29 additions & 29 deletions src/abstract_types.jl
Original file line number Diff line number Diff line change
@@ -1,54 +1,54 @@
"""
Abstract base class for `Bra` and `Ket` states.
Abstract type for `Bra` and `Ket` states.
The state vector class stores the coefficients of an abstract state
in respect to a certain basis. These coefficients are stored in the
`data` field and the basis is defined in the `basis`
field.
The state vector type stores an abstract state with respect to a certain
Hilbert space basis.
All deriving types must define the `fullbasis` function which
returns the state vector's underlying `Basis`.
"""
abstract type StateVector{B,T} end
abstract type AbstractKet{B,T} <: StateVector{B,T} end
abstract type AbstractBra{B,T} <: StateVector{B,T} end
abstract type StateVector{B<:Basis} end
abstract type AbstractKet{B} <: StateVector{B} end
abstract type AbstractBra{B} <: StateVector{B} end

"""
Abstract base class for all operators.
Abstract type for all operators which represent linear maps between two
Hilbert spaces with respect to a given basis in each space.
All deriving operator classes have to define the fields
`basis_l` and `basis_r` defining the left and right side bases.
All deriving operator types must define the `fullbasis` function which
returns the operator's underlying `OperatorBasis`.
For fast time evolution also at least the function
`mul!(result::Ket,op::AbstractOperator,x::Ket,alpha,beta)` should be
implemented. Many other generic multiplication functions can be defined in
terms of this function and are provided automatically.
See [TODO: reference operators.md in docs]
"""
abstract type AbstractOperator{BL,BR} end
abstract type AbstractOperator{B<:OperatorBasis} end

"""
Base class for all super operator classes.
Super operators are bijective mappings from operators given in one specific
basis to operators, possibly given in respect to another, different basis.
To embed super operators in an algebraic framework they are defined with a
left hand basis `basis_l` and a right hand basis `basis_r` where each of
them again consists of a left and right hand basis.
```math
A_{bl_1,bl_2} = S_{(bl_1,bl_2) ↔ (br_1,br_2)} B_{br_1,br_2}
\\\\
A_{br_1,br_2} = B_{bl_1,bl_2} S_{(bl_1,bl_2) ↔ (br_1,br_2)}
Abstract type for all super-operators which represent linear maps between two
operator spaces with respect to a given basis for each space.
All deriving operator types must define the `fullbasis` function which
returns the operator's underlying `SuperOperatorBasis`.
See [TODO: reference superoperators.md in docs]
```
"""
abstract type AbstractSuperOperator{B1,B2} end
abstract type AbstractSuperOperator{B<:SuperOperatorBasis} end

function summary(stream::IO, x::AbstractOperator)
print(stream, "$(typeof(x).name.name)(dim=$(length(x.basis_l))x$(length(x.basis_r)))\n")
if samebases(x)
b = fullbasis(x)
print(stream, "$(typeof(x).name.name)(dim=$(length(b.basis_l))x$(length(b.basis_r)))\n")
if samebases(b)
print(stream, " basis: ")
show(stream, basis(x))
show(stream, basis(b))
else
print(stream, " basis left: ")
show(stream, x.basis_l)
show(stream, b.basis_l)
print(stream, "\n basis right: ")
show(stream, x.basis_r)
show(stream, b.basis_r)
end
end

Expand Down
68 changes: 55 additions & 13 deletions src/bases.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""
Abstract base class for all specialized bases.
Abstract type for all specialized bases.
The Basis class is meant to specify a basis of the Hilbert space of the
studied system. Besides basis specific information all subclasses must
implement a shape variable which indicates the dimension of the used
The `Basis` type is meant to specify a basis of the Hilbert space of the
studied system. Besides basis specific information, all concrete subtypes must
implement a shape field which indicates the dimension of the used
Hilbert space. For a spin-1/2 Hilbert space this would be the
vector `[2]`. A system composed of two spins would then have a
shape vector `[2 2]`.
Expand All @@ -13,6 +13,26 @@ class.
"""
abstract type Basis end

"""
Parametric composite type for all operator bases.
See [TODO: reference operators.md in docs]
"""
struct OperatorBasis{BL<:Basis,BR<:Basis}
left::BL
right::BR
end

"""
Parametric composite type for all superoperator bases.
See [TODO: reference superoperators.md in docs]
"""
struct SuperOperatorBasis{BL<:OperatorBasis,BR<:OperatorBasis}
left::BL
right::BR
end

"""
length(b::Basis)
Expand All @@ -25,11 +45,24 @@ Base.length(b::Basis) = prod(b.shape)
Return the basis of an object.
If it's ambiguous, e.g. if an operator has a different left and right basis,
an [`IncompatibleBases`](@ref) error is thrown.
If it's ambiguous, e.g. if an operator or superoperator has a different
left and right basis, an [`IncompatibleBases`](@ref) error is thrown.
"""
function basis end

basis(b::OperatorBasis) = (check_samebases(b); b.left)
basis(b::SuperOperatorBasis) = (check_samebases(b); b.left.left)

"""
fullbasis(a)
Returns B where B<:Basis when typeof(a)<:StateVector.
Returns B where B<:OperatorBasis when typeof(a)<:AbstractOperator.
Returns B where B<:SuperOperatorBasis for typeof(a)<:AbstractSuperOperator.
"""
function fullbasis end


"""
GenericBasis(N)
Expand Down Expand Up @@ -80,13 +113,11 @@ contains another CompositeBasis.
tensor(b1::Basis, b2::Basis) = CompositeBasis([length(b1); length(b2)], (b1, b2))
tensor(b1::CompositeBasis, b2::CompositeBasis) = CompositeBasis([b1.shape; b2.shape], (b1.bases..., b2.bases...))
function tensor(b1::CompositeBasis, b2::Basis)
N = length(b1.bases)
shape = vcat(b1.shape, length(b2))
bases = (b1.bases..., b2)
CompositeBasis(shape, bases)
end
function tensor(b1::Basis, b2::CompositeBasis)
N = length(b2.bases)
shape = vcat(length(b1), b2.shape)
bases = (b1, b2.bases...)
CompositeBasis(shape, bases)
Expand Down Expand Up @@ -160,24 +191,35 @@ macro samebases(ex)
end

"""
samebases(a)
samebases(a, b)
Test if two objects have the same bases.
Test if one object has the same left and right bases or
if two objects have the same bases
"""
samebases(b1::Basis, b2::Basis) = b1==b2
samebases(b1::Tuple{Basis, Basis}, b2::Tuple{Basis, Basis}) = b1==b2 # for checking superoperators
samebases(b1::Basis, b2::Basis) = (b1 == b2)
samebases(b::OperatorBasis) = (b.left == b.right)
samebases(b1::OperatorBasis, b2::OperatorBasis) = ((b1.left == b2.left) && (b1.right == b2.right))
samebases(b::SuperOperatorBasis) = samebases(b.left, b.right)
samebases(b1::SuperOperatorBasis, b2::SuperOperatorBasis) = (samebases(b1.left, b2.left) && samebases(b1.right, b2.right))

"""
check_samebases(a)
check_samebases(a, b)
Throw an [`IncompatibleBases`](@ref) error if the objects don't have
the same bases.
Throw an [`IncompatibleBases`](@ref) error if the two objects don't have
the same bases or the one object doesn't have the same left and right bases.
"""
function check_samebases(b1, b2)
if BASES_CHECK[] && !samebases(b1, b2)
throw(IncompatibleBases())
end
end
function check_samebases(b)
if BASES_CHECK[] && !samebases(b)
throw(IncompatibleBases())
end
end


"""
Expand Down
12 changes: 6 additions & 6 deletions src/expect_variance.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,33 @@
If an `index` is given, it assumes that `op` is defined in the subsystem specified by this number.
"""
function expect(indices, op::AbstractOperator{B1,B2}, state::AbstractOperator{B3,B3}) where {B1,B2,B3<:CompositeBasis}
function expect(indices, op::AbstractOperator{OperatorBasis{B1,B2}}, state::AbstractOperator{OperatorBasis{B3,B3}}) where {B1,B2,B3<:CompositeBasis}
N = length(state.basis_l.shape)
indices_ = complement(N, indices)
expect(op, ptrace(state, indices_))
end

expect(index::Integer, op::AbstractOperator{B1,B2}, state::AbstractOperator{B3,B3}) where {B1,B2,B3<:CompositeBasis} = expect([index], op, state)
expect(index::Integer, op::AbstractOperator{OperatorBasis{B1,B2}}, state::AbstractOperator{OperatorBasis{B3,B3}}) where {B1,B2,B3<:CompositeBasis} = expect([index], op, state)
expect(op::AbstractOperator, states::Vector) = [expect(op, state) for state=states]
expect(indices, op::AbstractOperator, states::Vector) = [expect(indices, op, state) for state=states]

expect(op::AbstractOperator{B1,B2}, state::AbstractOperator{B2,B2}) where {B1,B2} = tr(op*state)
expect(op::AbstractOperator{OperatorBasis{B1,B2}}, state::AbstractOperator{OperatorBasis{B2,B2}}) where {B1,B2} = tr(op*state)

"""
variance(index, op, state)
If an `index` is given, it assumes that `op` is defined in the subsystem specified by this number
"""
function variance(indices, op::AbstractOperator{B,B}, state::AbstractOperator{BC,BC}) where {B,BC<:CompositeBasis}
function variance(indices, op::AbstractOperator{OperatorBasis{B,B}}, state::AbstractOperator{OperatorBasis{BC,BC}}) where {B,BC<:CompositeBasis}
N = length(state.basis_l.shape)
indices_ = complement(N, indices)
variance(op, ptrace(state, indices_))
end

variance(index::Integer, op::AbstractOperator{B,B}, state::AbstractOperator{BC,BC}) where {B,BC<:CompositeBasis} = variance([index], op, state)
variance(index::Integer, op::AbstractOperator{OperatorBasis{B,B}}, state::AbstractOperator{OperatorBasis{BC,BC}}) where {B,BC<:CompositeBasis} = variance([index], op, state)
variance(op::AbstractOperator, states::Vector) = [variance(op, state) for state=states]
variance(indices, op::AbstractOperator, states::Vector) = [variance(indices, op, state) for state=states]

function variance(op::AbstractOperator{B,B}, state::AbstractOperator{B,B}) where B
function variance(op::AbstractOperator{OperatorBasis{B,B}}, state::AbstractOperator{OperatorBasis{B,B}}) where B
expect(op*op, state) - expect(op, state)^2
end
8 changes: 5 additions & 3 deletions src/identityoperator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,17 @@ type which has to a subtype of [`AbstractOperator`](@ref) as well as the number
to be used in the identity matrix.
"""
identityoperator(::Type{T}, ::Type{S}, b1::Basis, b2::Basis) where {T<:AbstractOperator,S} = throw(ArgumentError("Identity operator not defined for operator type $T."))
identityoperator(::Type{T}, ::Type{S}, b::OperatorBasis) where {T<:AbstractOperator,S} = identityoperator(T,S,b.left,b.right)
identityoperator(::Type{T}, ::Type{S}, b::Basis) where {T<:AbstractOperator,S} = identityoperator(T,S,b,b)
identityoperator(::Type{T}, bases::Basis...) where T<:AbstractOperator = identityoperator(T,eltype(T),bases...)
identityoperator(::Type{T}, b::OperatorBasis) where {T<:AbstractOperator} = identityoperator(T,eltype(T),b)
identityoperator(::Type{T}, bases::Basis...) where {T<:AbstractOperator} = identityoperator(T,eltype(T),bases...)
identityoperator(b::Basis) = identityoperator(ComplexF64,b)
identityoperator(op::T) where {T<:AbstractOperator} = identityoperator(T, op.basis_l, op.basis_r)
identityoperator(op::T) where {T<:AbstractOperator} = identityoperator(T, fullbasis(op))

# Catch case where eltype cannot be inferred from type; this is a bit hacky
identityoperator(::Type{T}, ::Type{Any}, b1::Basis, b2::Basis) where T<:AbstractOperator = identityoperator(T, ComplexF64, b1, b2)

identityoperator(b1::Basis, b2::Basis) = identityoperator(ComplexF64, b1, b2)

"""Prepare the identity superoperator over a given space."""
function identitysuperoperator end
function identitysuperoperator end
24 changes: 12 additions & 12 deletions src/julia_base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,19 @@ addnumbererror() = throw(ArgumentError("Can't add or subtract a number and an op
# States
##

-(a::T) where {T<:StateVector} = T(a.basis, -a.data)
-(a::T) where {T<:StateVector} = T(a.basis, -a.data) # FIXME
*(a::StateVector, b::Number) = b*a
copy(a::T) where {T<:StateVector} = T(a.basis, copy(a.data))
length(a::StateVector) = length(a.basis)::Int
basis(a::StateVector) = a.basis
copy(a::T) where {T<:StateVector} = T(a.basis, copy(a.data)) # FIXME
length(a::StateVector) = length(basis(a))::Int
basis(a::StateVector) = fullbasis(a)
directsum(x::StateVector...) = reduce(directsum, x)

# Array-like functions
Base.size(x::StateVector) = size(x.data)
@inline Base.axes(x::StateVector) = axes(x.data)
Base.size(x::StateVector) = size(x.data) # FIXME
@inline Base.axes(x::StateVector) = axes(x.data) #FIXME
Base.ndims(x::StateVector) = 1
Base.ndims(::Type{<:StateVector}) = 1
Base.eltype(x::StateVector) = eltype(x.data)
Base.eltype(x::StateVector) = eltype(x.data) # FIXME

# Broadcasting
Base.broadcastable(x::StateVector) = x
Expand All @@ -32,9 +32,9 @@ Base.adjoint(a::StateVector) = dagger(a)
# Operators
##

length(a::AbstractOperator) = length(a.basis_l)::Int*length(a.basis_r)::Int
basis(a::AbstractOperator) = (check_samebases(a); a.basis_l)
basis(a::AbstractSuperOperator) = (check_samebases(a); a.basis_l[1])
length(a::AbstractOperator) = (b=fullbasis(a); length(b.basis_l)::Int*length(b.basis_r)::Int)
basis(a::AbstractOperator) = (b=fullbasis(a); check_samebases(b); b.left)
basis(a::AbstractSuperOperator) = (b=fullbasis(a); check_samebases(b); b.left.left)

# Ensure scalar broadcasting
Base.broadcastable(x::AbstractOperator) = Ref(x)
Expand All @@ -60,11 +60,11 @@ Operator exponential.
"""
exp(op::AbstractOperator) = throw(ArgumentError("exp() is not defined for this type of operator: $(typeof(op)).\nTry to convert to dense operator first with dense()."))

Base.size(op::AbstractOperator) = (length(op.basis_l),length(op.basis_r))
Base.size(op::AbstractOperator) = (b=fullbasis(op); (length(b.left),length(b.right)))
function Base.size(op::AbstractOperator, i::Int)
i < 1 && throw(ErrorException("dimension index is < 1"))
i > 2 && return 1
i==1 ? length(op.basis_l) : length(op.basis_r)
i==1 ? length(fullbasis(op).left) : length(fullbasis(op).right)
end

Base.adjoint(a::AbstractOperator) = dagger(a)
Expand Down
4 changes: 2 additions & 2 deletions src/julia_linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ tr(x::AbstractOperator) = arithmetic_unary_error("Trace", x)
Norm of the given bra or ket state.
"""
norm(x::StateVector) = norm(x.data)
norm(x::StateVector) = norm(x.data) # FIXME

"""
normalize(x::StateVector)
Expand All @@ -31,7 +31,7 @@ normalize(x::StateVector) = x/norm(x)
In-place normalization of the given bra or ket so that `norm(x)` is one.
"""
normalize!(x::StateVector) = (normalize!(x.data); x)
normalize!(x::StateVector) = (normalize!(x.data); x) # FIXME

"""
normalize(op)
Expand Down
14 changes: 10 additions & 4 deletions src/linalg.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
samebases(a::AbstractOperator) = samebases(a.basis_l, a.basis_r)::Bool
samebases(a::AbstractOperator, b::AbstractOperator) = samebases(a.basis_l, b.basis_l)::Bool && samebases(a.basis_r, b.basis_r)::Bool
check_samebases(a::Union{AbstractOperator, AbstractSuperOperator}) = check_samebases(a.basis_l, a.basis_r)
multiplicable(a::AbstractOperator, b::AbstractOperator) = multiplicable(a.basis_r, b.basis_l)
samebases(a::AbstractOperator) = samebases(fullbasis(a))::Bool
samebases(a::AbstractSuperOperator) = samebases(fullbasis(a))::Bool
samebases(a::AbstractOperator, b::AbstractOperator) = samebases(fullbasis(a), fullbasis(b))::Bool
samebases(a::AbstractSuperOperator, b::AbstractSuperOperator) = samebases(fullbasis(a), fullbasis(b))::Bool
check_samebases(a::AbstractOperator) = check_samebases(fullbasis(a))::Bool
check_samebases(a::AbstractSuperOperator) = check_samebases(fullbasis(a))::Bool
check_samebases(a::AbstractOperator, b::AbstractOperator) = check_samebases(fullbasis(a), fullbasis(b))::Bool
check_samebases(a::AbstractSuperOperator, b::AbstractSuperOperator) = check_samebases(fullbasis(a), fullbasis(b))::Bool

multiplicable(a::AbstractOperator, b::AbstractOperator) = multiplicable(fullbasis(a).right, fullbasis(b).left)
dagger(a::AbstractOperator) = arithmetic_unary_error("Hermitian conjugate", a)
transpose(a::AbstractOperator) = arithmetic_unary_error("Transpose", a)
directsum(a::AbstractOperator...) = reduce(directsum, a)
Expand Down

0 comments on commit eab995e

Please sign in to comment.