Skip to content

Commit e653ee2

Browse files
committed
Fixes
1 parent 9e752c9 commit e653ee2

5 files changed

Lines changed: 71 additions & 80 deletions

File tree

perf/neural.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
# first-order NLP solver.
66

77
using JuMP
8-
import MathOptInterface as MOI
98
using ArrayDiff
9+
using LinearAlgebra
1010
import NLopt
1111

1212
n = 2
@@ -28,12 +28,11 @@ end
2828
# Forward pass: Y = W2 * tanh.(W1 * X)
2929
Y = W2 * tanh.(W1 * X)
3030

31-
# Loss: sum of squared differences
32-
diff = Y - target
33-
loss = ArrayDiff.sumsq(diff)
31+
# Loss: ||Y - target|| (norm returns a scalar NonlinearExpr)
32+
# Pre-compute expression before @objective to avoid macro rewriting of `.-`
33+
loss = norm(Y .- target)
34+
@objective(model, Min, loss)
3435

35-
# Set the NLP objective and optimize
36-
ArrayDiff.set_nlp_objective!(model, MOI.MIN_SENSE, loss)
3736
optimize!(model)
3837

3938
println("Termination status: ", termination_status(model))

src/JuMP/moi_bridge.jl

Lines changed: 36 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Conversion from JuMP array types to MOI ArrayNonlinearFunction,
2-
# to Julia Expr for ArrayDiff parsing, and NLPBlock setup helpers.
2+
# to Julia Expr for ArrayDiff parsing, and NLPBlock setup via
3+
# JuMP.set_objective_function override.
34

45
# ── moi_function: JuMP → MOI ─────────────────────────────────────────────────
56

@@ -115,41 +116,14 @@ function to_expr(x::Expr)
115116
return x
116117
end
117118

118-
# ── Scalar expression from array operations ──────────────────────────────────
119+
# ── to_expr for JuMP scalar nonlinear expressions ────────────────────────────
119120

120-
"""
121-
ArrayScalarExpr
122-
123-
A scalar-valued expression that operates on array subexpressions (e.g.,
124-
`dot(A, B)`, `sum(A)`, `norm(A)`). This is the result type of scalar
125-
reductions on `GenericArrayExpr`.
126-
"""
127-
struct ArrayScalarExpr
128-
head::Symbol
129-
args::Vector{Any}
130-
end
131-
132-
function to_expr(x::ArrayScalarExpr)
121+
function to_expr(x::JuMP.GenericNonlinearExpr)
133122
return Expr(:call, x.head, Any[to_expr(a) for a in x.args]...)
134123
end
135124

136-
"""
137-
ArrayDiff.dot(x, y)
138-
139-
Compute the dot product (sum of elementwise products) of two array expressions.
140-
Returns an `ArrayScalarExpr` (scalar).
141-
"""
142-
function dot(x, y)
143-
return ArrayScalarExpr(:dot, Any[x, y])
144-
end
145-
146-
"""
147-
ArrayDiff.sumsq(x)
148-
149-
Compute the sum of squares of an array expression. Equivalent to `dot(x, x)`.
150-
"""
151-
function sumsq(x)
152-
return dot(x, x)
125+
function to_expr(x::JuMP.GenericVariableRef)
126+
return JuMP.index(x)
153127
end
154128

