Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/constant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,17 +75,17 @@ end
function ConstantColoringAlgorithm{:column}(
matrix_template::AbstractMatrix, color::Vector{Int}
)
S = convert(SparseMatrixCSC, matrix_template)
result = ColumnColoringResult(S, color)
bg = BipartiteGraph(matrix_template)
result = ColumnColoringResult(matrix_template, bg, color)
M, R = typeof(matrix_template), typeof(result)
return ConstantColoringAlgorithm{:column,M,R}(matrix_template, color, result)
end

function ConstantColoringAlgorithm{:row}(
matrix_template::AbstractMatrix, color::Vector{Int}
)
S = convert(SparseMatrixCSC, matrix_template)
result = RowColoringResult(S, color)
bg = BipartiteGraph(matrix_template)
result = RowColoringResult(matrix_template, bg, color)
M, R = typeof(matrix_template), typeof(result)
return ConstantColoringAlgorithm{:row,M,R}(matrix_template, color, result)
end
Expand Down
116 changes: 55 additions & 61 deletions src/decompression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,8 @@ true
- [`ColoringProblem`](@ref)
- [`AbstractColoringResult`](@ref)
"""
function decompress(B::AbstractMatrix{R}, result::AbstractColoringResult) where {R<:Real}
@compat (; S) = result
A = respectful_similar(S, R)
function decompress(B::AbstractMatrix, result::AbstractColoringResult)
A = respectful_similar(result.A, eltype(B))
Comment thread
gdalle marked this conversation as resolved.
return decompress!(A, B, result)
end

Expand Down Expand Up @@ -264,12 +263,11 @@ end

## ColumnColoringResult

function decompress!(
A::AbstractMatrix{R}, B::AbstractMatrix{R}, result::ColumnColoringResult
) where {R<:Real}
@compat (; S, color) = result
function decompress!(A::AbstractMatrix, B::AbstractMatrix, result::ColumnColoringResult)
@compat (; color) = result
S = result.bg.S2
check_same_pattern(A, S)
A .= zero(R)
fill!(A, zero(eltype(A)))
Comment thread
gdalle marked this conversation as resolved.
rvS = rowvals(S)
for j in axes(S, 2)
cj = color[j]
Expand All @@ -282,9 +280,10 @@ function decompress!(
end

function decompress_single_color!(
A::AbstractMatrix{R}, b::AbstractVector{R}, c::Integer, result::ColumnColoringResult
) where {R<:Real}
@compat (; S, group) = result
A::AbstractMatrix, b::AbstractVector, c::Integer, result::ColumnColoringResult
)
@compat (; group) = result
S = result.bg.S2
check_same_pattern(A, S)
rvS = rowvals(S)
for j in group[c]
Expand All @@ -296,10 +295,9 @@ function decompress_single_color!(
return A
end

function decompress!(
A::SparseMatrixCSC{R}, B::AbstractMatrix{R}, result::ColumnColoringResult
) where {R<:Real}
@compat (; S, compressed_indices) = result
function decompress!(A::SparseMatrixCSC, B::AbstractMatrix, result::ColumnColoringResult)
@compat (; compressed_indices) = result
S = result.bg.S2
check_same_pattern(A, S)
nzA = nonzeros(A)
for k in eachindex(nzA, compressed_indices)
Expand All @@ -309,9 +307,10 @@ function decompress!(
end

function decompress_single_color!(
A::SparseMatrixCSC{R}, b::AbstractVector{R}, c::Integer, result::ColumnColoringResult
) where {R<:Real}
@compat (; S, group) = result
A::SparseMatrixCSC, b::AbstractVector, c::Integer, result::ColumnColoringResult
)
@compat (; group) = result
S = result.bg.S2
check_same_pattern(A, S)
rvS = rowvals(S)
nzA = nonzeros(A)
Expand All @@ -326,12 +325,11 @@ end

## RowColoringResult

function decompress!(
A::AbstractMatrix{R}, B::AbstractMatrix{R}, result::RowColoringResult
) where {R<:Real}
@compat (; S, color) = result
function decompress!(A::AbstractMatrix, B::AbstractMatrix, result::RowColoringResult)
@compat (; color) = result
S = result.bg.S2
check_same_pattern(A, S)
A .= zero(R)
fill!(A, zero(eltype(A)))
rvS = rowvals(S)
for j in axes(S, 2)
for k in nzrange(S, j)
Expand All @@ -344,9 +342,10 @@ function decompress!(
end

function decompress_single_color!(
A::AbstractMatrix{R}, b::AbstractVector{R}, c::Integer, result::RowColoringResult
) where {R<:Real}
@compat (; S, Sᵀ, group) = result
A::AbstractMatrix, b::AbstractVector, c::Integer, result::RowColoringResult
)
@compat (; group) = result
S, Sᵀ = result.bg.S2, result.bg.S1
check_same_pattern(A, S)
rvSᵀ = rowvals(Sᵀ)
for i in group[c]
Expand All @@ -358,10 +357,9 @@ function decompress_single_color!(
return A
end

function decompress!(
A::SparseMatrixCSC{R}, B::AbstractMatrix{R}, result::RowColoringResult
) where {R<:Real}
@compat (; S, compressed_indices) = result
function decompress!(A::SparseMatrixCSC, B::AbstractMatrix, result::RowColoringResult)
@compat (; compressed_indices) = result
S = result.bg.S2
check_same_pattern(A, S)
nzA = nonzeros(A)
for k in eachindex(nzA, compressed_indices)
Expand All @@ -373,15 +371,13 @@ end
## StarSetColoringResult

function decompress!(
A::AbstractMatrix{R},
B::AbstractMatrix{R},
result::StarSetColoringResult,
uplo::Symbol=:F,
) where {R<:Real}
@compat (; S, color, star_set) = result
A::AbstractMatrix, B::AbstractMatrix, result::StarSetColoringResult, uplo::Symbol=:F
)
@compat (; color, star_set) = result
@compat (; star, hub, spokes) = star_set
S = result.ag.S
uplo == :F && check_same_pattern(A, S)
A .= zero(R)
fill!(A, zero(eltype(A)))
for i in axes(A, 1)
if !iszero(S[i, i])
A[i, i] = B[i, color[i]]
Expand All @@ -403,14 +399,15 @@ function decompress!(
end

function decompress_single_color!(
A::AbstractMatrix{R},
b::AbstractVector{R},
A::AbstractMatrix,
b::AbstractVector,
c::Integer,
result::StarSetColoringResult,
uplo::Symbol=:F,
) where {R<:Real}
@compat (; S, color, group, star_set) = result
)
@compat (; color, group, star_set) = result
@compat (; hub, spokes) = star_set
S = result.ag.S
uplo == :F && check_same_pattern(A, S)
for i in axes(A, 1)
if !iszero(S[i, i]) && color[i] == c
Expand All @@ -434,12 +431,10 @@ function decompress_single_color!(
end

function decompress!(
A::SparseMatrixCSC{R},
B::AbstractMatrix{R},
result::StarSetColoringResult,
uplo::Symbol=:F,
) where {R<:Real}
@compat (; S, compressed_indices) = result
A::SparseMatrixCSC, B::AbstractMatrix, result::StarSetColoringResult, uplo::Symbol=:F
)
@compat (; compressed_indices) = result
S = result.ag.S
nzA = nonzeros(A)
if uplo == :F
check_same_pattern(A, S)
Expand Down Expand Up @@ -468,14 +463,13 @@ end
# TODO: add method for A::SparseMatrixCSC

function decompress!(
A::AbstractMatrix{R},
B::AbstractMatrix{R},
result::TreeSetColoringResult,
uplo::Symbol=:F,
) where {R<:Real}
@compat (; S, color, vertices_by_tree, reverse_bfs_orders, buffer) = result
A::AbstractMatrix, B::AbstractMatrix, result::TreeSetColoringResult, uplo::Symbol=:F
)
@compat (; color, vertices_by_tree, reverse_bfs_orders, buffer) = result
S = result.ag.S
uplo == :F && check_same_pattern(A, S)
A .= zero(R)
R = eltype(A)
fill!(A, zero(R))

if eltype(buffer) == R
buffer_right_type = buffer
Expand Down Expand Up @@ -513,19 +507,19 @@ end
## MatrixInverseColoringResult

function decompress!(
A::AbstractMatrix{R},
B::AbstractMatrix{R},
A::AbstractMatrix,
B::AbstractMatrix,
result::LinearSystemColoringResult,
uplo::Symbol=:F,
) where {R<:Real}
@compat (;
S, color, strict_upper_nonzero_inds, T_factorization, strict_upper_nonzeros_A
) = result
)
@compat (; color, strict_upper_nonzero_inds, T_factorization, strict_upper_nonzeros_A) =
result
S = result.ag.S
uplo == :F && check_same_pattern(A, S)

# TODO: for some reason I cannot use ldiv! with a sparse QR
strict_upper_nonzeros_A = T_factorization \ vec(B)
A .= zero(R)
fill!(A, zero(eltype(A)))
for i in axes(A, 1)
if !iszero(S[i, i])
A[i, i] = B[i, color[i]]
Expand Down
27 changes: 27 additions & 0 deletions src/graph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,19 @@ end
SparsityPatternCSC(A::SparseMatrixCSC) = SparsityPatternCSC(A.m, A.n, A.colptr, A.rowval)

Base.size(S::SparsityPatternCSC) = (S.m, S.n)

function Base.size(S::SparsityPatternCSC, d::Integer)
if d == 1
return S.m
elseif d == 2
return S.n
else
return 1
Comment thread
gdalle marked this conversation as resolved.
Outdated
end
end

Base.axes(S::SparsityPatternCSC, d::Integer) = Base.OneTo(size(S, d))

SparseArrays.nnz(S::SparsityPatternCSC) = length(S.rowval)
SparseArrays.rowvals(S::SparsityPatternCSC) = S.rowval
SparseArrays.nzrange(S::SparsityPatternCSC, j::Integer) = S.colptr[j]:(S.colptr[j + 1] - 1)
Expand Down Expand Up @@ -81,6 +94,15 @@ function Base.transpose(S::SparsityPatternCSC{T}) where {T}
return SparsityPatternCSC{T}(n, m, B_colptr, B_rowval)
end

# copied from SparseArrays.jl
function Base.getindex(S::SparsityPatternCSC, i0::Integer, i1::Integer)
Comment thread
gdalle marked this conversation as resolved.
r1 = Int(S.colptr[i1])
r2 = Int(S.colptr[i1 + 1] - 1)
(r1 > r2) && return false
r1 = searchsortedfirst(rowvals(S), i0, r1, r2, Base.Order.Forward)
return ((r1 > r2) || (rowvals(S)[r1] != i0)) ? false : true
end

## Adjacency graph

"""
Expand Down Expand Up @@ -109,6 +131,7 @@ struct AdjacencyGraph{T}
S::SparsityPatternCSC{T}
end

