Skip to content

Commit fd0879f

Browse files
committed
Merge branch 'master' into jesse/eng-8566-new-subscription-client
2 parents 03de100 + fc6af0f commit fd0879f

19 files changed

Lines changed: 3933 additions & 2855 deletions
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
package grpcdatasource
2+
3+
import (
4+
"errors"
5+
"fmt"
6+
7+
"github.com/tidwall/gjson"
8+
9+
"github.com/wundergraph/astjson"
10+
)
11+
12+
// entityIndexMap maps positions in the typed gRPC response back to positions
13+
// in the original representations array. The slice index is the response
14+
// position; the value is the representation index. It is built per call by
15+
// recording the position of every representation whose __typename matches
16+
// the requested entity type.
17+
type entityIndexMap []int
18+
19+
// newEntityIndexMap builds the index map for a single entity call by collecting
20+
// the positions of representations whose __typename matches the requested type.
21+
// A single pass over representations populates the slice.
22+
func newEntityIndexMap(requestedEntityType string, representations []gjson.Result) entityIndexMap {
23+
indexMap := make(entityIndexMap, 0, len(representations))
24+
for i, representation := range representations {
25+
if representation.Get(typenameFieldName).String() == requestedEntityType {
26+
indexMap = append(indexMap, i)
27+
}
28+
}
29+
return indexMap
30+
}
31+
32+
// getRepresentations gets the representations from the variables.
33+
// If no representations are found, it returns nil.
34+
func getRepresentations(variables gjson.Result) []gjson.Result {
35+
r := variables.Get("representations")
36+
if !r.Exists() {
37+
return nil
38+
}
39+
40+
return r.Array()
41+
}
42+
43+
// validateEntityResponse verifies that the number of entities returned by the
44+
// subgraph matches the number of representations of the requested type.
45+
// Callers should subsequently build an entityIndexMap via newEntityIndexMap to
46+
// merge the response — mergeEntities relies on the invariant that
47+
// len(response entities) == len(indexMap), which this function establishes.
48+
func validateEntityResponse(data *astjson.Value, requestedEntityType string, representations []gjson.Result) error {
49+
if data == nil {
50+
return errors.New("validateEntityResponse: subgraph response data is nil")
51+
}
52+
53+
if requestedEntityType == "" {
54+
return errors.New("validateEntityResponse: requested entity type is empty; the entity RPC plan is missing a RequestedEntityType")
55+
}
56+
57+
if len(representations) == 0 {
58+
return errors.New("validateEntityResponse: no entity representations provided in the request variables")
59+
}
60+
61+
expected := 0
62+
for _, representation := range representations {
63+
if representation.Get(typenameFieldName).String() == requestedEntityType {
64+
expected++
65+
}
66+
}
67+
68+
entities := data.Get(entityPath).GetArray()
69+
if len(entities) != expected {
70+
return fmt.Errorf("entity type %s received %d entities in the subgraph response, but %d are expected", requestedEntityType, len(entities), expected)
71+
}
72+
73+
return nil
74+
}
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
package grpcdatasource
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/assert"
7+
"github.com/tidwall/gjson"
8+
9+
"github.com/wundergraph/astjson"
10+
)
11+
12+
func TestNewEntityIndexMap(t *testing.T) {
13+
t.Run("returns empty map when no representations match", func(t *testing.T) {
14+
reps := getRepresentations(gjson.Parse(`{"representations":[
15+
{"__typename":"Storage","id":"1"}
16+
]}`))
17+
idx := newEntityIndexMap("Product", reps)
18+
assert.Equal(t, entityIndexMap{}, idx)
19+
})
20+
21+
t.Run("returns empty map when representations are nil", func(t *testing.T) {
22+
idx := newEntityIndexMap("Product", nil)
23+
assert.Equal(t, entityIndexMap{}, idx)
24+
})
25+
26+
t.Run("ordered representations [Product, Product, Storage, Storage]", func(t *testing.T) {
27+
reps := getRepresentations(gjson.Parse(`{"representations":[
28+
{"__typename":"Product","id":"1"},
29+
{"__typename":"Product","id":"2"},
30+
{"__typename":"Storage","id":"3"},
31+
{"__typename":"Storage","id":"4"}
32+
]}`))
33+
34+
productIdx := newEntityIndexMap("Product", reps)
35+
assert.Equal(t, entityIndexMap{0, 1}, productIdx)
36+
37+
storageIdx := newEntityIndexMap("Storage", reps)
38+
assert.Equal(t, entityIndexMap{2, 3}, storageIdx)
39+
})
40+
41+
t.Run("unordered representations [Product, Storage, Product, Storage]", func(t *testing.T) {
42+
reps := getRepresentations(gjson.Parse(`{"representations":[
43+
{"__typename":"Product","id":"1"},
44+
{"__typename":"Storage","id":"2"},
45+
{"__typename":"Product","id":"3"},
46+
{"__typename":"Storage","id":"4"}
47+
]}`))
48+
49+
productIdx := newEntityIndexMap("Product", reps)
50+
assert.Equal(t, entityIndexMap{0, 2}, productIdx)
51+
52+
storageIdx := newEntityIndexMap("Storage", reps)
53+
assert.Equal(t, entityIndexMap{1, 3}, storageIdx)
54+
})
55+
56+
t.Run("interleaved representations across three types", func(t *testing.T) {
57+
reps := getRepresentations(gjson.Parse(`{"representations":[
58+
{"__typename":"Product","id":"1"},
59+
{"__typename":"Storage","id":"2"},
60+
{"__typename":"Warehouse","id":"3"},
61+
{"__typename":"Product","id":"4"},
62+
{"__typename":"Warehouse","id":"5"},
63+
{"__typename":"Storage","id":"6"}
64+
]}`))
65+
66+
assert.Equal(t, entityIndexMap{0, 3}, newEntityIndexMap("Product", reps))
67+
assert.Equal(t, entityIndexMap{1, 5}, newEntityIndexMap("Storage", reps))
68+
assert.Equal(t, entityIndexMap{2, 4}, newEntityIndexMap("Warehouse", reps))
69+
})
70+
71+
t.Run("single matching representation", func(t *testing.T) {
72+
reps := getRepresentations(gjson.Parse(`{"representations":[
73+
{"__typename":"Storage","id":"1"},
74+
{"__typename":"Product","id":"2"},
75+
{"__typename":"Storage","id":"3"}
76+
]}`))
77+
78+
assert.Equal(t, entityIndexMap{1}, newEntityIndexMap("Product", reps))
79+
})
80+
81+
t.Run("preserves original positions for fully matching list", func(t *testing.T) {
82+
reps := getRepresentations(gjson.Parse(`{"representations":[
83+
{"__typename":"Product","id":"1"},
84+
{"__typename":"Product","id":"2"},
85+
{"__typename":"Product","id":"3"}
86+
]}`))
87+
88+
assert.Equal(t, entityIndexMap{0, 1, 2}, newEntityIndexMap("Product", reps))
89+
})
90+
91+
t.Run("interface entity matches by typename string only", func(t *testing.T) {
92+
// Interface-entity representations carry the interface name as __typename
93+
// (e.g. "Resource"). The index map cares only about the typename string,
94+
// not whether it refers to an interface or a concrete type.
95+
reps := getRepresentations(gjson.Parse(`{"representations":[
96+
{"__typename":"Resource","id":"1"},
97+
{"__typename":"Product","id":"2"},
98+
{"__typename":"Resource","id":"3"},
99+
{"__typename":"Storage","id":"4"},
100+
{"__typename":"Resource","id":"5"}
101+
]}`))
102+
103+
assert.Equal(t, entityIndexMap{0, 2, 4}, newEntityIndexMap("Resource", reps))
104+
// Concrete types in the same list are independent.
105+
assert.Equal(t, entityIndexMap{1}, newEntityIndexMap("Product", reps))
106+
assert.Equal(t, entityIndexMap{3}, newEntityIndexMap("Storage", reps))
107+
})
108+
}
109+
110+
func TestGetRepresentations(t *testing.T) {
111+
t.Run("returns nil when representations key missing", func(t *testing.T) {
112+
vars := gjson.Parse(`{"other":"value"}`)
113+
assert.Nil(t, getRepresentations(vars))
114+
})
115+
116+
t.Run("returns empty slice when representations is empty array", func(t *testing.T) {
117+
vars := gjson.Parse(`{"representations":[]}`)
118+
reps := getRepresentations(vars)
119+
assert.NotNil(t, reps)
120+
assert.Empty(t, reps)
121+
})
122+
123+
t.Run("returns representations when present", func(t *testing.T) {
124+
vars := gjson.Parse(`{"representations":[{"__typename":"Product","id":"1"},{"__typename":"Storage","id":"2"}]}`)
125+
reps := getRepresentations(vars)
126+
assert.Len(t, reps, 2)
127+
assert.Equal(t, "Product", reps[0].Get("__typename").String())
128+
assert.Equal(t, "Storage", reps[1].Get("__typename").String())
129+
})
130+
}
131+
func TestValidateEntityResponse(t *testing.T) {
132+
reps := getRepresentations(gjson.Parse(`{"representations":[
133+
{"__typename":"Product","id":"1"},
134+
{"__typename":"Product","id":"2"}
135+
]}`))
136+
137+
t.Run("returns error when data is nil", func(t *testing.T) {
138+
err := validateEntityResponse(nil, "Product", reps)
139+
assert.ErrorContains(t, err, "validateEntityResponse: subgraph response data is nil")
140+
})
141+
142+
t.Run("returns error when requested entity type is empty", func(t *testing.T) {
143+
data := astjson.MustParse(`{"_entities":[]}`)
144+
err := validateEntityResponse(data, "", reps)
145+
assert.ErrorContains(t, err, "validateEntityResponse: requested entity type is empty")
146+
})
147+
148+
t.Run("returns error when representations are empty", func(t *testing.T) {
149+
data := astjson.MustParse(`{"_entities":[]}`)
150+
err := validateEntityResponse(data, "Product", nil)
151+
assert.ErrorContains(t, err, "validateEntityResponse: no entity representations provided")
152+
})
153+
154+
t.Run("returns error when entity count mismatches representation count", func(t *testing.T) {
155+
data := astjson.MustParse(`{"_entities":[{"__typename":"Product","id":"1"}]}`)
156+
err := validateEntityResponse(data, "Product", reps)
157+
assert.ErrorContains(t, err, "entity type Product received 1 entities in the subgraph response, but 2 are expected")
158+
})
159+
160+
t.Run("returns nil when entity count matches representation count", func(t *testing.T) {
161+
data := astjson.MustParse(`{"_entities":[{"__typename":"Product","id":"1"},{"__typename":"Product","id":"2"}]}`)
162+
assert.NoError(t, validateEntityResponse(data, "Product", reps))
163+
})
164+
165+
t.Run("counts only representations of the requested type", func(t *testing.T) {
166+
mixedReps := getRepresentations(gjson.Parse(`{"representations":[
167+
{"__typename":"Product","id":"1"},
168+
{"__typename":"Storage","id":"2"},
169+
{"__typename":"Product","id":"3"}
170+
]}`))
171+
data := astjson.MustParse(`{"_entities":[{"__typename":"Product","id":"1"},{"__typename":"Product","id":"3"}]}`)
172+
assert.NoError(t, validateEntityResponse(data, "Product", mixedReps))
173+
})
174+
175+
t.Run("returns error when _entities key is missing", func(t *testing.T) {
176+
data := astjson.MustParse(`{}`)
177+
err := validateEntityResponse(data, "Product", reps)
178+
assert.ErrorContains(t, err, "entity type Product received 0 entities in the subgraph response, but 2 are expected")
179+
})
180+
181+
t.Run("returns error when _entities path is not an array", func(t *testing.T) {
182+
data := astjson.MustParse(`{"_entities":"not an array"}`)
183+
err := validateEntityResponse(data, "Product", reps)
184+
assert.Error(t, err)
185+
})
186+
}

