Skip to content

Commit

Permalink
edits
Browse files Browse the repository at this point in the history
  • Loading branch information
aarontrowbridge committed Jun 23, 2023
1 parent 88aeaf7 commit 980a4fa
Showing 1 changed file with 21 additions and 3 deletions.
24 changes: 21 additions & 3 deletions src/problem_templates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,16 @@ function UnitarySmoothPulseProblem(
constraints::Vector{<:AbstractConstraint}=AbstractConstraint[],
timesteps_all_equal::Bool=true,
verbose=false,
U_init=Union{AbstractMatrix{<:Number}, Nothing}=nothing,
)
U_goal = Matrix{ComplexF64}(U_goal)

if isnothing(U_init)
Ũ⃗_init = operator_to_iso_vec(1.0I(size(U_goal, 1)))
else
Ũ⃗_init = operator_to_iso_vec(U_init)
end

n_drives = length(system.G_drives)

if !isnothing(init_trajectory)
Expand All @@ -62,17 +69,22 @@ function UnitarySmoothPulseProblem(

if isnothing(a_guess)
# TODO: add warning in case U_goal is not unitary
Ũ⃗ = unitary_geodesic(U_goal, T)
a_dists = [Uniform(-a_bounds[i], a_bounds[i]) for i = 1:n_drives]
a = hcat([
zeros(n_drives),
vcat([rand(a_dists[i], 1, T - 2) for i = 1:n_drives]...),
zeros(n_drives)
]...)
try
Ũ⃗ = unitary_geodesic(U_goal, T)
catch e
@warn "Could not find geodesic. Using random initial trajectory."
Ũ⃗ = unitary_rollout(Ũ⃗_init, a, Δt, system)
end
da = randn(n_drives, T) * drive_derivative_σ
dda = randn(n_drives, T) * drive_derivative_σ
else
Ũ⃗ = unitary_rollout(a_guess, Δt, system)
Ũ⃗ = unitary_rollout(Ũ⃗_init, a_guess, Δt, system)
Δt = vec(Δt)
a = a_guess
da = derivative(a, Δt)
Expand All @@ -93,8 +105,14 @@ function UnitarySmoothPulseProblem(
Δt = (Δt_min, Δt_max),
)

if isnothing(U_init)
Ũ⃗_init = operator_to_iso_vec(1.0I(size(U_goal, 1)))
else
Ũ⃗_init = operator_to_iso_vec(U_init)
end

initial = (
Ũ⃗ = operator_to_iso_vec(1.0I(size(U_goal, 1))),
Ũ⃗ = Ũ⃗_init,
a = zeros(n_drives),
)

Expand Down

0 comments on commit 980a4fa

Please sign in to comment.