AdjacencyGraph(A::AbstractMatrix) = AdjacencyGraph(SparseMatrixCSC(A))
AdjacencyGraph(A::SparseMatrixCSC) = AdjacencyGraph(SparsityPatternCSC(A))

pattern(g::AdjacencyGraph) = g.S
Expand Down Expand Up @@ -183,6 +206,10 @@ struct BipartiteGraph{T<:Integer}
S2::SparsityPatternCSC{T}
end

function BipartiteGraph(A::AbstractMatrix; symmetric_pattern::Bool=false)
return BipartiteGraph(SparseMatrixCSC(A); symmetric_pattern)
end

function BipartiteGraph(A::SparseMatrixCSC; symmetric_pattern::Bool=false)
S2 = SparsityPatternCSC(A) # columns to rows
if symmetric_pattern
Expand Down
29 changes: 11 additions & 18 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,12 +180,11 @@ function coloring(
decompression_eltype::Type=Float64,
symmetric_pattern::Bool=false,
)
S = convert(SparseMatrixCSC, A)
bg = BipartiteGraph(
S; symmetric_pattern=symmetric_pattern || A isa Union{Symmetric,Hermitian}
A; symmetric_pattern=symmetric_pattern || A isa Union{Symmetric,Hermitian}
)
color = partial_distance2_coloring(bg, Val(2), algo.order)
return ColumnColoringResult(S, color)
return ColumnColoringResult(A, bg, color)
end