v2/pkg/engine/datasource/grpc_datasource/execution_plan.go

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,9 @@ type RPCCall struct {
110110
Response RPCMessage
111111
// ResponsePath is the path to the response in the JSON response
112112
ResponsePath ast.Path
113+
// RequestedEntityType is the type of the entity that is being requested
114+
// Empty if the call is not an entity lookup.
115+
RequestedEntityType string
113116
}
114117

115118
// RPCMessage represents a gRPC message structure for requests and responses.
@@ -363,6 +366,11 @@ func NewPlanner(subgraphName string, mapping *GRPCMapping, federationConfigs pla
363366

364367
// formatRPCMessage formats an RPCMessage and adds it to the string builder with the specified indentation
365368
func formatRPCMessage(sb *strings.Builder, message RPCMessage, indent int) {
369+
visited := make(map[*RPCMessage]struct{})
370+
formatRPCMessageVisited(sb, message, indent, visited)
371+
}
372+
373+
func formatRPCMessageVisited(sb *strings.Builder, message RPCMessage, indent int, visited map[*RPCMessage]struct{}) {
366374
indentStr := strings.Repeat(" ", indent)
367375

368376
fmt.Fprintf(sb, "%sName: %s\n", indentStr, message.Name)
@@ -375,25 +383,34 @@ func formatRPCMessage(sb *strings.Builder, message RPCMessage, indent int) {
375383
fmt.Fprintf(sb, "%s JSONPath: %s\n", indentStr, field.JSONPath)
376384
fmt.Fprintf(sb, "%s ResolvePath: %s\n", indentStr, field.ResolvePath.String())
377385

378-
if field.Message != nil {
379-
fmt.Fprintf(sb, "%s Message:\n", indentStr)
380-
formatRPCMessage(sb, *field.Message, indent+6)
386+
if field.Message == nil {
387+
return
388+
}
389+
390+
fmt.Fprintf(sb, "%s Message:\n", indentStr)
391+
if _, seen := visited[field.Message]; seen {
392+
fmt.Fprintf(sb, "%s <recursive: %s>\n", indentStr, field.Message.Name)
393+
continue
381394
}
395+
visited[field.Message] = struct{}{}
396+
formatRPCMessageVisited(sb, *field.Message, indent+6, visited)
382397
}
383398
}
384399

385400
type rpcPlanningContext struct {
386-
operation *ast.Document
387-
definition *ast.Document
388-
mapping *GRPCMapping
401+
operation *ast.Document
402+
definition *ast.Document
403+
mapping *GRPCMapping
404+
visitedInputTypes map[string]*RPCMessage
389405
}
390406

391407
// newRPCPlanningContext creates a new RPCPlanningContext.
392408
func newRPCPlanningContext(operation *ast.Document, definition *ast.Document, mapping *GRPCMapping) *rpcPlanningContext {
393409
return &rpcPlanningContext{
394-
operation: operation,
395-
definition: definition,
396-
mapping: mapping,
410+
operation: operation,
411+
definition: definition,
412+
mapping: mapping,
413+
visitedInputTypes: make(map[string]*RPCMessage, len(definition.InputObjectTypeDefinitions)),
397414
}
398415
}
399416

@@ -684,11 +701,26 @@ func (r *rpcPlanningContext) buildMessageFromInputObjectType(node *ast.Node) (*R
684701
return nil, fmt.Errorf("unable to build message from input object type definition - incorrect type: %s", node.Kind)
685702
}
686703

704+
typeName := node.NameString(r.definition)
705+
706+
// If we've already started building this type, return the in-progress message
707+
// pointer to break the recursion cycle. The message's fields are populated by
708+
// the caller that first entered this type, so the pointer will be complete once
709+
// the top-level call returns.
710+
if existing, ok := r.visitedInputTypes[typeName]; ok {
711+
return existing, nil
712+
}
713+
687714
inputObjectDefinition := r.definition.InputObjectTypeDefinitions[node.Ref]
688715
message := &RPCMessage{
689-
Name: node.NameString(r.definition),
716+
Name: typeName,
690717
Fields: make(RPCFields, 0, len(inputObjectDefinition.InputFieldsDefinition.Refs)),
691718
}
719+
720+
// Register the message before recursing into fields so that recursive
721+
// references resolve to this same pointer.
722+
r.visitedInputTypes[typeName] = message
723+
692724
for _, inputFieldRef := range inputObjectDefinition.InputFieldsDefinition.Refs {
693725
field, err := r.buildMessageFieldFromInputValueDefinition(inputFieldRef, node)
694726
if err != nil {

0 commit comments

Comments
 (0)