Skip to content

Commit 2f6579f

Browse files
committed
Fix
1 parent 5cebee8 commit 2f6579f

5 files changed

Lines changed: 131 additions & 3 deletions

File tree

perf/neural.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
using JuMP
88
using ArrayDiff
9-
using LinearAlgebra
109
import NLopt
1110

1211
n = 2
@@ -28,8 +27,8 @@ end
2827
# Forward pass: Y = W2 * tanh.(W1 * X)
2928
Y = W2 * tanh.(W1 * X)
3029

31-
# Loss: ||Y - target|| (norm returns a scalar NonlinearExpr)
32-
loss = norm(Y .- target)
30+
# Loss: sum of squared errors
31+
loss = sum((Y .- target) .^ 2)
3332
@objective(model, Min, loss)
3433

3534
optimize!(model)

src/JuMP/operators.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,28 @@ function Base.broadcasted(
4545
return _broadcast(JuMP.variable_ref_type(x), op, x, y)
4646
end
4747

48+
function Base.broadcasted(op::Function, x::AbstractJuMPArray, y::Number)
49+
return _broadcast(JuMP.variable_ref_type(x), op, x, y)
50+
end
51+
52+
function Base.broadcasted(op::Function, x::Number, y::AbstractJuMPArray)
53+
return _broadcast(JuMP.variable_ref_type(y), op, x, y)
54+
end
55+
56+
function Base.broadcasted(
57+
::typeof(Base.literal_pow),
58+
::typeof(^),
59+
x::AbstractJuMPArray,
60+
::Val{y},
61+
) where {y}
62+
return Base.broadcasted(^, x, y)
63+
end
64+
65+
function Base.sum(x::GenericArrayExpr)
66+
V = JuMP.variable_ref_type(x)
67+
return JuMP.GenericNonlinearExpr{V}(:sum, Any[x])
68+
end
69+
4870
import LinearAlgebra
4971

5072
function _array_norm(x::AbstractJuMPArray)

src/reverse_mode.jl

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,29 @@ function _forward_eval(
465465
end
466466
@inbounds f.forward_storage[k] = tmp_prod
467467
end
468+
elseif node.index == 4 # :^ (broadcasted), array .^ scalar
469+
@assert N == 2
470+
idx1 = first(children_indices)
471+
idx2 = last(children_indices)
472+
@inbounds ix1 = children_arr[idx1]
473+
@inbounds ix2 = children_arr[idx2]
474+
@assert f.sizes.ndims[ix2] == 0 "Broadcasted ^ requires scalar exponent"
475+
@inbounds exponent =
476+
f.forward_storage[f.sizes.storage_offset[ix2]+1]
477+
for j in _eachindex(f.sizes, k)
478+
base = @j f.forward_storage[ix1]
479+
if exponent == 2
480+
@j f.forward_storage[k] = base * base
481+
@j f.partials_storage[ix1] = 2 * base
482+
elseif exponent == 1
483+
@j f.forward_storage[k] = base
484+
@j f.partials_storage[ix1] = one(T)
485+
else
486+
@j f.forward_storage[k] = pow(base, exponent)
487+
@j f.partials_storage[ix1] =
488+
exponent * pow(base, exponent - 1)
489+
end
490+
end
468491
end
469492
elseif node.type == NODE_CALL_UNIVARIATE
470493
child_idx = children_arr[f.adj.colptr[k]]
@@ -816,6 +839,35 @@ function _reverse_eval(f::_SubexpressionStorage)
816839
end
817840
continue
818841
end
842+
elseif op == :^
843+
# Broadcasted array .^ scalar: per-j reverse for the base,
844+
# and a sum-reduced reverse for the (scalar) exponent.
845+
@assert length(children_indices) == 2
846+
idx1 = first(children_indices)
847+
idx2 = last(children_indices)
848+
@inbounds ix1 = children_arr[idx1]
849+
@inbounds ix2 = children_arr[idx2]
850+
for j in _eachindex(f.sizes, k)
851+
rev_parent = @j f.reverse_storage[k]
852+
partial = @j f.partials_storage[ix1]
853+
val = ifelse(
854+
rev_parent == 0.0 && !isfinite(partial),
855+
rev_parent,
856+
rev_parent * partial,
857+
)
858+
@j f.reverse_storage[ix1] = val
859+
end
860+
rev_exp = zero(Float64)
861+
for j in _eachindex(f.sizes, k)
862+
rev_parent = @j f.reverse_storage[k]
863+
base = @j f.forward_storage[ix1]
864+
out = @j f.forward_storage[k]
865+
if base > 0
866+
rev_exp += rev_parent * out * log(base)
867+
end
868+
end
869+
@s f.reverse_storage[ix2] = rev_exp
870+
continue
819871
end
820872
end
821873
elseif node.type != NODE_CALL_UNIVARIATE &&

src/sizes.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,9 @@ function _infer_sizes(
288288
if op == :+ || op == :-
289289
# Broadcasted +/- preserves shape
290290
_copy_size!(sizes, k, children_arr[first(children_indices)])
291+
elseif op == :^
292+
# Broadcasted ^ with scalar exponent preserves base shape
293+
_copy_size!(sizes, k, children_arr[first(children_indices)])
291294
elseif op == :*
292295
# TODO assert compatible sizes and all ndims should be 0 or 2
293296
first_matrix = findfirst(children_indices) do i

test/ArrayDiff.jl

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -634,6 +634,58 @@ function test_objective_broadcasted_tanh()
634634
return
635635
end
636636

637+
function test_objective_broadcasted_pow_vector()
638+
model = ArrayDiff.Model()
639+
x1 = MOI.VariableIndex(1)
640+
x2 = MOI.VariableIndex(2)
641+
ArrayDiff.set_objective(model, :(sum([$x1, $x2] .^ 2)))
642+
evaluator = ArrayDiff.Evaluator(model, ArrayDiff.Mode(), [x1, x2])
643+
MOI.initialize(evaluator, [:Grad])
644+
x1v = 3.0
645+
x2v = -4.0
646+
@test MOI.eval_objective(evaluator, [x1v, x2v]) == x1v^2 + x2v^2
647+
g = ones(2)
648+
MOI.eval_objective_gradient(evaluator, g, [x1v, x2v])
649+
@test g == [2 * x1v, 2 * x2v]
650+
return
651+
end
652+
653+
function test_objective_broadcasted_pow_matrix_with_constant()
654+
model = ArrayDiff.Model()
655+
x1 = MOI.VariableIndex(1)
656+
x2 = MOI.VariableIndex(2)
657+
x3 = MOI.VariableIndex(3)
658+
x4 = MOI.VariableIndex(4)
659+
ArrayDiff.set_objective(
660+
model,
661+
:(sum(([$x1 $x2; $x3 $x4] - [1 1; 1 1]) .^ 2)),
662+
)
663+
evaluator = ArrayDiff.Evaluator(model, ArrayDiff.Mode(), [x1, x2, x3, x4])
664+
MOI.initialize(evaluator, [:Grad])
665+
xs = [1.0, 2.0, 3.0, 4.0]
666+
@test MOI.eval_objective(evaluator, xs) ==
667+
(1-1)^2 + (2-1)^2 + (3-1)^2 + (4-1)^2
668+
g = ones(4)
669+
MOI.eval_objective_gradient(evaluator, g, xs)
670+
@test g == [2 * (1 - 1), 2 * (2 - 1), 2 * (3 - 1), 2 * (4 - 1)]
671+
return
672+
end
673+
674+
function test_objective_broadcasted_pow_cubed()
675+
model = ArrayDiff.Model()
676+
x1 = MOI.VariableIndex(1)
677+
x2 = MOI.VariableIndex(2)
678+
ArrayDiff.set_objective(model, :(sum([$x1, $x2] .^ 3)))
679+
evaluator = ArrayDiff.Evaluator(model, ArrayDiff.Mode(), [x1, x2])
680+
MOI.initialize(evaluator, [:Grad])
681+
xs = [2.0, 3.0]
682+
@test MOI.eval_objective(evaluator, xs) 2.0^3 + 3.0^3
683+
g = ones(2)
684+
MOI.eval_objective_gradient(evaluator, g, xs)
685+
@test g [3 * 2.0^2, 3 * 3.0^2]
686+
return
687+
end
688+
637689
end # module
638690

639691
TestArrayDiff.runtests()

0 commit comments

Comments
 (0)