Skip to content

Commit d8da865

Browse files
committed
Allow reduce + tests
**Note it is not a Nodetype for the moment** only expending the reduce notation
1 parent f638758 commit d8da865

3 files changed

Lines changed: 108 additions & 1 deletion

File tree

src/operators.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ const DEFAULT_MULTIVARIATE_OPERATORS = [
1818
:norm,
1919
:sum,
2020
:row,
21+
:reduce,
2122
]
2223

2324
function _validate_register_assumptions(

src/parse.jl

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,9 @@ end
158158

159159
function _parse_expression(stack, data, expr, x, parent_index)
160160
if Meta.isexpr(x, :call)
161-
if length(x.args) == 2 && !Meta.isexpr(x.args[2], :...)
161+
if x.args[1] == :reduce
162+
_parse_reduce_expression(stack, data, expr, x, parent_index)
163+
elseif length(x.args) == 2 && !Meta.isexpr(x.args[2], :...)
162164
_parse_univariate_expression(stack, data, expr, x, parent_index)
163165
else
164166
# The call is either n-ary, or it is a splat, in which case we
@@ -278,6 +280,36 @@ function _parse_vcat_expression(
278280
return
279281
end
280282

283+
function _parse_reduce_expression(stack, data, expr, x, parent_index)
284+
if length(x.args) != 3
285+
error("Unsupported reduce expression: $x. Expected reduce(op, collection).")
286+
end
287+
288+
op = x.args[2]
289+
collection = x.args[3]
290+
291+
if !Meta.isexpr(collection, :vect)
292+
error("Unsupported reduce collection: $collection. Expected a vector literal.")
293+
end
294+
295+
args = collection.args
296+
297+
if isempty(args)
298+
error("Unsupported reduce on empty collection.")
299+
elseif length(args) == 1
300+
push!(stack, (parent_index, args[1]))
301+
return
302+
end
303+
304+
folded = Expr(:call, op, args[1], args[2])
305+
for i in 3:length(args)
306+
folded = Expr(:call, op, folded, args[i])
307+
end
308+
309+
push!(stack, (parent_index, folded))
310+
return
311+
end
312+
281313
function _parse_inequality_expression(
282314
stack::Vector{Tuple{Int,Any}},
283315
data::Model,

test/ArrayDiff.jl

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -634,6 +634,80 @@ function test_objective_broadcasted_tanh()
634634
return
635635
end
636636

637+
function test_objective_reduce_sum()
638+
model = ArrayDiff.Model()
639+
x1 = MOI.VariableIndex(1)
640+
x2 = MOI.VariableIndex(2)
641+
x3 = MOI.VariableIndex(3)
642+
ArrayDiff.set_objective(model, :(reduce(+, [$x1, $x2, $x3])))
643+
evaluator = ArrayDiff.Evaluator(model, ArrayDiff.Mode(), [x1, x2, x3])
644+
MOI.initialize(evaluator, [:Grad])
645+
sizes = evaluator.backend.objective.expr.sizes
646+
@test sizes.ndims == [0, 0, 0, 0, 0]
647+
@test sizes.size_offset == [0, 0, 0, 0, 0]
648+
@test sizes.size == []
649+
@test sizes.storage_offset == [0, 1, 2, 3, 4, 5]
650+
x1 = 1.0
651+
x2 = 2.0
652+
x3 = 3.0
653+
@test MOI.eval_objective(evaluator, [x1, x2, x3]) == 6.0
654+
g = ones(3)
655+
MOI.eval_objective_gradient(evaluator, g, [x1, x2, x3])
656+
@test g == [1.0, 1.0, 1.0]
657+
return
658+
end
659+
660+
function test_objective_reduce_prod()
661+
model = ArrayDiff.Model()
662+
x1 = MOI.VariableIndex(1)
663+
x2 = MOI.VariableIndex(2)
664+
x3 = MOI.VariableIndex(3)
665+
ArrayDiff.set_objective(model, :(reduce(*, [$x1, $x2, $x3])))
666+
evaluator = ArrayDiff.Evaluator(model, ArrayDiff.Mode(), [x1, x2, x3])
667+
MOI.initialize(evaluator, [:Grad])
668+
sizes = evaluator.backend.objective.expr.sizes
669+
@test sizes.ndims == [0, 0, 0, 0, 0]
670+
@test sizes.size_offset == [0, 0, 0, 0, 0]
671+
@test sizes.size == []
672+
@test sizes.storage_offset == [0, 1, 2, 3, 4, 5]
673+
x1 = 1.0
674+
x2 = 2.0
675+
x3 = 3.0
676+
@test MOI.eval_objective(evaluator, [x1, x2, x3]) == 6.0
677+
g = ones(3)
678+
MOI.eval_objective_gradient(evaluator, g, [x1, x2, x3])
679+
@test g == [6.0 / x1, 6.0 / x2, 6.0 / x3]
680+
return
681+
end
682+
683+
function test_objective_reduce_atan()
684+
model = ArrayDiff.Model()
685+
x1 = MOI.VariableIndex(1)
686+
x2 = MOI.VariableIndex(2)
687+
x3 = MOI.VariableIndex(3)
688+
ArrayDiff.set_objective(model, :(reduce(atan, [$x1, $x2, $x3])))
689+
evaluator = ArrayDiff.Evaluator(model, ArrayDiff.Mode(), [x1, x2, x3])
690+
MOI.initialize(evaluator, [:Grad])
691+
sizes = evaluator.backend.objective.expr.sizes
692+
@test sizes.ndims == [0, 0, 0, 0, 0]
693+
@test sizes.size_offset == [0, 0, 0, 0, 0]
694+
@test sizes.size == []
695+
@test sizes.storage_offset == [0, 1, 2, 3, 4, 5]
696+
x1 = 1.0
697+
x2 = 2.0
698+
x3 = 3.0
699+
@test MOI.eval_objective(evaluator, [x1, x2, x3]) ==
700+
atan(atan(x1, x2), x3)
701+
g = ones(3)
702+
MOI.eval_objective_gradient(evaluator, g, [x1, x2, x3])
703+
@test g [
704+
x2 * x3 / ((x1^2 + x2^2) * (x3^2 + atan(x1, x2)^2)),
705+
-x1 * x3 / ((x1^2 + x2^2) * (x3^2 + atan(x1, x2)^2)),
706+
-atan(x1, x2) / (x3^2 + atan(x1, x2)^2),
707+
]
708+
return
709+
end
710+
637711
end # module
638712

639713
TestArrayDiff.runtests()

0 commit comments

Comments
 (0)