44# Use of this source code is governed by an MIT-style license that can be found
55# in the LICENSE.md file or at https://opensource.org/licenses/MIT.
66
7+
78"""
89 struct ColoringResult
9- result::SMC.TreeSetColoringResult
10+ result::SMC.AbstractColoringResult
1011 local_indices::Vector{Int} # map from local to global indices
12+ full_colptr::Vector{Int} # colptr of the full symmetric matrix used for coloring
13+ lower_pos::Vector{Int} # positions of lower-triangular entries in the full nzval
14+ full_buffer::Vector{Float64} # pre-allocated buffer of size nnz(full symmetric matrix)
1115 end
1216
13- Wrapper around TreeSetColoringResult that also stores local_indices mapping.
17+ Wrapper around AbstractColoringResult that also stores auxiliary data needed
18+ for Hessian recovery from a full symmetric matrix decompression.
1419"""
1520struct ColoringResult{R<: SMC.AbstractColoringResult }
1621 result:: R
17- local_indices:: Vector{Int} # map from local to global indices
22+ local_indices:: Vector{Int} # map from local to global indices
23+ full_colptr:: Vector{Int} # colptr of full symmetric matrix used for coloring
24+ lower_pos:: Vector{Int} # positions of lower-triangular entries in full nzval
25+ full_buffer:: Vector{Float64} # scratch buffer of length nnz(full symmetric matrix)
1826end
1927
2028"""
2836`edgelist` contains the nonzeros in the Hessian, *including* nonzeros on the
2937diagonal.
3038
31- Returns `(I, J, result)` where `I` and `J` are the row and column indices
32- of the Hessian structure, and `result` is a `TreeSetColoringResult` from
33- SparseMatrixColorings .
39+ Returns `(colptr, I, J, result)` where `colptr`, ` I` and `J` define the lower
40+ triangular CSC sparsity structure of the Hessian (in global variable indices),
41+ and `result` is a `ColoringResult` wrapping an SMC coloring result .
3442"""
3543function _hessian_color_preprocess (
3644 edgelist,
@@ -39,13 +47,16 @@ function _hessian_color_preprocess(
3947 seen_idx = MOI. Nonlinear. Coloring. IndexedSet (0 ),
4048)
4149 resize! (seen_idx, num_total_var)
42- I, J = Int[], Int[]
50+ # Collect off-diagonal lower-triangular entries (local coords, filled later)
51+ I_off, J_off = Int[], Int[]
4352 for (ei, ej) in edgelist
4453 push! (seen_idx, ei)
4554 push! (seen_idx, ej)
46- # Store in lower triangular format: row >= col
47- push! (I, max (ei, ej))
48- push! (J, min (ei, ej))
55+ if ei != ej
56+ # Store in lower triangular format: row > col
57+ push! (I_off, max (ei, ej))
58+ push! (J_off, min (ei, ej))
59+ end
4960 end
5061 local_indices = sort! (collect (seen_idx))
5162 empty! (seen_idx)
@@ -54,36 +65,58 @@ function _hessian_color_preprocess(
5465 for k in eachindex (local_indices)
5566 global_to_local_idx[local_indices[k]] = k
5667 end
57- # only do the coloring on the local indices
58- for k in eachindex (I )
59- I [k] = global_to_local_idx[I [k]]
60- J [k] = global_to_local_idx[J [k]]
68+ # Map off-diagonal entries to local indices
69+ for k in eachindex (I_off )
70+ I_off [k] = global_to_local_idx[I_off [k]]
71+ J_off [k] = global_to_local_idx[J_off [k]]
6172 end
6273
6374 n = length (local_indices)
64- # Always include diagonal entries (needed for Hessian recovery)
75+
76+ # Build full symmetric matrix: both (i,j) and (j,i) for off-diagonal, plus diagonal
77+ I_full, J_full = Int[], Int[]
78+ for k in eachindex (I_off)
79+ push! (I_full, I_off[k]); push! (J_full, J_off[k]) # lower
80+ push! (I_full, J_off[k]); push! (J_full, I_off[k]) # upper (transpose)
81+ end
6582 for k in 1 : n
66- push! (I, k)
67- push! (J, k)
83+ push! (I_full, k); push! (J_full, k) # diagonal
6884 end
85+ mat_sym = SparseArrays. sparse (I_full, J_full, trues (length (I_full)), n, n, | )
6986
70- # Create lower triangular sparsity pattern (including diagonal)
71- mat = SparseArrays. sparse (I, J, trues (length (I)), n, n, | )
72- S = SMC. SparsityPatternCSC (mat)
73-
74- # Perform coloring using SMC
87+ # Perform coloring on full symmetric matrix
88+ S = SMC. SparsityPatternCSC (mat_sym)
7589 problem = SMC. ColoringProblem (; structure = :symmetric , partition = :column )
7690 tree_result = SMC. coloring (S, problem, algo)
7791
78- # Wrap result with local_indices
79- result = ColoringResult (tree_result, local_indices)
92+ # Find positions of lower-triangular entries within the full CSC nzval array.
93+ # findnz on a CSC matrix returns elements in CSC (column-major) order,
94+ # matching the nzval layout exactly.
95+ I_nz, J_nz, _ = SparseArrays. findnz (mat_sym)
96+ lower_pos = findall (k -> I_nz[k] >= J_nz[k], eachindex (I_nz))
97+
98+ # Lower-triangular CSC-ordered local indices
99+ I_low_csc = I_nz[lower_pos]
100+ J_low_csc = J_nz[lower_pos]
80101
81- # Get CSC-ordered indices from the sparse matrix, then map back to global
82- I_local, J_local, _ = SparseArrays. findnz (mat)
83- I_global = [local_indices[i] for i in I_local]
84- J_global = [local_indices[j] for j in J_local]
102+ # Map back to global indices for the returned hess_I / hess_J
103+ I_global = [local_indices[i] for i in I_low_csc]
104+ J_global = [local_indices[j] for j in J_low_csc]
85105
86- return copy (mat. colptr), I_global, J_global, result
106+ # Build lower-triangular sparse matrix to obtain its colptr
107+ mat_low = SparseArrays. sparse (I_low_csc, J_low_csc, trues (length (I_low_csc)), n, n, | )
108+
109+ full_buffer = Vector {Float64} (undef, SparseArrays. nnz (mat_sym))
110+
111+ result = ColoringResult (
112+ tree_result,
113+ local_indices,
114+ copy (mat_sym. colptr),
115+ lower_pos,
116+ full_buffer,
117+ )
118+
119+ return copy (mat_low. colptr), I_global, J_global, result
87120end
88121
89122"""
127160
128161Recover the Hessian values from the Hessian-matrix product H*R_seed.
129162R is the result of H*R_seed where R_seed is the seed matrix.
130- `stored_values` is a temporary vector.
163+ `stored_values` is a temporary vector (unused, kept for API compatibility) .
131164"""
132165function _recover_from_matmat! (
133166 colptr:: AbstractVector ,
@@ -136,6 +169,10 @@ function _recover_from_matmat!(
136169 result:: ColoringResult ,
137170 stored_values:: AbstractVector{T} ,
138171) where {T}
139- SMC. decompress_csc! (V, colptr, R, result. result, :L )
172+ # Decompress into the full symmetric buffer, then extract lower-triangular values.
173+ SMC. decompress_csc! (result. full_buffer, result. full_colptr, R, result. result, :F )
174+ for k in eachindex (V)
175+ V[k] = result. full_buffer[result. lower_pos[k]]
176+ end
140177 return
141178end
0 commit comments