Skip to content
64 changes: 64 additions & 0 deletions v2/pkg/engine/datasource/grpc_datasource/codec.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package grpcdatasource

import (
"fmt"

"google.golang.org/grpc/encoding"
"google.golang.org/grpc/encoding/proto"
"google.golang.org/grpc/mem"
)

var defaultCodec = encoding.GetCodecV2("proto")

type connectCodec struct{}

// TODO: force codec as client option instead of registering it globally.
func init() {
encoding.RegisterCodecV2(&connectCodec{})
}

// Name implements [encoding.CodecV2].
func (c *connectCodec) Name() string {
// we use the default proto codec to allow marshalling our own message but not
// interfere with the default proto codec for servers to unmarshal it.
return proto.Name
}

// Marshal implements [encoding.CodecV2].
func (c *connectCodec) Marshal(v any) (out mem.BufferSlice, err error) {
switch v := v.(type) {
case *PreWiredInputMessage:
protoBytes, err := v.wire()
if err != nil {
return nil, err
}

if mem.IsBelowBufferPoolingThreshold(v.size) {
out = append(out, mem.SliceBuffer(protoBytes))
return out, nil
} else {
pool := mem.DefaultBufferPool()
buf := pool.Get(v.size)

copy(*buf, protoBytes)

out = append(out, mem.NewBuffer(buf, pool))
return out, nil
}
}

// TODO: This should never happen
if defaultCodec == nil {
return nil, fmt.Errorf("default codec is nil")
}

return defaultCodec.Marshal(v)
}

// Unmarshal implements [encoding.CodecV2].
// TODO: Unmarshal to astjson
func (c *connectCodec) Unmarshal(data mem.BufferSlice, v any) error {
return defaultCodec.Unmarshal(data, v)
}

var _ encoding.CodecV2 = (*connectCodec)(nil)
53 changes: 9 additions & 44 deletions v2/pkg/engine/datasource/grpc_datasource/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,8 @@ type EnumValue struct {
// RPCCompiler compiles protobuf schema strings into a Document and can
// build protobuf messages from JSON data based on the schema.
type RPCCompiler struct {
doc *Document // The compiled Document
Ancestor []Message
doc *Document // The compiled Document
runtime *runtimeSchema // The compiled runtime schema
}

// ServiceByName returns a Service by its name.
Expand Down Expand Up @@ -311,6 +311,13 @@ func NewProtoCompiler(schema string, mapping *GRPCMapping) (*RPCCompiler, error)
// Process the schema file
pc.processFile(schemaFile, mapping)

runtime, err := newSchemaRuntime(pc.doc)
if err != nil {
return nil, err
}

pc.runtime = runtime

return pc, nil
}

Expand Down Expand Up @@ -370,48 +377,6 @@ func (s *ServiceCall) MethodFullName() string {
return builder.String()
}

// func (p *RPCCompiler) CompileFetches(graph *DependencyGraph, fetches []FetchItem, inputData gjson.Result) ([]Invocation, error) {
// invocations := make([]Invocation, 0, len(fetches))

// resultChan := make(chan Invocation, len(fetches))
// errChan := make(chan error, len(fetches))

// wg := sync.WaitGroup{}
// wg.Add(len(fetches))

// for _, node := range fetches {
// go func() {
// defer wg.Done()
// invocation, err := p.CompileNode(graph, node, inputData)
// if err != nil {
// errChan <- err
// return
// }

// resultChan <- invocation
// node.Invocation = &invocation
// }()
// }

// close(resultChan)
// close(errChan)

// var joinErr error
// for err := range errChan {
// joinErr = errors.Join(joinErr, err)
// }

// if joinErr != nil {
// return nil, joinErr
// }

// for invocation := range resultChan {
// invocations = append(invocations, invocation)
// }

// return invocations, nil
// }