function coloring(
Expand All @@ -195,12 +194,11 @@ function coloring(
decompression_eltype::Type=Float64,
symmetric_pattern::Bool=false,
)
S = convert(SparseMatrixCSC, A)
bg = BipartiteGraph(
S; symmetric_pattern=symmetric_pattern || A isa Union{Symmetric,Hermitian}
A; symmetric_pattern=symmetric_pattern || A isa Union{Symmetric,Hermitian}
)
color = partial_distance2_coloring(bg, Val(1), algo.order)
return RowColoringResult(S, color)
return RowColoringResult(A, bg, color)
end

function coloring(
Expand All @@ -209,10 +207,9 @@ function coloring(
algo::GreedyColoringAlgorithm{:direct};
decompression_eltype::Type=Float64,
)
S = convert(SparseMatrixCSC, A)
ag = AdjacencyGraph(S)
ag = AdjacencyGraph(A)
color, star_set = star_coloring(ag, algo.order)
return StarSetColoringResult(S, color, star_set)
return StarSetColoringResult(A, ag, color, star_set)
end

function coloring(
Expand All @@ -221,31 +218,27 @@ function coloring(
algo::GreedyColoringAlgorithm{:substitution};
decompression_eltype::Type=Float64,
)
S = convert(SparseMatrixCSC, A)
ag = AdjacencyGraph(S)
ag = AdjacencyGraph(A)
color, tree_set = acyclic_coloring(ag, algo.order)
return TreeSetColoringResult(S, color, tree_set, decompression_eltype)
return TreeSetColoringResult(A, ag, color, tree_set, decompression_eltype)
end

## ADTypes interface

function ADTypes.column_coloring(A::AbstractMatrix, algo::GreedyColoringAlgorithm)
S = convert(SparseMatrixCSC, A)
bg = BipartiteGraph(S; symmetric_pattern=A isa Union{Symmetric,Hermitian})
bg = BipartiteGraph(A; symmetric_pattern=A isa Union{Symmetric,Hermitian})
color = partial_distance2_coloring(bg, Val(2), algo.order)
return color
end

function ADTypes.row_coloring(A::AbstractMatrix, algo::GreedyColoringAlgorithm)
S = convert(SparseMatrixCSC, A)
bg = BipartiteGraph(S; symmetric_pattern=A isa Union{Symmetric,Hermitian})
bg = BipartiteGraph(A; symmetric_pattern=A isa Union{Symmetric,Hermitian})
color = partial_distance2_coloring(bg, Val(1), algo.order)
return color
end

function ADTypes.symmetric_coloring(A::AbstractMatrix, algo::GreedyColoringAlgorithm)
S = convert(SparseMatrixCSC, A)
ag = AdjacencyGraph(S)
ag = AdjacencyGraph(A)
color, star_set = star_coloring(ag, algo.order)
return color
end
Loading