155129
# ── parse_expression for ArrayNonlinearFunction ──────────────────────────────
@@ -172,48 +146,43 @@ function parse_expression(
172146
return parse_expression(data, expr, to_expr(x), parent_index)
173147
end
174148

175-
# ── NLPBlock setup helpers ───────────────────────────────────────────────────
149+
# ── Detect whether a JuMP expression contains array args ─────────────────────
176150

177-
"""
178-
set_nlp_objective!(jmodel::JuMP.Model, sense, objective)
179-
180-
Build an `ArrayDiff.Model` from the given `objective` expression (which may be
181-
an `ArrayScalarExpr`, `GenericArrayExpr`, `ArrayNonlinearFunction`, or plain
182-
`Expr`), create an `ArrayDiff.Evaluator` with first-order AD, and set the
183-
resulting `MOI.NLPBlockData` on the JuMP model's backend.
184-
185-
## Example
186-
187-
```julia
188-
model = Model(NLopt.Optimizer)
189-
@variable(model, W[1:n, 1:n], container = ArrayDiff.ArrayOfVariables)
190-
Y = W * X
191-
diff = Y - target
192-
ArrayDiff.set_nlp_objective!(model, MOI.MIN_SENSE, ArrayDiff.sumsq(diff))
193-
optimize!(model)
194-
```
195-
"""
196-
function set_nlp_objective!(
197-
jmodel::JuMP.Model,
198-
sense::MOI.OptimizationSense,
199-
objective,
200-
)
201-
# Collect ordered variables
151+
_has_array_args(::Any) = false
152+
_has_array_args(::AbstractJuMPArray) = true
153+
_has_array_args(::ArrayNonlinearFunction) = true
154+
155+
function _has_array_args(x::JuMP.GenericNonlinearExpr)
156+
return any(_has_array_args, x.args)
157+
end
158+
159+
# ── Override set_objective_function for array-valued nonlinear expressions ────
160+
161+
function _set_arraydiff_nlp_block!(
162+
jmodel::JuMP.GenericModel{T},
163+
func::JuMP.GenericNonlinearExpr{JuMP.GenericVariableRef{T}},
164+
) where {T}
202165
vars = JuMP.all_variables(jmodel)
203166
ordered_variables = [JuMP.index(v) for v in vars]
204-
205-
# Build ArrayDiff Model
206167
ad_model = Model()
207-
obj_expr = to_expr(objective)
168+
obj_expr = to_expr(func)
208169
set_objective(ad_model, obj_expr)
209-
210-
# Create evaluator (first-order AD)
211170
evaluator = Evaluator(ad_model, Mode(), ordered_variables)
212171
nlp_data = MOI.NLPBlockData(evaluator)
172+
MOI.set(JuMP.backend(jmodel), MOI.NLPBlock(), nlp_data)
173+
return
174+
end
213175

214-
# Set on the JuMP backend
215-
backend = JuMP.backend(jmodel)
216-
MOI.set(backend, MOI.NLPBlock(), nlp_data)
217-
MOI.set(backend, MOI.ObjectiveSense(), sense)
176+
function JuMP.set_objective_function(
177+
model::JuMP.GenericModel{T},
178+
func::JuMP.GenericNonlinearExpr{JuMP.GenericVariableRef{T}},
179+
) where {T<:Real}
180+
if _has_array_args(func)
181+
return _set_arraydiff_nlp_block!(model, func)
182+
end
183+
# Fall back to standard JuMP: convert to MOI and set on backend.
184+
f = JuMP.moi_function(func)
185+
attr = MOI.ObjectiveFunction{typeof(f)}()
186+
MOI.set(JuMP.backend(model), attr, f)
218187
return
219188
end

src/reverse_mode.jl

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,28 @@ function _forward_eval(
388388
elseif node.type == NODE_CALL_MULTIVARIATE_BROADCASTED
389389
children_indices = SparseArrays.nzrange(f.adj, k)
390390
N = length(children_indices)
391-
if node.index == node.index == 3 # :*
391+
if node.index == 1 # :+ (broadcasted)
392+
for j in _eachindex(f.sizes, k)
393+
tmp_sum = zero(T)
394+
for c_idx in children_indices
395+
ix = children_arr[c_idx]
396+
@j f.partials_storage[ix] = one(T)
397+
tmp_sum += @j f.forward_storage[ix]
398+
end
399+
@j f.forward_storage[k] = tmp_sum
400+
end
401+
elseif node.index == 2 # :- (broadcasted)
402+
@assert N == 2
403+
child1 = first(children_indices)
404+
@inbounds ix1 = children_arr[child1]
405+
@inbounds ix2 = children_arr[child1+1]
406+
for j in _eachindex(f.sizes, k)
407+
@j f.partials_storage[ix1] = one(T)
408+
@j f.partials_storage[ix2] = -one(T)
409+
@j f.forward_storage[k] =
410+
@j(f.forward_storage[ix1]) - @j(f.forward_storage[ix2])
411+
end
412+
elseif node.index == 3 # :* (broadcasted)
392413
# Node `k` is not scalar, so we do matrix multiplication
393414
if f.sizes.ndims[k] != 0
394415
@assert N == 2

src/sizes.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,10 @@ function _infer_sizes(
285285
continue
286286
end
287287
op = DEFAULT_MULTIVARIATE_OPERATORS[node.index]
288-
if op == :*
288+
if op == :+ || op == :-
289+
# Broadcasted +/- preserves shape
290+
_copy_size!(sizes, k, children_arr[first(children_indices)])
291+
elseif op == :*
289292
# TODO assert compatible sizes and all ndims should be 0 or 2
290293
first_matrix = findfirst(children_indices) do i
291294
return !iszero(sizes.ndims[children_arr[i]])

test/JuMP.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -145,11 +145,11 @@ function test_to_expr()
145145
X = rand(2, 2)
146146
Y = W * tanh.(W * X)
147147
diff = Y - X
148-
loss = ArrayDiff.sumsq(diff)
148+
loss = LinearAlgebra.norm(diff)
149149
expr = ArrayDiff.to_expr(loss)
150150
@test expr isa Expr
151151
@test expr.head == :call
152-
@test expr.args[1] == :dot
152+
@test expr.args[1] == :norm
153153
return
154154
end
155155

@@ -183,9 +183,8 @@ function test_neural_nlopt()
183183
set_start_value(W2[i, j], start_W2[i, j])
184184
end
185185
Y = W2 * tanh.(W1 * X)
186-
diff = Y - target
187-
loss = ArrayDiff.sumsq(diff)
188-
ArrayDiff.set_nlp_objective!(model, MOI.MIN_SENSE, loss)
186+
loss = LinearAlgebra.norm(Y .- target)
187+
@objective(model, Min, loss)
189188
optimize!(model)
190189
@test termination_status(model) == MOI.LOCALLY_SOLVED
191190
@test objective_value(model) < 1e-6

0 commit comments

Comments
 (0)