Skip to content

Commit 1291166

Browse files
committed
Add _is_real for GenericArrayExpr and restore full tests
JuMP's GenericNonlinearExpr constructor validates arguments via _is_real(). GenericArrayExpr was missing this method, causing norm() to fail when wrapping array expressions in NonlinearExpr. Also consolidate the L2 loss tests into a single comprehensive test. https://claude.ai/code/session_01GWT1QHA3D5BpMQBEHvgbcV
1 parent 15ac1e3 commit 1291166

2 files changed

Lines changed: 10 additions & 6 deletions

File tree

src/JuMP/nlp_expr.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,5 @@ end
2020
Base.size(expr::GenericArrayExpr) = expr.size
2121

2222
JuMP.variable_ref_type(::Type{GenericArrayExpr{V,N}}) where {V,N} = V
23+
24+
JuMP._is_real(::GenericArrayExpr) = true

test/JuMP.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ function test_l2_loss_tanh()
120120
return
121121
end
122122

123-
function test_l2_loss_nested()
123+
function test_l2_loss()
124124
n = 2
125125
X = rand(n, n)
126126
Y = rand(n, n)
@@ -130,12 +130,14 @@ function test_l2_loss_nested()
130130
Y_hat = W2 * tanh.(W1 * X)
131131
diff_expr = Y_hat .- Y
132132
@test diff_expr isa ArrayDiff.MatrixExpr
133-
# Test creating NonlinearExpr manually
134-
loss = JuMP.GenericNonlinearExpr{JuMP.VariableRef}(
135-
:norm,
136-
Any[diff_expr],
137-
)
133+
@test diff_expr.head == :-
134+
@test diff_expr.broadcasted
135+
@test diff_expr.args[1] === Y_hat
136+
@test diff_expr.args[2] === Y
137+
loss = LinearAlgebra.norm(diff_expr)
138138
@test loss isa JuMP.NonlinearExpr
139+
@test loss.head == :norm
140+
@test loss.args[1] === diff_expr
139141
return
140142
end
141143

0 commit comments

Comments
 (0)