Skip to content

Commit

Permalink
Ensure block-bandwidths BlockSkylineMatrix products are not larger th…
Browse files Browse the repository at this point in the history
…an necessary (#21)
  • Loading branch information
jagot authored and dlfivefifty committed Nov 26, 2018
1 parent 3d06e72 commit 11c28ef
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 3 deletions.
26 changes: 24 additions & 2 deletions src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,28 @@ end
*(A::Matrix, B::BandedBlockBandedMatrix) = materialize(Mul(A,B))


function add_bandwidths(A::AbstractBlockBandedMatrix,B::AbstractBlockBandedMatrix)
Al,Au = colblockbandwidths(A)
Bl,Bu = colblockbandwidths(B)

l = copy(Al)
u = copy(Au)

for (v,Bv) in [(l,Bl),(u,Bu)]
n = length(v)
for i = 1:n
sel = max(i-Au[i],1):min(i+Al[i],n)
isempty(sel) && continue
v[i] += maximum(Bv[sel])
end
end

l,u
end

add_bandwidths(A::BlockBandedMatrix,B::BlockBandedMatrix) =
colblockbandwidths(A) .+ colblockbandwidths(B)

function similar(M::MatMulMat{<:AbstractBlockBandedLayout,<:AbstractBlockBandedLayout}, ::Type{T}) where T
A,B = M.factors
Arows, Acols = A.block_sizes.block_sizes.cumul_sizes
Expand All @@ -115,8 +137,8 @@ function similar(M::MatMulMat{<:AbstractBlockBandedLayout,<:AbstractBlockBandedL
end
n,m = size(A,1), size(B,2)

l, u = blockbandwidths(A) .+ blockbandwidths(B)
BlockBandedMatrix{T}(undef, BlockBandedSizes(BlockSizes((Arows,Bcols)), l, u))
l,u = add_bandwidths(A,B)
BlockSkylineMatrix{T}(undef, BlockSkylineSizes(BlockSizes((Arows,Bcols)), l, u))
end

function similar(M::MatMulMat{BandedBlockBandedColumnMajor,BandedBlockBandedColumnMajor}, ::Type{T}) where T
Expand Down
43 changes: 42 additions & 1 deletion test/test_blockskyline.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

using LazyArrays, BlockBandedMatrices, LinearAlgebra, Random, Test
import BlockBandedMatrices: colblockbandwidths

Random.seed!(0)

Expand All @@ -26,4 +26,45 @@ Random.seed!(0)
@view(V[:,2]) .= Mul(A, @view(V[:,1]))
@test V[:,2] reference
end

@testset "BlockSkylineMatrix multiplication" begin
rows = [3, 1, 2, 1, 2, 1, 2, 1, 2, 1, 3]
l,u = [1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1], [1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1]

M = BlockSkylineMatrix{Float64}(undef, (rows,rows), (l,u))
M.data .= 1

d = Diagonal(1.0:size(M,2))
D = BandedBlockBandedMatrix(d, (rows,rows), (0,0), (0,0))

MD = M*D
@test MD isa BlockSkylineMatrix
@test MD == Matrix(M)*d
@test colblockbandwidths(MD) == (l,u)

MM = M*M
@test MM isa BlockSkylineMatrix
@test MM == Matrix(M)^2
# Ensure correct (minimal) bandedness of product
MMl,MMu = colblockbandwidths(MM)
@test MMl[1:7] == [3,4,3,4,3,4,3]
@test all(MMl[8:10] .≥ [3,2,1])
@test all(MMu[2:4] .≥ [1,2,3])
@test MMu[5:11] == [3,4,3,4,3,4,3]

N = BlockBandedMatrix{Float64}(undef, (rows,rows), (1,1))
N.data .= 1
NN = N*N
# We don't want a BlockBandedMatrix^2 to become a general
# BlockSkylineMatrix
@test NN isa BlockBandedMatrix
@test NN == Matrix(N)^2

rows = [9, 4, 1, 10, 6]
O = BlockSkylineMatrix{Int64}(undef, (rows,rows), ([-2, 2, 0, 2, -1],[-1, 2, 1, 0, -1]))
O.data .= 1
OO = O*O
@test OO isa BlockSkylineMatrix
@test OO == Matrix(O)^2
end
end

0 comments on commit 11c28ef

Please sign in to comment.