Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug fix for moving average tracking #48

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/src/man/cls_overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,8 @@ solver = RLSSolver(
# Solve the system
sol = rsolve(solver, A, b)
```

## Block Methods for Linear Systems

Because of the way that computers operate, it is often more efficient to work using
blocks of data rather than single vectors to generate updates to solutions.
50 changes: 41 additions & 9 deletions src/linear_solver_logs/solve_log_ma.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ mutable struct SEDistInfo
sigma2::Union{Float64, Nothing}
omega::Union{Float64, Nothing}
eta::Float64
scaling::Float64
end

"""
Expand Down Expand Up @@ -130,15 +131,15 @@ LSLogMA(;
eta = 1,
true_res = false
) = LSLogMA( collection_rate,
MAInfo(lambda1, lambda2, 1, false, 1, zeros(lambda2)),
MAInfo(lambda1, lambda2, lambda1, false, 1, zeros(lambda2)),
Float64[],
Float64[],
Int64[],
norm,
-1,
false,
true_res,
SEDistInfo(nothing, 0, 0, sigma2, omega, eta)
SEDistInfo(nothing, 0, 0, sigma2, omega, eta, 0)
)

#Function to update the moving average
Expand All @@ -149,7 +150,7 @@ function log_update!(
samp::Tuple,
iter::Int64,
A::AbstractArray,
b::AbstractVector
b::AbstractVector,
)
if iter == 0
# Check if it is a row or column method and record dimensions
Expand All @@ -171,7 +172,7 @@ function log_update!(

log.dist_info.sampler = typeof(sampler)
# If the constants for the sub-Exponential distribution are not defined then define them
if typeof(log.dist_info.sigma2) <: Nothing
if typeof(log.dist_info.sigma2) <: Nothing || log.dist_info.sigma2 == 0
get_SE_constants!(log, log.dist_info.sampler)
end

Expand All @@ -183,17 +184,18 @@ function log_update!(
if !log.true_res && iter > 0
# Compute the current residual to second power to align with theory
# Check if it is one dimensional or block sampling method
res::Float64 = eltype(samp[1]) <: Int64 || size(samp[1],2) != 1 ?
log.resid_norm(samp[3])^2 : log.resid_norm(dot(samp[1], x) - samp[2])^2
res::Float64 = log.dist_info.scaling * (eltype(samp[1]) <: Int64 || size(samp[1],2) != 1 ?
log.resid_norm(samp[3])^2 : log.resid_norm(dot(samp[1], x) - samp[2])^2)
else
res = log.resid_norm(A * x - b)^2
end

# Check if MA is in lambda1 or lambda2 regime
if ma_info.flag
update_ma!(log, res, ma_info.lambda2, iter)
else
#Check if we can switch between lambda1 and lambda2 regime
if res < ma_info.res_window[ma_info.idx]
if iter == 0 || res <= ma_info.res_window[ma_info.idx]
update_ma!(log, res, ma_info.lambda1, iter)
else
update_ma!(log, res, ma_info.lambda1, iter)
Expand Down Expand Up @@ -232,7 +234,7 @@ function update_ma!(log::LSLogMA, res::Union{AbstractVector, Real}, lambda_base:
ma_info.idx = ma_info.idx < ma_info.lambda2 ? ma_info.idx + 1 : 1
ma_info.res_window[ma_info.idx] = res
#Check if entire storage buffer can be used
if ma_info.lambda == lambda_base
if ma_info.lambda == ma_info.lambda2
# Compute the moving average
for i in 1:ma_info.lambda2
accum += ma_info.res_window[i]
Expand All @@ -244,6 +246,32 @@ function update_ma!(log::LSLogMA, res::Union{AbstractVector, Real}, lambda_base:
push!(log.resid_hist, accum / ma_info.lambda)
push!(log.iota_hist, accum2 / ma_info.lambda)
end

elseif ma_info.lambda == ma_info.lambda1 && !ma_info.flag
diff = ma_info.idx - ma_info.lambda
# Determine start point for first loop
startp1 = diff < 0 ? 1 : (diff + 1)

# Determine start and endpoints for second loop
startp2 = diff < 0 ? ma_info.lambda2 + diff + 1 : 2
endp2 = diff < 0 ? ma_info.lambda2 : 1
# Compute the moving average two loop setup required when lambda < lambda2
for i in startp1:ma_info.idx
accum += ma_info.res_window[i]
accum2 += ma_info.res_window[i]^2
end

for i in startp2:endp2
accum += ma_info.res_window[i]
accum2 += ma_info.res_window[i]^2
end

#Update the log variable with the information for this update
if mod(iter, log.collection_rate) == 0 || iter == 0
push!(log.lambda_hist, ma_info.lambda)
push!(log.resid_hist, accum / ma_info.lambda)
push!(log.iota_hist, accum2 / ma_info.lambda)
end

else
# Get the difference between the start and current lambda
Expand Down Expand Up @@ -332,7 +360,8 @@ This function is not exported and thus the user does not have direct access to i
- `sampler::Type{LinSysSampler}`, the type of sampler being used.

# Outputs
Performs an inplace update of the sub-Exponential constants for the log.
Performs an inplace update of the sub-Exponential constants for the log. Additionally, updates the scaling constant to ensure expectation of
block norms is equal to true norm.
"""
function get_SE_constants!(log::LSLogMA, sampler::Type{T}) where T<:LinSysSampler
return nothing
Expand All @@ -347,6 +376,7 @@ for type in (LinSysVecRowDetermCyclic,LinSysVecRowHopRandCyclic,
@eval begin
function get_SE_constants!(log::LSLogMA, sampler::Type{$type})
log.dist_info.sigma2 = log.dist_info.dimension^2 / (4 * log.dist_info.block_dimension^2 * log.dist_info.eta)
log.dist_info.scaling = log.dist_info.dimension / log.dist_info.block_dimension
end

end
Expand All @@ -359,6 +389,7 @@ for type in (LinSysVecColOneRandCyclic, LinSysVecColDetermCyclic)
@eval begin
function get_SE_constants!(log::LSLogMA, sampler::Type{$type})
log.dist_info.sigma2 = log.dist_info.dimension^2 / (4 * log.dist_info.block_dimension^2 * log.dist_info.eta)
log.dist_info.scaling = log.dist_info.dimension / log.dist_info.block_dimension
end

end
Expand All @@ -371,6 +402,7 @@ for type in (LinSysVecRowGaussSampler, LinSysVecRowSparseGaussSampler)
function get_SE_constants!(log::LSLogMA, sampler::Type{$type})
log.dist_info.sigma2 = log.dist_info.block_dimension / (0.2345 * log.dist_info.eta)
log.dist_info.omega = .1127
log.dist_info.scaling = 1.
end

end
Expand Down
79 changes: 56 additions & 23 deletions test/linear_solver_logs/proc_solve_log_ma.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,54 +23,87 @@ Random.seed!(1010)
z = rand(2)

sampler = LinSysVecRowOneRandCyclic()
log = LSLogMA()
logger = LSLogMA()


RLinearAlgebra.log_update!(log, sampler, z, (A[1,:],b[1]), 0, A, b)
RLinearAlgebra.log_update!(logger, sampler, z, (A[1,:],b[1]), 0, A, b)

@test length(log.resid_hist) == 1
@test log.resid_hist[1] == norm(A * z - b)^2
@test norm(log.iota_hist[1] - norm(A * z - b)^4) < 1e2 * eps()
@test log.iterations == 0
@test log.converged == false
@test length(logger.resid_hist) == 1
@test logger.resid_hist[1] == norm(A * z - b)^2
@test norm(logger.iota_hist[1] - norm(A * z - b)^4) < 1e2 * eps()
@test logger.iterations == 0
@test logger.converged == false
end

# Verify moving average
# Verify late moving average
let
A = rand(2,2)
x = rand(2)
b = A * x
z = rand(2)

sampler = LinSysVecRowOneRandCyclic()
log = LSLogMA(lambda2 = 2)
logger = LSLogMA(lambda2 = 2)
samp = (A[1,:], b[1])

RLinearAlgebra.log_update!(log, sampler, z, (A[1,:],b[1]), 0, A, b)
RLinearAlgebra.log_update!(logger, sampler, z, (A[1,:],b[1]), 0, A, b)
# Test moving average of log
for i = 1:10
samp = (A[1,:], b[1])
RLinearAlgebra.log_update!(log, sampler, x + (i+1)*(z-x), samp, i, A, b)
RLinearAlgebra.log_update!(logger, sampler, x + (i+1)*(z-x), samp, i, A, b)
end
#compute sampled residuals
obs_res = [abs(dot(A[1,:],x + (i+1)*(z-x)) - b[1])^2 for i = 0:10]
obs_res2 = [abs(dot(A[1,:],x + (i+1)*(z-x)) - b[1])^4 for i = 0:10]
@test length(log.resid_hist) == 11
@test norm(log.resid_hist[2:11] - vcat(obs_res[2],
[(obs_res[i] + obs_res[i-1])/2 for i = 3:11])) < 1e2 * eps()
@test norm(log.iota_hist[2:11] - vcat(obs_res2[2],
[(obs_res2[i] + obs_res2[i-1])/2 for i = 3:11])) < 1e2 * eps()
@test log.iterations == 10
@test log.converged == false
obs_res = 2 .* [abs(dot(A[1,:],x + (i+1)*(z-x)) - b[1])^2 for i = 0:10]
obs_res2 = 4 .* [abs(dot(A[1,:],x + (i+1)*(z-x)) - b[1])^4 for i = 0:10]
@test length(logger.resid_hist) == 11
@test norm(logger.resid_hist[3:11] - vcat(obs_res[3],
[(obs_res[i] + obs_res[i-1])/2 for i = 4:11])) < 1e2 * eps()
@test norm(logger.iota_hist[3:11] - vcat(obs_res2[3],
[(obs_res2[i] + obs_res2[i-1])/2 for i = 4:11])) < 1e2 * eps()
@test logger.iterations == 10
@test logger.converged == false

#Test uncertainty set
Uncertainty_set = get_uncertainty(log)
Uncertainty_set = get_uncertainty(logger)
@test length(Uncertainty_set[1]) == 11
#If you undo the steps of the interval calculation should be 1
@test norm((Uncertainty_set[2] - log.resid_hist) ./ sqrt.(2 * Base.log(2/.05) * log.iota_hist *
log.dist_info.sigma2 .* (1 .+ Base.log.(log.lambda_hist)) ./ log.lambda_hist) .- 1) < 1e2 * eps()
@test norm((Uncertainty_set[2] - logger.resid_hist) ./ sqrt.(2 * log(2/.05) * logger.iota_hist *
logger.dist_info.sigma2 .* (1 .+ log.(logger.lambda_hist)) ./ logger.lambda_hist) .- 1) < 1e2 * eps()
end
# Verify early moving average
let
A = rand(2,2)
x = rand(2)
b = A * x
z = rand(2)

sampler = LinSysVecRowOneRandCyclic()
logger = LSLogMA(lambda1 = 2,
lambda2 = 10)
samp = (A[1,:], b[1])

RLinearAlgebra.log_update!(logger, sampler, z, (A[1,:],b[1]), 0, A, b)
# Test moving average of log when the residual only decreases to not trigger switch
for i = 1:10
samp = (A[1,:], b[1])
RLinearAlgebra.log_update!(logger, sampler, x + .3^(i+1)*(z-x), samp, i, A, b)
end
#compute sampled residuals
obs_res = 2 .* [abs(dot(A[1,:],x + .3^(i+1)*(z-x)) - b[1])^2 for i = 0:10]
obs_res2 = 4 .* [abs(dot(A[1,:],x + .3^(i+1)*(z-x)) - b[1])^4 for i = 0:10]
@test length(logger.resid_hist) == 11
@test norm(logger.resid_hist[3:11] - vcat( [(obs_res[i] + obs_res[i-1])/2 for i = 3:11])) < 1e2 * eps()
@test norm(logger.iota_hist[3:11] - vcat( [(obs_res2[i] + obs_res2[i-1])/2 for i = 3:11])) < 1e2 * eps()
@test logger.iterations == 10
@test logger.converged == false

#Test uncertainty set
Uncertainty_set = get_uncertainty(logger)
@test length(Uncertainty_set[1]) == 11
#If you undo the steps of the interval calculation should be 1
@test norm((Uncertainty_set[2] - logger.resid_hist) ./ sqrt.(2 * log(2/.05) * logger.iota_hist *
logger.dist_info.sigma2 .* (1 .+ log.(logger.lambda_hist)) ./ logger.lambda_hist) .- 1) < 1e2 * eps()
end
end

end # End Module
Loading