Skip to content

Commit 26bff65

Browse files
committed
Merge branch 'master' of github.com:wundergraph/graphql-go-tools into ludwig/improve-entity-mapping
2 parents 141a4c9 + eba0f58 commit 26bff65

12 files changed

Lines changed: 3447 additions & 2564 deletions

File tree

v2/pkg/engine/datasource/grpc_datasource/execution_plan.go

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,11 @@ func NewPlanner(subgraphName string, mapping *GRPCMapping, federationConfigs pla
365365

366366
// formatRPCMessage formats an RPCMessage and adds it to the string builder with the specified indentation
367367
func formatRPCMessage(sb *strings.Builder, message RPCMessage, indent int) {
368+
visited := make(map[*RPCMessage]struct{})
369+
formatRPCMessageVisited(sb, message, indent, visited)
370+
}
371+
372+
func formatRPCMessageVisited(sb *strings.Builder, message RPCMessage, indent int, visited map[*RPCMessage]struct{}) {
368373
indentStr := strings.Repeat(" ", indent)
369374

370375
fmt.Fprintf(sb, "%sName: %s\n", indentStr, message.Name)
@@ -377,25 +382,34 @@ func formatRPCMessage(sb *strings.Builder, message RPCMessage, indent int) {
377382
fmt.Fprintf(sb, "%s JSONPath: %s\n", indentStr, field.JSONPath)
378383
fmt.Fprintf(sb, "%s ResolvePath: %s\n", indentStr, field.ResolvePath.String())
379384

380-
if field.Message != nil {
381-
fmt.Fprintf(sb, "%s Message:\n", indentStr)
382-
formatRPCMessage(sb, *field.Message, indent+6)
385+
if field.Message == nil {
386+
return
387+
}
388+
389+
fmt.Fprintf(sb, "%s Message:\n", indentStr)
390+
if _, seen := visited[field.Message]; seen {
391+
fmt.Fprintf(sb, "%s <recursive: %s>\n", indentStr, field.Message.Name)
392+
continue
383393
}
394+
visited[field.Message] = struct{}{}
395+
formatRPCMessageVisited(sb, *field.Message, indent+6, visited)
384396
}
385397
}
386398

387399
type rpcPlanningContext struct {
388-
operation *ast.Document
389-
definition *ast.Document
390-
mapping *GRPCMapping
400+
operation *ast.Document
401+
definition *ast.Document
402+
mapping *GRPCMapping
403+
visitedInputTypes map[string]*RPCMessage
391404
}
392405

393406
// newRPCPlanningContext creates a new RPCPlanningContext.
394407
func newRPCPlanningContext(operation *ast.Document, definition *ast.Document, mapping *GRPCMapping) *rpcPlanningContext {
395408
return &rpcPlanningContext{
396-
operation: operation,
397-
definition: definition,
398-
mapping: mapping,
409+
operation: operation,
410+
definition: definition,
411+
mapping: mapping,
412+
visitedInputTypes: make(map[string]*RPCMessage, len(definition.InputObjectTypeDefinitions)),
399413
}
400414
}
401415

@@ -686,11 +700,26 @@ func (r *rpcPlanningContext) buildMessageFromInputObjectType(node *ast.Node) (*R
686700
return nil, fmt.Errorf("unable to build message from input object type definition - incorrect type: %s", node.Kind)
687701
}
688702

703+
typeName := node.NameString(r.definition)
704+
705+
// If we've already started building this type, return the in-progress message
706+
// pointer to break the recursion cycle. The message's fields are populated by
707+
// the caller that first entered this type, so the pointer will be complete once
708+
// the top-level call returns.
709+
if existing, ok := r.visitedInputTypes[typeName]; ok {
710+
return existing, nil
711+
}
712+
689713
inputObjectDefinition := r.definition.InputObjectTypeDefinitions[node.Ref]
690714
message := &RPCMessage{
691-
Name: node.NameString(r.definition),
715+
Name: typeName,
692716
Fields: make(RPCFields, 0, len(inputObjectDefinition.InputFieldsDefinition.Refs)),
693717
}
718+
719+
// Register the message before recursing into fields so that recursive
720+
// references resolve to this same pointer.
721+
r.visitedInputTypes[typeName] = message
722+
694723
for _, inputFieldRef := range inputObjectDefinition.InputFieldsDefinition.Refs {
695724
field, err := r.buildMessageFieldFromInputValueDefinition(inputFieldRef, node)
696725
if err != nil {
Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
package grpcdatasource
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/require"
7+
8+
"github.com/wundergraph/graphql-go-tools/v2/pkg/astparser"
9+
)
10+
11+
func TestExecutionPlan_RecursiveInputTypes_String(t *testing.T) {
12+
// Verify stringer method does not overflow on recursive inputs
13+
t.Parallel()
14+
15+
schema := `
16+
type Query {
17+
search(conditions: ConditionsInput): [Result!]!
18+
}
19+
20+
type Result {
21+
id: ID!
22+
name: String!
23+
}
24+
25+
input ConditionsInput {
26+
and: [ConditionsInput!]
27+
or: [ConditionsInput!]
28+
key: String
29+
value: String
30+
}`
31+
32+
mapping := &GRPCMapping{
33+
Service: "Search",
34+
QueryRPCs: map[string]RPCConfig{
35+
"search": {
36+
RPC: "Search",
37+
Request: "SearchRequest",
38+
Response: "SearchResponse",
39+
},
40+
},
41+
}
42+
43+
query := `query SearchQuery($conditions: ConditionsInput) { search(conditions: $conditions) { id name } }`
44+
45+
plan := planRecursiveTest(t, query, schema, mapping)
46+
47+
result := plan.String()
48+
// formatRPCMessage must emit the recursive placeholder instead of overflowing.
49+
require.Contains(t, result, "ConditionsInput")
50+
require.Contains(t, result, "<recursive: ConditionsInput>")
51+
}
52+
53+
func TestExecutionPlan_RecursiveInputTypes(t *testing.T) {
54+
t.Parallel()
55+
56+
t.Run("Should not stack overflow on recursive input object with and/or fields", func(t *testing.T) {
57+
t.Parallel()
58+
59+
schema := `
60+
type Query {
61+
search(conditions: ConditionsInput): [Result!]!
62+
}
63+
64+
type Result {
65+
id: ID!
66+
name: String!
67+
}
68+
69+
input ConditionsInput {
70+
and: [ConditionsInput!]
71+
or: [ConditionsInput!]
72+
key: String
73+
value: String
74+
}`
75+
76+
mapping := &GRPCMapping{
77+
Service: "Search",
78+
QueryRPCs: map[string]RPCConfig{
79+
"search": {
80+
RPC: "Search",
81+
Request: "SearchRequest",
82+
Response: "SearchResponse",
83+
},
84+
},
85+
}
86+
87+
query := `query SearchQuery($conditions: ConditionsInput) { search(conditions: $conditions) { id name } }`
88+
89+
plan := planRecursiveTest(t, query, schema, mapping)
90+
91+
require.Len(t, plan.Calls, 1)
92+
call := plan.Calls[0]
93+
require.Equal(t, "Search", call.MethodName)
94+
95+
// The request should have a conditions field with a recursive message.
96+
require.Len(t, call.Request.Fields, 1)
97+
conditionsField := call.Request.Fields[0]
98+
require.Equal(t, "conditions", conditionsField.JSONPath)
99+
require.NotNil(t, conditionsField.Message)
100+
require.Equal(t, "ConditionsInput", conditionsField.Message.Name)
101+
require.Len(t, conditionsField.Message.Fields, 4)
102+
103+
// The and/or fields should reference the same ConditionsInput message (cycle).
104+
andField := findField(t, conditionsField.Message.Fields, "and")
105+
orField := findField(t, conditionsField.Message.Fields, "or")
106+
require.True(t, andField.Message == conditionsField.Message, "and field should reference the same ConditionsInput message")
107+
require.True(t, orField.Message == conditionsField.Message, "or field should reference the same ConditionsInput message")
108+
})
109+
110+
t.Run("Should not stack overflow on self-referencing input object", func(t *testing.T) {
111+
t.Parallel()
112+
113+
schema := `
114+
type Query {
115+
filter(input: FilterInput): [Item!]!
116+
}
117+
118+
type Item {
119+
id: ID!
120+
}
121+
122+
input FilterInput {
123+
child: FilterInput
124+
value: String
125+
}`
126+
127+
mapping := &GRPCMapping{
128+
Service: "Items",
129+
QueryRPCs: map[string]RPCConfig{
130+
"filter": {
131+
RPC: "Filter",
132+
Request: "FilterRequest",
133+
Response: "FilterResponse",
134+
},
135+
},
136+
}
137+
138+
query := `query FilterQuery($input: FilterInput) { filter(input: $input) { id } }`
139+
140+
plan := planRecursiveTest(t, query, schema, mapping)
141+
142+
require.Len(t, plan.Calls, 1)
143+
call := plan.Calls[0]
144+
require.Equal(t, "Filter", call.MethodName)
145+
146+
require.Len(t, call.Request.Fields, 1)
147+
inputField := call.Request.Fields[0]
148+
require.Equal(t, "input", inputField.JSONPath)
149+
require.NotNil(t, inputField.Message)
150+
require.Equal(t, "FilterInput", inputField.Message.Name)
151+
require.Len(t, inputField.Message.Fields, 2)
152+
153+
// The child field should reference the same FilterInput message.
154+
childField := findField(t, inputField.Message.Fields, "child")
155+
require.True(t, childField.Message == inputField.Message, "child field should reference the same FilterInput message")
156+
})
157+
158+
t.Run("Should not stack overflow on mutually recursive input objects", func(t *testing.T) {
159+
t.Parallel()
160+
161+
schema := `
162+
type Query {
163+
evaluate(expr: ExprInput): Boolean!
164+
}
165+
166+
input ExprInput {
167+
not: NotExprInput
168+
value: String
169+
}
170+
171+
input NotExprInput {
172+
expr: ExprInput
173+
}`
174+
175+
mapping := &GRPCMapping{
176+
Service: "Eval",
177+
QueryRPCs: map[string]RPCConfig{
178+
"evaluate": {
179+
RPC: "Evaluate",
180+
Request: "EvaluateRequest",
181+
Response: "EvaluateResponse",
182+
},
183+
},
184+
}
185+
186+
query := `query EvalQuery($expr: ExprInput) { evaluate(expr: $expr) }`
187+
188+
plan := planRecursiveTest(t, query, schema, mapping)
189+
190+
require.Len(t, plan.Calls, 1)
191+
call := plan.Calls[0]
192+
require.Equal(t, "Evaluate", call.MethodName)
193+
194+
require.Len(t, call.Request.Fields, 1)
195+
exprField := call.Request.Fields[0]
196+
require.Equal(t, "expr", exprField.JSONPath)
197+
require.NotNil(t, exprField.Message)
198+
require.Equal(t, "ExprInput", exprField.Message.Name)
199+
require.Len(t, exprField.Message.Fields, 2)
200+
201+
// ExprInput.not -> NotExprInput.expr -> ExprInput (cycle)
202+
notField := findField(t, exprField.Message.Fields, "not")
203+
require.NotNil(t, notField.Message)
204+
require.Equal(t, "NotExprInput", notField.Message.Name)
205+
require.Len(t, notField.Message.Fields, 1)
206+
207+
backRef := findField(t, notField.Message.Fields, "expr")
208+
require.True(t, backRef.Message == exprField.Message, "NotExprInput.expr should reference the same ExprInput message")
209+
})
210+
}
211+
212+
func findField(t *testing.T, fields RPCFields, jsonPath string) RPCField {
213+
t.Helper()
214+
215+
for _, f := range fields {
216+
if f.JSONPath == jsonPath {
217+
return f
218+
}
219+
}
220+
221+
t.Fatalf("field with JSONPath %q not found", jsonPath)
222+
return RPCField{}
223+
}
224+
225+
func planRecursiveTest(t *testing.T, query, schema string, mapping *GRPCMapping) *RPCExecutionPlan {
226+
t.Helper()
227+
228+
schemaDoc := testSchema(t, schema)
229+
230+
queryDoc, report := astparser.ParseGraphqlDocumentString(query)
231+
require.False(t, report.HasErrors())
232+
233+
rpcPlanVisitor := newRPCPlanVisitor(rpcPlanVisitorConfig{
234+
subgraphName: mapping.Service,
235+
mapping: mapping,
236+
})
237+
238+
plan, err := rpcPlanVisitor.PlanOperation(&queryDoc, &schemaDoc)
239+
require.NoError(t, err)
240+
241+
return plan
242+
}

v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,17 +95,15 @@ func NewDataSource(client grpc.ClientConnInterface, config DataSourceConfig) (*D
9595
// It processes the input JSON data to make gRPC calls and returns
9696
// the response data.
9797
//
98-
// Headers are converted to gRPC metadata and part of gRPC calls.
98+
// Headers are converted to gRPC metadata and are part of gRPC calls.
9999
//
100100
// The input is expected to contain the necessary information to make
101101
// a gRPC call, including service name, method name, and request data.
102102
func (d *DataSource) Load(ctx context.Context, headers http.Header, input []byte) (data []byte, err error) {
103103
// get variables from input
104104
variables := gjson.Parse(unsafebytes.BytesToString(input)).Get("body.variables")
105105

106-
var (
107-
poolItems []*arena.PoolItem
108-
)
106+
var poolItems []*arena.PoolItem
109107
defer func() {
110108
d.pool.ReleaseMany(poolItems)
111109
}()

0 commit comments

Comments
 (0)