Skip to content

Commit d5df492

Browse files
authored
refined hessian sparsity detection (#29)
* refined hessian sparsity detection * Apply suggestions from code review
1 parent 68dbc68 commit d5df492

4 files changed

Lines changed: 161 additions & 124 deletions

File tree

src/graph_tools.jl

Lines changed: 153 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -175,12 +175,75 @@ function _compute_gradient_sparsity!(
175175
return
176176
end
177177

178+
"""
179+
_get_nonlinear_child_interactions(
180+
node::Nonlinear.Node,
181+
num_children::Int,
182+
)
183+
184+
Get the list of nonlinear child interaction pairs for a node.
185+
Returns empty list of tuples `(i, j)` where `i` and `j` are child indices (1-indexed)
186+
that have nonlinear interactions.
187+
188+
For example, for `*` with 2 children, the result is `[(1, 2)]` because children 1
189+
and 2 interact nonlinearly, but children 1 and 1, or 2 and 2, do not.
190+
191+
For functions like `+` or `-`, the result is `[]` since there are no nonlinear
192+
interactions between children.
193+
"""
194+
function _get_nonlinear_child_interactions(
195+
node::Nonlinear.Node,
196+
num_children::Int,
197+
)::Vector{Tuple{Int,Int}}
198+
if node.type == Nonlinear.NODE_CALL_UNIVARIATE
199+
@assert num_children == 1
200+
op = get(Nonlinear.DEFAULT_UNIVARIATE_OPERATORS, node.index, nothing)
201+
# Univariate operators :+ and :- don't create interactions
202+
if op in (:+, :-)
203+
return Tuple{Int,Int}[]
204+
else
205+
return [(1, 1)]
206+
end
207+
elseif node.type == Nonlinear.NODE_CALL_MULTIVARIATE
208+
op = get(Nonlinear.DEFAULT_MULTIVARIATE_OPERATORS, node.index, nothing)
209+
if op in (:+, :-, :ifelse, :min, :max)
210+
# No nonlinear interactions between children
211+
return Tuple{Int,Int}[]
212+
elseif op == :*
213+
# All pairs of distinct children interact nonlinearly
214+
result = Tuple{Int,Int}[]
215+
for i in 1:num_children
216+
for j in 1:(i-1)
217+
push!(result, (j, i))
218+
end
219+
end
220+
return result
221+
elseif op == :/
222+
@assert num_children == 2
223+
# The numerator doesn't have a nonlinear interaction with itself.
224+
return [(1, 2), (2, 2)]
225+
else
226+
# Conservative: assume all pairs interact
227+
result = Tuple{Int,Int}[]
228+
for i in 1:num_children
229+
for j in 1:i
230+
push!(result, (j, i))
231+
end
232+
end
233+
return result
234+
end
235+
else
236+
# Logic and comparison nodes don't generate hessian terms.
237+
# Subexpression nodes are special cased.
238+
return Tuple{Int,Int}[]
239+
end
240+
end
241+
178242
"""
179243
_compute_hessian_sparsity(
180244
nodes::Vector{Nonlinear.Node},
181245
adj,
182246
input_linearity::Vector{Linearity},
183-
indexedset::Coloring.IndexedSet,
184247
subexpression_edgelist::Vector{Set{Tuple{Int,Int}}},
185248
subexpression_variables::Vector{Vector{Int}},
186249
)
@@ -193,142 +256,118 @@ Compute the sparsity pattern the Hessian of an expression.
193256
* `subexpression_variables` is the list of all variables which appear in a
194257
subexpression (including recursively).
195258
196-
Idea: consider the (non)linearity of a node *with respect to the output*. The
197-
children of any node which is nonlinear with respect to the output should have
198-
nonlinear interactions, hence nonzeros in the hessian. This is not true in
199-
general, but holds for everything we consider.
200-
201-
A counter example is `f(x, y, z) = x + y * z`, but we don't have any functions
202-
like that. By "nonlinear with respect to the output", we mean that the output
203-
depends nonlinearly on the value of the node, regardless of how the node itself
204-
depends on the input.
259+
Returns a `Set{Tuple{Int,Int}}` containing the nonzero entries of the Hessian.
205260
"""
206261
function _compute_hessian_sparsity(
207262
nodes::Vector{Nonlinear.Node},
208263
adj,
209264
input_linearity::Vector{Linearity},
210-
indexedset::Coloring.IndexedSet,
211265
subexpression_edgelist::Vector{Set{Tuple{Int,Int}}},
212266
subexpression_variables::Vector{Vector{Int}},
213267
)
214-
# So start at the root of the tree and classify the linearity wrt the output.
215-
# For each nonlinear node, do a mini DFS and collect the list of children.
216-
# Add a nonlinear interaction between all children of a nonlinear node.
217268
edge_list = Set{Tuple{Int,Int}}()
218-
nonlinear_wrt_output = fill(false, length(nodes))
219269
children_arr = SparseArrays.rowvals(adj)
220-
stack = Int[]
221-
stack_ignore = Bool[]
222-
nonlinear_group = indexedset
223-
if length(nodes) == 1 && nodes[1].type == Nonlinear.NODE_SUBEXPRESSION
224-
# Subexpression comes in linearly, so append edge_list
225-
for ij in subexpression_edgelist[nodes[1].index]
226-
push!(edge_list, ij)
227-
end
228-
end
229-
for k in 2:length(nodes)
230-
nod = nodes[k]
231-
@assert nod.type != Nonlinear.NODE_MOI_VARIABLE
232-
if nonlinear_wrt_output[k]
233-
continue # already seen this node one way or another
234-
elseif input_linearity[k] == CONSTANT
235-
continue # definitely not nonlinear
270+
# Stack entry: (node_index, child_group_index)
271+
stack = Tuple{Int,Int}[]
272+
# Map from child_group_index to variable indices
273+
child_group_variables = Dict{Int,Set{Int}}()
274+
for (k, node) in enumerate(nodes)
275+
@assert node.type != Nonlinear.NODE_MOI_VARIABLE
276+
if input_linearity[k] == CONSTANT
277+
continue # No hessian contribution from constant nodes
236278
end
237-
@assert !nonlinear_wrt_output[nod.parent]
238-
# check if the parent depends nonlinearly on the value of this node
239-
par = nodes[nod.parent]
240-
if par.type == Nonlinear.NODE_CALL_UNIVARIATE
241-
op = get(Nonlinear.DEFAULT_UNIVARIATE_OPERATORS, par.index, nothing)
242-
if op === nothing || (op != :+ && op != :-)
243-
nonlinear_wrt_output[k] = true
279+
# Check if this node has nonlinear child interactions
280+
children_idx = SparseArrays.nzrange(adj, k)
281+
num_children = length(children_idx)
282+
interactions = _get_nonlinear_child_interactions(node, num_children)
283+
if !isempty(interactions)
284+
# This node has nonlinear child interactions, so collect variables
285+
# from its children
286+
empty!(child_group_variables)
287+
# DFS from all children, tracking child index
288+
for (child_position, cidx) in enumerate(children_idx)
289+
child_node_idx = children_arr[cidx]
290+
push!(stack, (child_node_idx, child_position))
244291
end
245-
elseif par.type == Nonlinear.NODE_CALL_MULTIVARIATE
246-
op = get(
247-
Nonlinear.DEFAULT_MULTIVARIATE_OPERATORS,
248-
par.index,
249-
nothing,
250-
)
251-
if op === nothing
252-
nonlinear_wrt_output[k] = true
253-
elseif op in (:+, :-, :ifelse)
254-
# pass
255-
elseif op == :*
256-
# check if all siblings are constant
257-
sibling_idx = SparseArrays.nzrange(adj, nod.parent)
258-
if !all(
259-
i ->
260-
input_linearity[children_arr[i]] == CONSTANT ||
261-
children_arr[i] == k,
262-
sibling_idx,
263-
)
264-
# at least one sibling isn't constant
265-
nonlinear_wrt_output[k] = true
292+
while length(stack) > 0
293+
r, child_group_idx = pop!(stack)
294+
# Don't traverse into logical conditions or comparisons
295+
if nodes[r].type == Nonlinear.NODE_LOGIC ||
296+
nodes[r].type == Nonlinear.NODE_COMPARISON
297+
continue
266298
end
267-
elseif op == :/
268-
# check if denominator is nonconstant
269-
sibling_idx = SparseArrays.nzrange(adj, nod.parent)
270-
if input_linearity[children_arr[last(sibling_idx)]] != CONSTANT
271-
nonlinear_wrt_output[k] = true
299+
r_children_idx = SparseArrays.nzrange(adj, r)
300+
for cidx in r_children_idx
301+
push!(stack, (children_arr[cidx], child_group_idx))
302+
end
303+
if nodes[r].type == Nonlinear.NODE_VARIABLE
304+
if !haskey(child_group_variables, child_group_idx)
305+
child_group_variables[child_group_idx] = Set{Int}()
306+
end
307+
push!(
308+
child_group_variables[child_group_idx],
309+
nodes[r].index,
310+
)
311+
elseif nodes[r].type == Nonlinear.NODE_SUBEXPRESSION
312+
sub_vars = subexpression_variables[nodes[r].index]
313+
if !haskey(child_group_variables, child_group_idx)
314+
child_group_variables[child_group_idx] = Set{Int}()
315+
end
316+
union!(child_group_variables[child_group_idx], sub_vars)
272317
end
273-
else
274-
nonlinear_wrt_output[k] = true
275318
end
276-
end
277-
if nod.type == Nonlinear.NODE_SUBEXPRESSION && !nonlinear_wrt_output[k]
278-
# subexpression comes in linearly, so append edge_list
279-
for ij in subexpression_edgelist[nod.index]
319+
_add_hessian_edges!(edge_list, interactions, child_group_variables)
320+
elseif node.type == Nonlinear.NODE_SUBEXPRESSION
321+
for ij in subexpression_edgelist[node.index]
280322
push!(edge_list, ij)
281323
end
282324
end
283-
if !nonlinear_wrt_output[k]
284-
continue
285-
end
286-
# do a DFS from here, including all children
287-
@assert isempty(stack)
288-
@assert isempty(stack_ignore)
289-
sibling_idx = SparseArrays.nzrange(adj, nod.parent)
290-
for sidx in sibling_idx
291-
push!(stack, children_arr[sidx])
292-
push!(stack_ignore, false)
293-
end
294-
empty!(nonlinear_group)
295-
while length(stack) > 0
296-
r = pop!(stack)
297-
should_ignore = pop!(stack_ignore)
298-
nonlinear_wrt_output[r] = true
299-
if nodes[r].type == Nonlinear.NODE_LOGIC ||
300-
nodes[r].type == Nonlinear.NODE_COMPARISON
301-
# don't count the nonlinear interactions inside
302-
# logical conditions or comparisons
303-
should_ignore = true
304-
end
305-
children_idx = SparseArrays.nzrange(adj, r)
306-
for cidx in children_idx
307-
push!(stack, children_arr[cidx])
308-
push!(stack_ignore, should_ignore)
309-
end
310-
if should_ignore
311-
continue
312-
end
313-
if nodes[r].type == Nonlinear.NODE_VARIABLE
314-
push!(nonlinear_group, nodes[r].index)
315-
elseif nodes[r].type == Nonlinear.NODE_SUBEXPRESSION
316-
# append all variables in subexpression
317-
union!(nonlinear_group, subexpression_variables[nodes[r].index])
325+
end
326+
return edge_list
327+
end
328+
329+
"""
330+
_add_hessian_edges!(
331+
edge_list::Set{Tuple{Int,Int}},
332+
interactions::Vector{Tuple{Int,Int}},
333+
child_variables::Dict{Int,Set{Int}},
334+
)
335+
336+
Add hessian edges based on the operator's nonlinear interaction pattern.
337+
"""
338+
function _add_hessian_edges!(
339+
edge_list::Set{Tuple{Int,Int}},
340+
interactions::Vector{Tuple{Int,Int}},
341+
child_variables::Dict{Int,Set{Int}},
342+
)
343+
for (child_i, child_j) in interactions
344+
if child_i == child_j
345+
# Within-child interactions: add all pairs from a single child
346+
if haskey(child_variables, child_i)
347+
vars = child_variables[child_i]
348+
for vi in vars
349+
for vj in vars
350+
i, j = minmax(vi, vj)
351+
push!(edge_list, (j, i))
352+
end
353+
end
318354
end
319-
end
320-
for i_ in 1:nonlinear_group.nnz
321-
i = nonlinear_group.nzidx[i_]
322-
for j_ in 1:nonlinear_group.nnz
323-
j = nonlinear_group.nzidx[j_]
324-
if j > i
325-
continue # Only lower triangle.
355+
else
356+
# Between-child interactions: add pairs from different children
357+
if haskey(child_variables, child_i) &&
358+
haskey(child_variables, child_j)
359+
vars_i = child_variables[child_i]
360+
vars_j = child_variables[child_j]
361+
for vi in vars_i
362+
for vj in vars_j
363+
i, j = minmax(vi, vj)
364+
push!(edge_list, (j, i))
365+
end
326366
end
327-
push!(edge_list, (i, j))
328367
end
329368
end
330369
end
331-
return edge_list
370+
return
332371
end
333372

334373
"""

src/mathoptinterface_api.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@ function MOI.initialize(d::NLPEvaluator, requested_features::Vector{Symbol})
9494
subex.nodes,
9595
subex.adj,
9696
linearity,
97-
coloring_storage,
9897
subexpression_edgelist,
9998
subexpression_variables,
10099
)

src/types.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@ struct _FunctionStorage
100100
expr.nodes,
101101
expr.adj,
102102
linearity,
103-
coloring_storage,
104103
subexpression_edgelist,
105104
subexpression_variables,
106105
)

test/ReverseAD.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -552,7 +552,6 @@ function test_linearity()
552552
nodes,
553553
adj,
554554
ret,
555-
indexed_set,
556555
Set{Tuple{Int,Int}}[],
557556
Vector{Int}[],
558557
)
@@ -576,12 +575,7 @@ function test_linearity()
576575
[1, 2],
577576
)
578577
_test_linearity(:(3 * 4 * ($x + $y)), ArrayDiff.LINEAR)
579-
_test_linearity(
580-
:($z * $y),
581-
ArrayDiff.NONLINEAR,
582-
Set([(3, 2), (3, 3), (2, 2)]),
583-
[2, 3],
584-
)
578+
_test_linearity(:($z * $y), ArrayDiff.NONLINEAR, Set([(3, 2)]), [2, 3])
585579
_test_linearity(:(3 + 4), ArrayDiff.CONSTANT)
586580
_test_linearity(:(sin(3) + $x), ArrayDiff.LINEAR)
587581
_test_linearity(
@@ -626,6 +620,12 @@ function test_linearity()
626620
Set([(1, 1)]),
627621
[1],
628622
)
623+
_test_linearity(
624+
:(($x + $y)/$z),
625+
ArrayDiff.NONLINEAR,
626+
Set([(3, 3), (3, 2), (3, 1)]),
627+
[1, 2, 3],
628+
)
629629
return
630630
end
631631

@@ -1357,7 +1357,7 @@ function test_hessian_reinterpret_unsafe()
13571357
x_v = ones(5)
13581358
MOI.eval_hessian_lagrangian(evaluator, H, x_v, 0.0, [1.0, 1.0])
13591359
@test count(isapprox.(H, 1.0; atol = 1e-8)) == 3
1360-
@test count(isapprox.(H, 0.0; atol = 1e-8)) == 6
1360+
@test count(isapprox.(H, 0.0; atol = 1e-8)) == 5
13611361
@test sort(H_s[round.(Bool, H)]) == [(3, 1), (3, 2), (5, 4)]
13621362
return
13631363
end

0 commit comments

Comments
 (0)