Skip to content

Commit 9ed4d93

Browse files
committed
Fix Hessian coloring: use lower triangular CSC format
- Build lower triangular sparsity pattern (row >= col) instead of full symmetric matrix, so decompress_csc! with :L works correctly - Always add diagonal entries explicitly (old Coloring module did this implicitly; SMC requires them to be in the sparsity pattern) - Get CSC structure directly from the sparse matrix instead of going through compress/decompress - Switch decompress_csc! from :U to :L to match lower triangular storage - Update test expectations to CSC lower triangular order: J=[1,1,2] (was [1,2,1]), V=[3.4,2.1,1.3] (was [3.4,1.3,2.1]) https://claude.ai/code/session_01WBu9hZukriWDSSybN9gfBq
1 parent 160886e commit 9ed4d93

2 files changed

Lines changed: 21 additions & 21 deletions

File tree

src/coloring.jl

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,12 @@ function _hessian_color_preprocess(
4040
)
4141
resize!(seen_idx, num_total_var)
4242
I, J = Int[], Int[]
43-
for (i, j) in edgelist
44-
push!(seen_idx, i)
45-
push!(seen_idx, j)
46-
push!(I, i)
47-
push!(J, j)
48-
if i != j
49-
push!(I, j)
50-
push!(J, i)
51-
end
43+
for (ei, ej) in edgelist
44+
push!(seen_idx, ei)
45+
push!(seen_idx, ej)
46+
# Store in lower triangular format: row >= col
47+
push!(I, max(ei, ej))
48+
push!(J, min(ei, ej))
5249
end
5350
local_indices = sort!(collect(seen_idx))
5451
empty!(seen_idx)
@@ -63,11 +60,16 @@ function _hessian_color_preprocess(
6360
J[k] = global_to_local_idx[J[k]]
6461
end
6562

66-
# Create sparsity pattern matrix
6763
n = length(local_indices)
68-
S = SMC.SparsityPatternCSC(
69-
SparseArrays.sparse(I, J, trues(length(I)), n, n, &),
70-
)
64+
# Always include diagonal entries (needed for Hessian recovery)
65+
for k in 1:n
66+
push!(I, k)
67+
push!(J, k)
68+
end
69+
70+
# Create lower triangular sparsity pattern (including diagonal)
71+
mat = SparseArrays.sparse(I, J, trues(length(I)), n, n, |)
72+
S = SMC.SparsityPatternCSC(mat)
7173

7274
# Perform coloring using SMC
7375
problem = SMC.ColoringProblem(; structure = :symmetric, partition = :column)
@@ -76,12 +78,10 @@ function _hessian_color_preprocess(
7678
# Wrap result with local_indices
7779
result = ColoringResult(tree_result, local_indices)
7880

79-
# SparseMatrixColorings assumes that `I` and `J` are CSC-ordered
80-
B = SMC.compress(S, tree_result)
81-
C = SMC.decompress(B, tree_result)
82-
I_sorted, J_sorted = SparseArrays.findnz(C)
81+
# Get CSC-ordered indices directly from the sparse matrix
82+
I_sorted, J_sorted, _ = SparseArrays.findnz(mat)
8383

84-
return C.colptr, I_sorted, J_sorted, result
84+
return copy(mat.colptr), I_sorted, J_sorted, result
8585
end
8686

8787
"""
@@ -134,6 +134,6 @@ function _recover_from_matmat!(
134134
result::ColoringResult,
135135
stored_values::AbstractVector{T},
136136
) where {T}
137-
SMC.decompress_csc!(V, colptr, R, result.result, :U)
137+
SMC.decompress_csc!(V, colptr, R, result.result, :L)
138138
return
139139
end

test/ReverseAD.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -426,13 +426,13 @@ function test_coloring_end_to_end_hessian_coloring_and_recovery()
426426
R = ArrayDiff._seed_matrix(rinfo)
427427
ArrayDiff._prepare_seed_matrix!(R, rinfo)
428428
@test I == [1, 2, 2]
429-
@test J == [1, 2, 1]
429+
@test J == [1, 1, 2]
430430
@test R == [1.0 0.0; 0.0 1.0]
431431
hess = [3.4 2.1; 2.1 1.3]
432432
matmat = hess * R
433433
V = zeros(3)
434434
ArrayDiff._recover_from_matmat!(colptr, V, matmat, rinfo, zeros(3))
435-
@test V == [3.4, 1.3, 2.1]
435+
@test V == [3.4, 2.1, 1.3]
436436
return
437437
end
438438

0 commit comments

Comments
 (0)