func (p *RPCCompiler) CompileFetches(graph *DependencyGraph, fetches []FetchItem, inputData gjson.Result) ([]ServiceCall, error) {
serviceCalls := make([]ServiceCall, 0, len(fetches))

Expand Down
49 changes: 39 additions & 10 deletions v2/pkg/engine/datasource/grpc_datasource/entity.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@ import (
"errors"
"fmt"

"github.com/tidwall/gjson"

"github.com/wundergraph/astjson"
"github.com/wundergraph/go-arena"
)

// entityIndexMap maps positions in the typed gRPC response back to positions
Expand All @@ -19,33 +18,63 @@ type entityIndexMap []int
// newEntityIndexMap builds the index map for a single entity call by collecting
// the positions of representations whose __typename matches the requested type.
// A single pass over representations populates the slice.
func newEntityIndexMap(requestedEntityType string, representations []gjson.Result) entityIndexMap {
func newEntityIndexMap(requestedEntityType string, representations []*astjson.Value) entityIndexMap {
indexMap := make(entityIndexMap, 0, len(representations))
for i, representation := range representations {
if representation.Get(typenameFieldName).String() == requestedEntityType {
if string(representation.Get(typenameFieldName).GetStringBytes()) == requestedEntityType {
indexMap = append(indexMap, i)
}
}
return indexMap
}

// getRepresentations gets the representations from the variables.
// If no representations are found, it returns nil.
func getRepresentations(variables gjson.Result) []gjson.Result {
// getRepresentationsAST gets the representations from the variables.
// If no representations are found, it returns an empty slice.
func getRepresentations(variables *astjson.Value) []*astjson.Value {
r := variables.Get("representations")
if !r.Exists() {
return nil
}

arr := r.GetArray()
if len(arr) == 0 {
return make([]*astjson.Value, 0)
}

return arr
}

// filterRepresentations filters the representations to only include the ones of the requested entity type.
func filterRepresentations(arena arena.Arena, variables *astjson.Value, requestedEntityType string) *astjson.Value {
r := variables.Get("representations")
if !r.Exists() {
return nil
}

return r.Array()
representations := r.GetArray()
if len(representations) == 0 {
return nil
}

ov := astjson.ObjectValue(arena)
representationsArr := astjson.ArrayValue(arena)

for _, representation := range representations {
if string(representation.Get(typenameFieldName).GetStringBytes()) == requestedEntityType {
representationsArr.SetArrayItem(arena, len(representationsArr.GetArray()), representation)
}
}

ov.Set(arena, "representations", representationsArr)
return ov
}

// validateEntityResponse verifies that the number of entities returned by the
// subgraph matches the number of representations of the requested type.
// Callers should subsequently build an entityIndexMap via newEntityIndexMap to
// merge the response — mergeEntities relies on the invariant that
// len(response entities) == len(indexMap), which this function establishes.
func validateEntityResponse(data *astjson.Value, requestedEntityType string, representations []gjson.Result) error {
func validateEntityResponse(data *astjson.Value, requestedEntityType string, representations []*astjson.Value) error {
if data == nil {
return errors.New("validateEntityResponse: subgraph response data is nil")
}
Expand All @@ -60,7 +89,7 @@ func validateEntityResponse(data *astjson.Value, requestedEntityType string, rep

expected := 0
for _, representation := range representations {
if representation.Get(typenameFieldName).String() == requestedEntityType {
if string(representation.Get(typenameFieldName).GetStringBytes()) == requestedEntityType {
expected++
}
}
Expand Down
29 changes: 14 additions & 15 deletions v2/pkg/engine/datasource/grpc_datasource/entity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/tidwall/gjson"

"github.com/wundergraph/astjson"
)

func TestNewEntityIndexMap(t *testing.T) {
t.Run("returns empty map when no representations match", func(t *testing.T) {
reps := getRepresentations(gjson.Parse(`{"representations":[
reps := getRepresentations(astjson.MustParse(`{"representations":[
{"__typename":"Storage","id":"1"}
]}`))
idx := newEntityIndexMap("Product", reps)
Expand All @@ -24,7 +23,7 @@ func TestNewEntityIndexMap(t *testing.T) {
})

t.Run("ordered representations [Product, Product, Storage, Storage]", func(t *testing.T) {
reps := getRepresentations(gjson.Parse(`{"representations":[
reps := getRepresentations(astjson.MustParse(`{"representations":[
{"__typename":"Product","id":"1"},
{"__typename":"Product","id":"2"},
{"__typename":"Storage","id":"3"},
Expand All @@ -39,7 +38,7 @@ func TestNewEntityIndexMap(t *testing.T) {
})

t.Run("unordered representations [Product, Storage, Product, Storage]", func(t *testing.T) {
reps := getRepresentations(gjson.Parse(`{"representations":[
reps := getRepresentations(astjson.MustParse(`{"representations":[
{"__typename":"Product","id":"1"},
{"__typename":"Storage","id":"2"},
{"__typename":"Product","id":"3"},
Expand All @@ -54,7 +53,7 @@ func TestNewEntityIndexMap(t *testing.T) {
})

t.Run("interleaved representations across three types", func(t *testing.T) {
reps := getRepresentations(gjson.Parse(`{"representations":[
reps := getRepresentations(astjson.MustParse(`{"representations":[
{"__typename":"Product","id":"1"},
{"__typename":"Storage","id":"2"},
{"__typename":"Warehouse","id":"3"},
Expand All @@ -69,7 +68,7 @@ func TestNewEntityIndexMap(t *testing.T) {
})

t.Run("single matching representation", func(t *testing.T) {
reps := getRepresentations(gjson.Parse(`{"representations":[
reps := getRepresentations(astjson.MustParse(`{"representations":[
{"__typename":"Storage","id":"1"},
{"__typename":"Product","id":"2"},
{"__typename":"Storage","id":"3"}
Expand All @@ -79,7 +78,7 @@ func TestNewEntityIndexMap(t *testing.T) {
})

t.Run("preserves original positions for fully matching list", func(t *testing.T) {
reps := getRepresentations(gjson.Parse(`{"representations":[
reps := getRepresentations(astjson.MustParse(`{"representations":[
{"__typename":"Product","id":"1"},
{"__typename":"Product","id":"2"},
{"__typename":"Product","id":"3"}
Expand All @@ -92,7 +91,7 @@ func TestNewEntityIndexMap(t *testing.T) {
// Interface-entity representations carry the interface name as __typename
// (e.g. "Resource"). The index map cares only about the typename string,
// not whether it refers to an interface or a concrete type.
reps := getRepresentations(gjson.Parse(`{"representations":[
reps := getRepresentations(astjson.MustParse(`{"representations":[
{"__typename":"Resource","id":"1"},
{"__typename":"Product","id":"2"},
{"__typename":"Resource","id":"3"},
Expand All @@ -109,27 +108,27 @@ func TestNewEntityIndexMap(t *testing.T) {

func TestGetRepresentations(t *testing.T) {
t.Run("returns nil when representations key missing", func(t *testing.T) {
vars := gjson.Parse(`{"other":"value"}`)
vars := astjson.MustParse(`{"other":"value"}`)
assert.Nil(t, getRepresentations(vars))
})

t.Run("returns empty slice when representations is empty array", func(t *testing.T) {
vars := gjson.Parse(`{"representations":[]}`)
vars := astjson.MustParse(`{"representations":[]}`)
reps := getRepresentations(vars)
assert.NotNil(t, reps)
assert.Empty(t, reps)
})

t.Run("returns representations when present", func(t *testing.T) {
vars := gjson.Parse(`{"representations":[{"__typename":"Product","id":"1"},{"__typename":"Storage","id":"2"}]}`)
vars := astjson.MustParse(`{"representations":[{"__typename":"Product","id":"1"},{"__typename":"Storage","id":"2"}]}`)
reps := getRepresentations(vars)
assert.Len(t, reps, 2)
assert.Equal(t, "Product", reps[0].Get("__typename").String())
assert.Equal(t, "Storage", reps[1].Get("__typename").String())
assert.Equal(t, "Product", string(reps[0].Get("__typename").GetStringBytes()))
assert.Equal(t, "Storage", string(reps[1].Get("__typename").GetStringBytes()))
})
}
func TestValidateEntityResponse(t *testing.T) {
reps := getRepresentations(gjson.Parse(`{"representations":[
reps := getRepresentations(astjson.MustParse(`{"representations":[
{"__typename":"Product","id":"1"},
{"__typename":"Product","id":"2"}
]}`))
Expand Down Expand Up @@ -163,7 +162,7 @@ func TestValidateEntityResponse(t *testing.T) {
})

t.Run("counts only representations of the requested type", func(t *testing.T) {
mixedReps := getRepresentations(gjson.Parse(`{"representations":[
mixedReps := getRepresentations(astjson.MustParse(`{"representations":[
{"__typename":"Product","id":"1"},
{"__typename":"Storage","id":"2"},
{"__typename":"Product","id":"3"}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package grpcdatasource
import (
"testing"

"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/require"

"github.com/wundergraph/graphql-go-tools/v2/pkg/astparser"
Expand Down Expand Up @@ -986,10 +985,7 @@ func TestCompositeTypeExecutionPlan(t *testing.T) {
}

require.Empty(t, tt.expectedError)
diff := cmp.Diff(tt.expectedPlan, plan)
if diff != "" {
t.Fatalf("execution plan mismatch: %s", diff)
}
assertExecutionPlanEqual(t, tt.expectedPlan, plan)
})
}
}
Expand Down Expand Up @@ -1252,10 +1248,7 @@ func TestMutationUnionExecutionPlan(t *testing.T) {
}

require.Empty(t, tt.expectedError)
diff := cmp.Diff(tt.expectedPlan, plan)
if diff != "" {
t.Fatalf("execution plan mismatch: %s", diff)
}
assertExecutionPlanEqual(t, tt.expectedPlan, plan)
})
}
}
Loading
Loading