From e07c2bddabb25ac3ce43056fa9699ec29c81a59a Mon Sep 17 00:00:00 2001 From: Ludwig Bedacht Date: Wed, 22 Apr 2026 11:47:36 +0200 Subject: [PATCH 01/12] feat: build messages using the wire protocol directly --- .../datasource/grpc_datasource/codec.go | 53 ++ .../grpc_datasource/grpc_datasource.go | 3 + .../datasource/grpc_datasource/program.go | 193 +++++ .../grpc_datasource/program_test.go | 63 ++ .../datasource/grpc_datasource/runtime.go | 109 +++ .../grpc_datasource/runtime_test.go | 52 ++ .../engine/datasource/grpc_datasource/wire.go | 384 +++++++++ .../datasource/grpc_datasource/wire_test.go | 741 ++++++++++++++++++ 8 files changed, 1598 insertions(+) create mode 100644 v2/pkg/engine/datasource/grpc_datasource/codec.go create mode 100644 v2/pkg/engine/datasource/grpc_datasource/program.go create mode 100644 v2/pkg/engine/datasource/grpc_datasource/program_test.go create mode 100644 v2/pkg/engine/datasource/grpc_datasource/runtime.go create mode 100644 v2/pkg/engine/datasource/grpc_datasource/runtime_test.go create mode 100644 v2/pkg/engine/datasource/grpc_datasource/wire.go create mode 100644 v2/pkg/engine/datasource/grpc_datasource/wire_test.go diff --git a/v2/pkg/engine/datasource/grpc_datasource/codec.go b/v2/pkg/engine/datasource/grpc_datasource/codec.go new file mode 100644 index 000000000..f2be1dc2c --- /dev/null +++ b/v2/pkg/engine/datasource/grpc_datasource/codec.go @@ -0,0 +1,53 @@ +package grpcdatasource + +import ( + "google.golang.org/grpc/encoding" + "google.golang.org/grpc/encoding/proto" + _ "google.golang.org/grpc/encoding/proto" + "google.golang.org/grpc/mem" +) + +var defaultCodec = encoding.GetCodecV2("proto") + +type connectCodec struct{} + +// Name implements [encoding.CodecV2]. +func (c *connectCodec) Name() string { + // we use the default proto codec to allow marshalling our own message but not + // interfere with the default proto codec for servers to unmarshal it. + return proto.Name +} + +// Marshal implements [encoding.CodecV2]. +func (c *connectCodec) Marshal(v any) (out mem.BufferSlice, err error) { + switch v := v.(type) { + case *PreWiredInputMessage: + protoBytes, err := v.wire() + if err != nil { + return nil, err + } + + if mem.IsBelowBufferPoolingThreshold(v.size) { + out = append(out, mem.SliceBuffer(protoBytes)) + return out, nil + } else { + pool := mem.DefaultBufferPool() + buf := pool.Get(v.size) + + copy(*buf, protoBytes) + + out = append(out, mem.NewBuffer(buf, pool)) + return out, nil + } + } + + return defaultCodec.Marshal(v) +} + +// Unmarshal implements [encoding.CodecV2]. +// TODO: Unmarshal to astjson +func (c *connectCodec) Unmarshal(data mem.BufferSlice, v any) error { + return defaultCodec.Unmarshal(data, v) +} + +var _ encoding.CodecV2 = (*connectCodec)(nil) diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go index 527b3e244..8c15816d7 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go @@ -136,6 +136,9 @@ func (d *DataSource) Load(ctx context.Context, headers http.Header, input []byte representations := getRepresentations(variables) if err := graph.TopologicalSortResolve(func(nodes []FetchItem) error { + // TODO: Compile fetches should be converted to a program. + // The program defines all the fetches that need to be executed in parallel for a given query. + serviceCalls, err := d.rc.CompileFetches(graph, nodes, variables) if err != nil { return err diff --git a/v2/pkg/engine/datasource/grpc_datasource/program.go b/v2/pkg/engine/datasource/grpc_datasource/program.go new file mode 100644 index 000000000..a0441a133 --- /dev/null +++ b/v2/pkg/engine/datasource/grpc_datasource/program.go @@ -0,0 +1,193 @@ +package grpcdatasource + +import ( + "fmt" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" +) + +type program struct { + stages []stage +} + +type stage struct { + fetches []fetch +} + +type fetch struct { + id int + kind CallKind + dependentCall *RPCCall + serviceName string + methodName string + responsePath ast.Path + request *fetchRequest + response *fetchResponse +} + +type fetchRequest struct { + message *runtimeMessage + rpcMessage RPCMessage + // The wire message will be created fromt the + // request structure. + wire *wireMessage +} + +type fetchResponse struct { + // reponse type is the type of the response message. + responseType *runtimeMessage +} + +func compileProgram(plan *RPCExecutionPlan, runtime *runtimeSchema) (*program, error) { + stageIndexes, err := compileStageIndexes(plan) + if err != nil { + return nil, err + } + + // We are calculating the number of stages by finding the maximum stage index and adding 1. + stageCount := 0 + for _, stageIndex := range stageIndexes { + if stageIndex+1 > stageCount { + stageCount = stageIndex + 1 + } + } + + program := &program{ + stages: make([]stage, stageCount), + } + + stageMap := make(map[int][]fetch, stageCount) + + for i := range plan.Calls { + call := &plan.Calls[i] + + // Currently we only support one dependent call. + var dependentCall *RPCCall + if len(call.DependentCalls) > 0 { + dependentCall = &plan.Calls[call.DependentCalls[0]] + } + + fetch, err := compileFetch(call, runtime, dependentCall) + if err != nil { + return nil, err + } + + stageMap[stageIndexes[call.ID]] = append(stageMap[stageIndexes[call.ID]], fetch) + } + + for i := 0; i < stageCount; i++ { + program.stages[i] = stage{ + fetches: stageMap[i], + } + } + + return program, nil +} + +func compileFetch(call *RPCCall, runtime *runtimeSchema, dependentCall *RPCCall) (fetch, error) { + serviceName, ok := runtime.serviceNamesByMethod[call.MethodName] + if !ok { + return fetch{}, fmt.Errorf("service name not found for method %s", call.MethodName) + } + + f := fetch{ + id: call.ID, + kind: call.Kind, + dependentCall: dependentCall, + serviceName: serviceName, + methodName: call.MethodName, + responsePath: call.ResponsePath, + } + + requestMessage := runtime.getMessageByName(call.Request.Name) + if requestMessage == nil { + return fetch{}, fmt.Errorf("request message not found for method %s", call.MethodName) + } + + responseMessage := runtime.getMessageByName(call.Response.Name) + if responseMessage == nil { + return fetch{}, fmt.Errorf("response message not found for method %s", call.MethodName) + } + + f.request = &fetchRequest{ + message: requestMessage, + rpcMessage: call.Request, + } + + f.response = &fetchResponse{ + responseType: responseMessage, + } + + wireMessage, err := compileWireMessage(&f.request.rpcMessage, requestMessage) + if err != nil { + return fetch{}, err + } + + f.request.wire = wireMessage + + return f, nil +} + +func compileStageIndexes(plan *RPCExecutionPlan) ([]int, error) { + // We are using a slice to store the batch index for each noded ordered. + stageIndexes := initializeSlice(len(plan.Calls), -1) + cycleChecks := make([]bool, len(plan.Calls)) + + var visit func(index int) error + visit = func(index int) error { + if cycleChecks[index] { + return fmt.Errorf("cycle detected") + } + + // We are marking the call as visited to avoid cycles. + cycleChecks[index] = true + + call := &plan.Calls[index] + if len(call.DependentCalls) == 0 { + // If the call has no dependencies, we are setting the level to 0 and return early. + stageIndexes[index] = 0 + return nil + } + + currentLevel := 0 + // We are iterating over the dependent calls of the current call. + for _, depCallIndex := range call.DependentCalls { + if depCallIndex < 0 || depCallIndex >= len(plan.Calls) { + return fmt.Errorf("unable to find dependent call %d in execution plan", depCallIndex) + } + + // If the dependent call has already been visited, we are checking if the level of the dependent call is greater than the current level. + // If it is, we are updating the current level to the level of the dependent call. + if depLevel := stageIndexes[depCallIndex]; depLevel >= 0 { + if depLevel > currentLevel { + currentLevel = depLevel + } + continue + } + + // If the dependent call has not been visited, we are visiting it. + if err := visit(depCallIndex); err != nil { + return err + } + + // If the level of the dependent call is greater than the current level, we are updating the current level to the level of the dependent call. + if l := stageIndexes[depCallIndex]; l > currentLevel { + currentLevel = l + } + } + + // After receiving the maximum level of the dependent calls, we increment the level by 1. + stageIndexes[index] = currentLevel + 1 + return nil + } + + for callIndex := range plan.Calls { + if err := visit(callIndex); err != nil { + return nil, err + } + + clear(cycleChecks) + } + + return stageIndexes, nil +} diff --git a/v2/pkg/engine/datasource/grpc_datasource/program_test.go b/v2/pkg/engine/datasource/grpc_datasource/program_test.go new file mode 100644 index 000000000..cfbec33dc --- /dev/null +++ b/v2/pkg/engine/datasource/grpc_datasource/program_test.go @@ -0,0 +1,63 @@ +package grpcdatasource + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" + "github.com/wundergraph/graphql-go-tools/v2/pkg/astparser" + "github.com/wundergraph/graphql-go-tools/v2/pkg/grpctest" +) + +func TestCompileProgram(t *testing.T) { + t.Parallel() + + type expected struct { + stageCount int + } + + tests := []struct { + name string + operation string + expected expected + err error + }{ + { + name: "simple program", + operation: `query UsersWithTypename { users { __typename id __typename name } }`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + // Parse the GraphQL schema + schemaDoc := grpctest.MustGraphQLSchema(t) + // Parse the GraphQL query + queryDoc, report := astparser.ParseGraphqlDocumentString(tt.operation) + if report.HasErrors() { + t.Fatalf("failed to parse query: %s", report.Error()) + } + + rpcPlanVisitor := newRPCPlanVisitor(rpcPlanVisitorConfig{ + subgraphName: "Products", + mapping: testMapping(), + }) + + plan, err := rpcPlanVisitor.PlanOperation(&queryDoc, &schemaDoc) + require.NoError(t, err) + + compiler, err := NewProtoCompiler(grpctest.MustProtoSchema(t), testMapping()) + require.NoError(t, err) + + runtime, err := newSchemaRuntime(compiler) + require.NoError(t, err) + + program, err := compileProgram(plan, runtime) + require.NoError(t, err) + + fmt.Println("program", program) + }) + } +} diff --git a/v2/pkg/engine/datasource/grpc_datasource/runtime.go b/v2/pkg/engine/datasource/grpc_datasource/runtime.go new file mode 100644 index 000000000..013a3049f --- /dev/null +++ b/v2/pkg/engine/datasource/grpc_datasource/runtime.go @@ -0,0 +1,109 @@ +package grpcdatasource + +import ( + "fmt" + + protoref "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/dynamicpb" +) + +type runtimeSchema struct { + messageByName map[string]*runtimeMessage + messageByFullname map[string]*runtimeMessage + serviceNamesByMethod map[string]string +} + +type runtimeMessage struct { + name string + desc protoref.MessageDescriptor + dynamicType protoref.MessageType + fieldsByName map[string]*runtimeField +} + +type runtimeField struct { + name string + owner *runtimeMessage + desc protoref.FieldDescriptor + genDesc protoref.FieldDescriptor + dataType DataType + message *runtimeMessage + repeated bool + optional bool +} + +func newSchemaRuntime(compiler *RPCCompiler) (*runtimeSchema, error) { + runtime := &runtimeSchema{ + messageByName: make(map[string]*runtimeMessage, len(compiler.doc.Messages)), + messageByFullname: make(map[string]*runtimeMessage, len(compiler.doc.Messages)), + serviceNamesByMethod: make(map[string]string, len(compiler.doc.Methods)), + } + + for i := range compiler.doc.Messages { + message := &compiler.doc.Messages[i] + + rtMessage := &runtimeMessage{ + name: message.Name, + desc: message.Desc, + dynamicType: dynamicpb.NewMessageType(message.Desc), + fieldsByName: make(map[string]*runtimeField, message.Desc.Fields().Len()), + } + + runtime.messageByName[message.Name] = rtMessage + runtime.messageByFullname[string(message.Desc.FullName())] = rtMessage + } + + for _, message := range runtime.messageByName { + if err := appendMessageFields(runtime, message); err != nil { + return nil, err + } + } + + for _, service := range compiler.doc.Services { + for i := range service.MethodsRefs { + runtime.serviceNamesByMethod[compiler.doc.Methods[i].Name] = service.FullName + } + } + + return runtime, nil +} + +func appendMessageFields(runtime *runtimeSchema, message *runtimeMessage) error { + for i := 0; i < message.desc.Fields().Len(); i++ { + fieldDesc := message.desc.Fields().Get(i) + + field := &runtimeField{ + owner: message, + name: string(fieldDesc.Name()), + desc: fieldDesc, + dataType: parseDataType(fieldDesc.Kind()), + repeated: fieldDesc.IsList(), + optional: fieldDesc.Cardinality() == protoref.Optional, + } + + if field.dataType == DataTypeMessage { + child, found := runtime.messageByFullname[string(fieldDesc.Message().FullName())] + if !found { + return fmt.Errorf("message %s not found in document", string(fieldDesc.Message().FullName())) + } + + field.message = child + } + + message.fieldsByName[string(fieldDesc.Name())] = field + } + + return nil +} + +func (r *runtimeSchema) getMessageByName(name string) *runtimeMessage { + message, found := r.messageByName[name] + if !found { + return nil + } + + return message +} + +func (m *runtimeMessage) newEmptyMessage() protoref.Message { + return m.dynamicType.New() +} diff --git a/v2/pkg/engine/datasource/grpc_datasource/runtime_test.go b/v2/pkg/engine/datasource/grpc_datasource/runtime_test.go new file mode 100644 index 000000000..a4b224e47 --- /dev/null +++ b/v2/pkg/engine/datasource/grpc_datasource/runtime_test.go @@ -0,0 +1,52 @@ +package grpcdatasource + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNewSchemaRuntime(t *testing.T) { + compiler, err := NewProtoCompiler(testSchemaWithLookup, testMapping()) + require.NoError(t, err) + + runtime, err := newSchemaRuntime(compiler) + require.NoError(t, err) + + require.Equal(t, 5, len(runtime.messageByName)) + require.Equal(t, 5, len(runtime.messageByFullname)) +} + +// =============== Test Schemas ================== // + +var testSchemaWithLookup = ` +syntax = "proto3"; +package product.v1; + +service ProductService { + rpc LookupProductById(LookupProductByIdRequest) returns (LookupProductByIdResponse) {} +} + +message LookupProductByIdRequest { + repeated LookupProductByIdInput inputs = 1; +} + +message LookupProductByIdInput { + string id = 1; +} + +message LookupProductByIdResponse { + repeated LookupProductByIdResult results = 1; +} + +message LookupProductByIdResult { + Product product = 1; +} + +message Product { + string id = 1; + string name = 2; + double price = 3; +} + +` diff --git a/v2/pkg/engine/datasource/grpc_datasource/wire.go b/v2/pkg/engine/datasource/grpc_datasource/wire.go new file mode 100644 index 000000000..8e6208512 --- /dev/null +++ b/v2/pkg/engine/datasource/grpc_datasource/wire.go @@ -0,0 +1,384 @@ +package grpcdatasource + +import ( + "bytes" + "fmt" + "math" + "sync" + + "github.com/wundergraph/astjson" + "google.golang.org/protobuf/encoding/protowire" +) + +type PreWiredInputMessage struct { + size int + buffer []byte +} + +func NewPreWiredInputMessage(buffer []byte) *PreWiredInputMessage { + return &PreWiredInputMessage{ + size: len(buffer), + buffer: buffer, + } +} + +func (c *PreWiredInputMessage) wire() ([]byte, error) { + if c.buffer == nil { + return nil, fmt.Errorf("connect message not initialized") + } + + return c.buffer, nil +} + +type wireMessage struct { + fields []wireField + oneOfType OneOfType +} + +type wireField struct { + tag []byte + number protowire.Number + dataType DataType + wireType protowire.Type + runtimeMessage *runtimeMessage + staticValue string + jsonPath string + optional bool + repeated bool + listMetadata *ListMetadata + child *wireMessage +} + +const ( + minBufferSize = 1 << 10 // 1KiB + maxBufferSize = 1 << 20 // 1MiB + + defaultBufferCount = 10 +) + +// TODO: This should use an area instead +type bufferPool struct { + pool sync.Pool + defaultSize int +} + +func newMempool(size int) *bufferPool { + if size <= 0 || size < minBufferSize || size > maxBufferSize { + size = minBufferSize + } + + return &bufferPool{ + pool: sync.Pool{ + New: func() any { + return bytes.NewBuffer(make([]byte, 0, size)) + }, + }, + defaultSize: size, + } +} + +func (b *bufferPool) Get() *bytes.Buffer { + buf := b.pool.Get().(*bytes.Buffer) + buf.Reset() + if buf.Cap() < b.defaultSize { + buf.Grow(b.defaultSize - buf.Cap()) + } + + return buf +} + +func (b *bufferPool) Put(buf *bytes.Buffer) { + b.pool.Put(buf) +} + +type wireBuilder struct { + pool *bufferPool + buffers []*bytes.Buffer +} + +func newWireBuilder(size int) *wireBuilder { + return &wireBuilder{ + pool: newMempool(size), + buffers: make([]*bytes.Buffer, 0, defaultBufferCount), + } +} + +func (w *wireBuilder) Reset() { + for _, buf := range w.buffers { + w.pool.Put(buf) + } + w.buffers = w.buffers[:0] +} + +func compileWireMessage(rpcMessage *RPCMessage, message *runtimeMessage) (*wireMessage, error) { + if message == nil { + return nil, fmt.Errorf("message not found for fetch request") + } + msg := &wireMessage{ + fields: make([]wireField, len(rpcMessage.Fields)), + } + + // TODO: This is possible for `@requires` fields, but not yet supported. + if rpcMessage.OneOfType != OneOfTypeNone { + return nil, fmt.Errorf("oneof type not supported yet") + } + + for i := range rpcMessage.Fields { + rpcField := &rpcMessage.Fields[i] + + field, ok := message.fieldsByName[rpcField.Name] + if !ok { + return nil, fmt.Errorf("field not found for name %s", rpcField.Name) + } + + wf := wireField{ + number: field.desc.Number(), + runtimeMessage: field.message, + dataType: rpcField.ProtoTypeName, + wireType: getWireType(field.dataType), + jsonPath: rpcField.JSONPath, + staticValue: rpcField.StaticValue, + optional: rpcField.Optional, + repeated: rpcField.Repeated, + listMetadata: rpcField.ListMetadata, + } + + if rpcField.Message != nil { + child, err := compileWireMessage(rpcField.Message, field.message) + if err != nil { + return nil, err + } + + wf.child = child + } + + wf.tag = protowire.AppendTag(nil, wf.number, wf.wireType) + msg.fields[i] = wf + } + + return msg, nil +} + +// createProtoWire creates a proto wire from the wire plan. +func (w *wireMessage) createProtoWire(builder *wireBuilder, data *astjson.Value) ([]byte, error) { + // TODO: Use arena or a global buffer pool + buf := builder.pool.Get() + + for _, field := range w.fields { + err := field.appendFieldWire(builder, buf, data) + if err != nil { + return nil, err + } + } + + return buf.Bytes(), nil +} + +func (f *wireField) appendFieldWire(builder *wireBuilder, buf *bytes.Buffer, data *astjson.Value) error { + var fieldData *astjson.Value + + if f.jsonPath == "" { + fieldData = data + } else { + fieldData = data.Get(f.jsonPath) + } + + if !fieldData.Exists() { + if f.optional { + return nil + } + + return fmt.Errorf("field %s is required but has no value", f.jsonPath) + } + + if f.repeated { + for _, element := range fieldData.GetArray() { + err := f.appendFieldValue(builder, buf, element) + if err != nil { + return err + } + } + + return nil + } + + if f.isListWrapper() { + // TODO: build a wireMessage for the list wrapper and just create the proto wire for it + //wm := &wireMessage{fields: make([]wireField, 0, 1)} + return f.appendListFieldValue(builder, buf, fieldData, 0) + } + + if f.isOptionalScalar() { + return f.appendOptionalScalarFieldValue(builder, buf, fieldData) + } + + return f.appendFieldValue(builder, buf, fieldData) +} + +// appendListFieldValue appends the list value to the buffer. +// This is used for lists and nested lists which are defined as wrapper messages. +// +// Example: +// ```proto +// +// message ListOfString { +// message List { +// repeated string items = 1; +// } +// List list = 1; +// } +// +// ``` +func (f *wireField) appendListFieldValue(builder *wireBuilder, buf *bytes.Buffer, data *astjson.Value, level int) error { + if level >= f.listMetadata.NestingLevel { + f.listMetadata = nil // reset the list metadata to avoid infinite recursion + return f.appendFieldWire(builder, buf, data) + } + + // TODO: Check for optional + md := f.listMetadata.LevelInfo[level] + level++ + + runtimeMsg := f.runtimeMessage + if runtimeMsg == nil { + return fmt.Errorf("runtime message not found for field %s", f.jsonPath) + } + + listBuffer := builder.pool.Get() + defer builder.pool.Put(listBuffer) + + field, ok := runtimeMsg.fieldsByName["list"] + if !ok { + return fmt.Errorf("list field not found for message %s but was expected", runtimeMsg.name) + } + + // We will always have a message type here, therefore we must use the bytes type. + listBuffer.Write(protowire.AppendTag(nil, field.desc.Number(), protowire.BytesType)) + + listMessage := field.message + if listMessage == nil { + return fmt.Errorf("expected nested message type for list wrapper field but the field %s doesn't have a message", f.jsonPath) + } + + itemsField, ok := listMessage.fieldsByName["items"] + if !ok { + return fmt.Errorf("items field not found for message %s but was expected", listMessage.name) + } + + elements := data.GetArray() + if len(elements) == 0 && !md.Optional { + return fmt.Errorf("list is required but has no elements") + } + + itemsBuffer := builder.pool.Get() + defer builder.pool.Put(itemsBuffer) + + for i := range elements { + iwf := wireField{ + number: itemsField.desc.Number(), + dataType: f.dataType, + wireType: getWireType(itemsField.dataType), + runtimeMessage: itemsField.message, + listMetadata: f.listMetadata, + } + + iwf.tag = protowire.AppendTag(nil, iwf.number, iwf.wireType) + if err := iwf.appendListFieldValue(builder, itemsBuffer, elements[i], level); err != nil { + return err + } + } + + listBuffer.Write(protowire.AppendBytes(nil, itemsBuffer.Bytes())) + + buf.Write(f.tag) + buf.Write(protowire.AppendBytes(nil, listBuffer.Bytes())) + + return nil +} + +func (f *wireField) appendOptionalScalarFieldValue(builder *wireBuilder, buf *bytes.Buffer, data *astjson.Value) error { + if f.runtimeMessage == nil { + return fmt.Errorf("runtime message not found for optional scalar field %s but was expected", f.jsonPath) + } + + wrapperField, ok := f.runtimeMessage.fieldsByName[knownTypeOptionalFieldValueName] + if !ok { + return fmt.Errorf("wrapper field not found for message %s but was expected", f.runtimeMessage.name) + } + + fieldBuf := builder.pool.Get() + defer builder.pool.Put(fieldBuf) + + wf := wireField{ + number: wrapperField.desc.Number(), + dataType: wrapperField.dataType, + wireType: getWireType(wrapperField.dataType), + jsonPath: f.jsonPath, + runtimeMessage: wrapperField.message, + } + + wf.tag = protowire.AppendTag(nil, wf.number, wf.wireType) + err := wf.appendFieldValue(builder, fieldBuf, data) + if err != nil { + return err + } + + buf.Write(f.tag) + buf.Write(protowire.AppendBytes(nil, fieldBuf.Bytes())) + return nil +} + +func (f *wireField) isListWrapper() bool { + return f.listMetadata != nil +} + +func (f *wireField) isOptionalScalar() bool { + return f.optional && f.dataType != DataTypeMessage +} + +func (f *wireField) appendFieldValue(builder *wireBuilder, buf *bytes.Buffer, data *astjson.Value) error { + if f.child != nil { + childWire, err := f.child.createProtoWire(builder, data) + if err != nil { + return err + + } + + buf.Write(f.tag) + buf.Write(protowire.AppendBytes(nil, childWire)) + return nil + } + + switch f.wireType { + case protowire.BytesType: + value := data.GetStringBytes() + buf.Write(f.tag) + buf.Write(protowire.AppendBytes(nil, value)) + case protowire.VarintType: + uintValue := data.GetUint64() + buf.Write(f.tag) + buf.Write(protowire.AppendVarint(nil, uintValue)) + case protowire.Fixed64Type: + buf.Write(f.tag) + buf.Write(protowire.AppendFixed64(nil, math.Float64bits(data.GetFloat64()))) + default: + return fmt.Errorf("unsupported wire type %d", f.wireType) + } + + return nil +} + +func getWireType(dataType DataType) protowire.Type { + switch dataType { + case DataTypeString, DataTypeBytes: + return protowire.BytesType + case DataTypeInt32, DataTypeInt64, DataTypeUint32, DataTypeUint64: + return protowire.VarintType + case DataTypeFloat, DataTypeDouble: + return protowire.Fixed64Type + case DataTypeMessage: + return protowire.BytesType + default: + return protowire.VarintType + } +} diff --git a/v2/pkg/engine/datasource/grpc_datasource/wire_test.go b/v2/pkg/engine/datasource/grpc_datasource/wire_test.go new file mode 100644 index 000000000..972d476f4 --- /dev/null +++ b/v2/pkg/engine/datasource/grpc_datasource/wire_test.go @@ -0,0 +1,741 @@ +package grpcdatasource + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wundergraph/astjson" + "google.golang.org/protobuf/proto" + protoref "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/dynamicpb" +) + +var testWireSchema = ` +syntax = "proto3"; +package test.wire.v1; + +import "google/protobuf/wrappers.proto"; + +enum Status { + STATUS_UNSPECIFIED = 0; + STATUS_ACTIVE = 1; + STATUS_INACTIVE = 2; +} + +message EmptyRequest {} + +message ScalarRequest { + string name = 1; + int32 age = 2; + double score = 3; + bool active = 4; +} + +message WrapperScalarRequest { + google.protobuf.StringValue name = 1; + google.protobuf.Int32Value age = 2; + google.protobuf.DoubleValue score = 3; + google.protobuf.BoolValue active = 4; +} + +message RepeatedScalarRequest { + repeated string tags = 1; + repeated int32 scores = 2; +} + +message NestedItem { + string id = 1; + string value = 2; +} + +message NestedMessageRequest { + NestedItem item = 1; + repeated NestedItem items = 2; +} + +message ListOfString { + message List { + repeated string items = 1; + } + List list = 1; +} + +message ListOfNestedItem { + message List { + repeated NestedItem items = 1; + } + List list = 1; +} + +message ListOfListOfString { + message List { + repeated ListOfString items = 1; + } + List list = 1; +} + +message ListOfListOfNestedItem { + message List { + repeated ListOfNestedItem items = 1; + } + List list = 1; +} + +message ListWrapperRequest { + ListOfString optional_tags = 1; + ListOfNestedItem optional_items = 2; +} + +message NestedListRequest { + ListOfListOfString tag_groups = 1; + ListOfListOfNestedItem item_groups = 2; +} + +message EnumRequest { + Status status = 1; + repeated Status statuses = 2; +} + +message MixedRequest { + string id = 1; + google.protobuf.StringValue description = 2; + repeated string tags = 3; + ListOfString keywords = 4; + NestedItem metadata = 5; + double price = 6; + Status status = 7; +} + +service TestService { + rpc Empty(EmptyRequest) returns (EmptyRequest) {} + rpc Scalar(ScalarRequest) returns (ScalarRequest) {} + rpc WrapperScalar(WrapperScalarRequest) returns (WrapperScalarRequest) {} + rpc RepeatedScalar(RepeatedScalarRequest) returns (RepeatedScalarRequest) {} + rpc NestedMessage(NestedMessageRequest) returns (NestedMessageRequest) {} + rpc ListWrapper(ListWrapperRequest) returns (ListWrapperRequest) {} + rpc NestedList(NestedListRequest) returns (NestedListRequest) {} + rpc Enum(EnumRequest) returns (EnumRequest) {} + rpc Mixed(MixedRequest) returns (MixedRequest) {} +} +` + +func newWireTestRuntime(t *testing.T) *runtimeSchema { + t.Helper() + compiler, err := NewProtoCompiler(testWireSchema, &GRPCMapping{ + Service: "TestService", + }) + require.NoError(t, err) + runtime, err := newSchemaRuntime(compiler) + require.NoError(t, err) + return runtime +} + +func compileTestWireMessage(t *testing.T, runtime *runtimeSchema, messageName string, rpcMessage *RPCMessage) *wireMessage { + t.Helper() + msg := runtime.getMessageByName(messageName) + require.NotNilf(t, msg, "message %q not found in runtime", messageName) + wm, err := compileWireMessage(rpcMessage, msg) + require.NoError(t, err) + return wm +} + +// marshalDynamic builds a dynamicpb message using the runtime's descriptor and marshals it via proto.Marshal. +// This produces the canonical protobuf encoding to compare against createProtoWire output. +func marshalDynamic(t *testing.T, runtime *runtimeSchema, messageName string, build func(msg *dynamicpb.Message, desc protoref.MessageDescriptor)) []byte { + t.Helper() + rtMsg := runtime.getMessageByName(messageName) + require.NotNilf(t, rtMsg, "message %q not found in runtime", messageName) + msg := dynamicpb.NewMessage(rtMsg.desc) + build(msg, rtMsg.desc) + out, err := proto.Marshal(msg) + require.NoError(t, err) + return out +} + +// assertProtoEqual unmarshals both byte slices into the same message type and compares the resulting messages. +// This allows valid encoding differences (e.g. packed vs unpacked repeated scalars) to pass. +func assertProtoEqual(t *testing.T, runtime *runtimeSchema, messageName string, expected, got []byte) { + t.Helper() + rtMsg := runtime.getMessageByName(messageName) + require.NotNilf(t, rtMsg, "message %q not found in runtime", messageName) + + expectedMsg := dynamicpb.NewMessage(rtMsg.desc) + require.NoError(t, proto.Unmarshal(expected, expectedMsg), "failed to unmarshal expected bytes") + + gotMsg := dynamicpb.NewMessage(rtMsg.desc) + require.NoError(t, proto.Unmarshal(got, gotMsg), "failed to unmarshal got bytes") + + assert.True(t, proto.Equal(expectedMsg, gotMsg), + "messages not equal\nexpected: %v\ngot: %v\nexpected bytes: %x\ngot bytes: %x", + expectedMsg, gotMsg, expected, got) +} + +// setWrapperValue sets a google.protobuf wrapper field (e.g. StringValue, Int32Value) on a dynamic message. +func setWrapperValue(msg *dynamicpb.Message, fieldName protoref.Name, value protoref.Value) { + fd := msg.Descriptor().Fields().ByName(fieldName) + wrapper := dynamicpb.NewMessage(fd.Message()) + wrapper.Set(fd.Message().Fields().ByName("value"), value) + msg.Set(fd, protoref.ValueOfMessage(wrapper)) +} + +func TestCompileWireMessage(t *testing.T) { + runtime := newWireTestRuntime(t) + + t.Run("empty message with no fields", func(t *testing.T) { + wm := compileTestWireMessage(t, runtime, "EmptyRequest", &RPCMessage{ + Name: "EmptyRequest", + Fields: nil, + }) + assert.Len(t, wm.fields, 0) + }) + + t.Run("scalar fields", func(t *testing.T) { + wm := compileTestWireMessage(t, runtime, "ScalarRequest", &RPCMessage{ + Name: "ScalarRequest", + Fields: RPCFields{ + {Name: "name", ProtoTypeName: DataTypeString, JSONPath: "name"}, + {Name: "age", ProtoTypeName: DataTypeInt32, JSONPath: "age"}, + {Name: "score", ProtoTypeName: DataTypeDouble, JSONPath: "score"}, + {Name: "active", ProtoTypeName: DataTypeBool, JSONPath: "active"}, + }, + }) + assert.Len(t, wm.fields, 4) + assert.Equal(t, DataTypeString, wm.fields[0].dataType) + assert.Equal(t, DataTypeInt32, wm.fields[1].dataType) + assert.Equal(t, DataTypeDouble, wm.fields[2].dataType) + assert.Equal(t, DataTypeBool, wm.fields[3].dataType) + }) + + t.Run("wrapper scalar fields as optional scalars", func(t *testing.T) { + wm := compileTestWireMessage(t, runtime, "WrapperScalarRequest", &RPCMessage{ + Name: "WrapperScalarRequest", + Fields: RPCFields{ + {Name: "name", ProtoTypeName: DataTypeString, JSONPath: "name", Optional: true}, + {Name: "age", ProtoTypeName: DataTypeInt32, JSONPath: "age", Optional: true}, + {Name: "score", ProtoTypeName: DataTypeDouble, JSONPath: "score", Optional: true}, + {Name: "active", ProtoTypeName: DataTypeBool, JSONPath: "active", Optional: true}, + }, + }) + assert.Len(t, wm.fields, 4) + assert.True(t, wm.fields[0].optional) + assert.True(t, wm.fields[1].optional) + assert.True(t, wm.fields[2].optional) + assert.True(t, wm.fields[3].optional) + }) + + t.Run("repeated scalar fields", func(t *testing.T) { + wm := compileTestWireMessage(t, runtime, "RepeatedScalarRequest", &RPCMessage{ + Name: "RepeatedScalarRequest", + Fields: RPCFields{ + {Name: "tags", ProtoTypeName: DataTypeString, JSONPath: "tags", Repeated: true}, + {Name: "scores", ProtoTypeName: DataTypeInt32, JSONPath: "scores", Repeated: true}, + }, + }) + assert.Len(t, wm.fields, 2) + assert.True(t, wm.fields[0].repeated) + assert.True(t, wm.fields[1].repeated) + }) + + t.Run("list wrapper with list metadata", func(t *testing.T) { + msg := runtime.getMessageByName("ListWrapperRequest") + require.NotNil(t, msg) + // Optional + IsListType: compileWireMessage must not treat this as a wrapper scalar. + // Currently this errors because it tries to wrap in google.protobuf.*Value and looks for "value" in ListOfString. + wm, err := compileWireMessage(&RPCMessage{ + Name: "ListWrapperRequest", + Fields: RPCFields{ + { + Name: "optional_tags", + ProtoTypeName: DataTypeString, + JSONPath: "optionalTags", + Optional: true, + IsListType: true, + ListMetadata: &ListMetadata{ + NestingLevel: 1, + LevelInfo: []LevelInfo{{Optional: true}}, + }, + }, + }, + }, msg) + + require.NoError(t, err) + + assert.Len(t, wm.fields, 1) + assert.NotNil(t, wm.fields[0].listMetadata) + assert.Equal(t, 1, wm.fields[0].listMetadata.NestingLevel) + }) + + t.Run("nested list wrapper with list metadata", func(t *testing.T) { + wm := compileTestWireMessage(t, runtime, "NestedListRequest", &RPCMessage{ + Name: "NestedListRequest", + Fields: RPCFields{ + { + Name: "tag_groups", + ProtoTypeName: DataTypeString, + JSONPath: "tagGroups", + IsListType: true, + ListMetadata: &ListMetadata{ + NestingLevel: 2, + LevelInfo: []LevelInfo{{Optional: false}, {Optional: false}}, + }, + }, + }, + }) + assert.Len(t, wm.fields, 1) + assert.NotNil(t, wm.fields[0].listMetadata) + assert.Equal(t, 2, wm.fields[0].listMetadata.NestingLevel) + }) + + t.Run("enum field", func(t *testing.T) { + wm := compileTestWireMessage(t, runtime, "EnumRequest", &RPCMessage{ + Name: "EnumRequest", + Fields: RPCFields{ + {Name: "status", ProtoTypeName: DataTypeEnum, JSONPath: "status", EnumName: "Status"}, + {Name: "statuses", ProtoTypeName: DataTypeEnum, JSONPath: "statuses", EnumName: "Status", Repeated: true}, + }, + }) + assert.Len(t, wm.fields, 2) + assert.Equal(t, DataTypeEnum, wm.fields[0].dataType) + assert.Equal(t, DataTypeEnum, wm.fields[1].dataType) + assert.True(t, wm.fields[1].repeated) + }) +} + +func TestCreateProtoWire(t *testing.T) { + runtime := newWireTestRuntime(t) + + t.Run("empty message", func(t *testing.T) { + wm := compileTestWireMessage(t, runtime, "EmptyRequest", &RPCMessage{ + Name: "EmptyRequest", + Fields: nil, + }) + + builder := newWireBuilder(minBufferSize) + + got, err := wm.createProtoWire(builder, astjson.MustParse(`{}`)) + require.NoError(t, err) + + expected := marshalDynamic(t, runtime, "EmptyRequest", func(msg *dynamicpb.Message, desc protoref.MessageDescriptor) {}) + + assertProtoEqual(t, runtime, "EmptyRequest", expected, got) + }) + + t.Run("single string field", func(t *testing.T) { + wm := compileTestWireMessage(t, runtime, "ScalarRequest", &RPCMessage{ + Name: "ScalarRequest", + Fields: RPCFields{ + {Name: "name", ProtoTypeName: DataTypeString, JSONPath: "name"}, + }, + }) + + builder := newWireBuilder(minBufferSize) + + got, err := wm.createProtoWire(builder, astjson.MustParse(`{"name":"hello"}`)) + require.NoError(t, err) + + expected := marshalDynamic(t, runtime, "ScalarRequest", func(msg *dynamicpb.Message, desc protoref.MessageDescriptor) { + msg.Set(desc.Fields().ByName("name"), protoref.ValueOfString("hello")) + }) + + assertProtoEqual(t, runtime, "ScalarRequest", expected, got) + }) + + t.Run("string int32 and double fields", func(t *testing.T) { + wm := compileTestWireMessage(t, runtime, "ScalarRequest", &RPCMessage{ + Name: "ScalarRequest", + Fields: RPCFields{ + {Name: "name", ProtoTypeName: DataTypeString, JSONPath: "name"}, + {Name: "age", ProtoTypeName: DataTypeInt32, JSONPath: "age"}, + {Name: "score", ProtoTypeName: DataTypeDouble, JSONPath: "score"}, + }, + }) + + builder := newWireBuilder(minBufferSize) + + got, err := wm.createProtoWire(builder, astjson.MustParse(`{"name":"alice","age":30,"score":99.5}`)) + require.NoError(t, err) + + expected := marshalDynamic(t, runtime, "ScalarRequest", func(msg *dynamicpb.Message, desc protoref.MessageDescriptor) { + msg.Set(desc.Fields().ByName("name"), protoref.ValueOfString("alice")) + msg.Set(desc.Fields().ByName("age"), protoref.ValueOfInt32(30)) + msg.Set(desc.Fields().ByName("score"), protoref.ValueOfFloat64(99.5)) + }) + + assertProtoEqual(t, runtime, "ScalarRequest", expected, got) + }) + + t.Run("wrapper string value present", func(t *testing.T) { + wm := compileTestWireMessage(t, runtime, "WrapperScalarRequest", &RPCMessage{ + Name: "WrapperScalarRequest", + Fields: RPCFields{ + { + Name: "name", ProtoTypeName: DataTypeString, JSONPath: "name", Optional: true, + }, + }, + }) + + builder := newWireBuilder(minBufferSize) + + got, err := wm.createProtoWire(builder, astjson.MustParse(`{"name":"hello"}`)) + require.NoError(t, err) + + expected := marshalDynamic(t, runtime, "WrapperScalarRequest", func(msg *dynamicpb.Message, desc protoref.MessageDescriptor) { + setWrapperValue(msg, "name", protoref.ValueOfString("hello")) + }) + + assertProtoEqual(t, runtime, "WrapperScalarRequest", expected, got) + }) + + t.Run("wrapper string value absent", func(t *testing.T) { + wm := compileTestWireMessage(t, runtime, "WrapperScalarRequest", &RPCMessage{ + Name: "WrapperScalarRequest", + Fields: RPCFields{ + {Name: "name", ProtoTypeName: DataTypeString, JSONPath: "name", Optional: true}, + }, + }) + + builder := newWireBuilder(minBufferSize) + + got, err := wm.createProtoWire(builder, astjson.MustParse(`{}`)) + require.NoError(t, err) + + expected := marshalDynamic(t, runtime, "WrapperScalarRequest", func(msg *dynamicpb.Message, desc protoref.MessageDescriptor) { + // name not set — wrapper absent means null + }) + + assertProtoEqual(t, runtime, "WrapperScalarRequest", expected, got) + }) + + t.Run("wrapper int32 and double values", func(t *testing.T) { + wm := compileTestWireMessage(t, runtime, "WrapperScalarRequest", &RPCMessage{ + Name: "WrapperScalarRequest", + Fields: RPCFields{ + {Name: "age", ProtoTypeName: DataTypeInt32, JSONPath: "age", Optional: true}, + {Name: "score", ProtoTypeName: DataTypeDouble, JSONPath: "score", Optional: true}, + }, + }) + + builder := newWireBuilder(minBufferSize) + + got, err := wm.createProtoWire(builder, astjson.MustParse(`{"age":25,"score":3.14}`)) + require.NoError(t, err) + + expected := marshalDynamic(t, runtime, "WrapperScalarRequest", func(msg *dynamicpb.Message, desc protoref.MessageDescriptor) { + setWrapperValue(msg, "age", protoref.ValueOfInt32(25)) + setWrapperValue(msg, "score", protoref.ValueOfFloat64(3.14)) + }) + + assertProtoEqual(t, runtime, "WrapperScalarRequest", expected, got) + }) + + t.Run("repeated strings", func(t *testing.T) { + wm := compileTestWireMessage(t, runtime, "RepeatedScalarRequest", &RPCMessage{ + Name: "RepeatedScalarRequest", + Fields: RPCFields{ + {Name: "tags", ProtoTypeName: DataTypeString, JSONPath: "tags", Repeated: true}, + }, + }) + + builder := newWireBuilder(minBufferSize) + + got, err := wm.createProtoWire(builder, astjson.MustParse(`{"tags":["foo","bar","baz"]}`)) + require.NoError(t, err) + + expected := marshalDynamic(t, runtime, "RepeatedScalarRequest", func(msg *dynamicpb.Message, desc protoref.MessageDescriptor) { + list := msg.Mutable(desc.Fields().ByName("tags")).List() + list.Append(protoref.ValueOfString("foo")) + list.Append(protoref.ValueOfString("bar")) + list.Append(protoref.ValueOfString("baz")) + }) + + assertProtoEqual(t, runtime, "RepeatedScalarRequest", expected, got) + }) + + t.Run("repeated int32s", func(t *testing.T) { + wm := compileTestWireMessage(t, runtime, "RepeatedScalarRequest", &RPCMessage{ + Name: "RepeatedScalarRequest", + Fields: RPCFields{ + {Name: "scores", ProtoTypeName: DataTypeInt32, JSONPath: "scores", Repeated: true}, + }, + }) + + builder := newWireBuilder(minBufferSize) + + got, err := wm.createProtoWire(builder, astjson.MustParse(`{"scores":[1,2,3]}`)) + require.NoError(t, err) + + expected := marshalDynamic(t, runtime, "RepeatedScalarRequest", func(msg *dynamicpb.Message, desc protoref.MessageDescriptor) { + list := msg.Mutable(desc.Fields().ByName("scores")).List() + list.Append(protoref.ValueOfInt32(1)) + list.Append(protoref.ValueOfInt32(2)) + list.Append(protoref.ValueOfInt32(3)) + }) + + assertProtoEqual(t, runtime, "RepeatedScalarRequest", expected, got) + }) + + t.Run("single nested message", func(t *testing.T) { + wm := compileTestWireMessage(t, runtime, "NestedMessageRequest", &RPCMessage{ + Name: "NestedMessageRequest", + Fields: RPCFields{ + { + Name: "item", ProtoTypeName: DataTypeMessage, JSONPath: "item", + Message: &RPCMessage{ + Name: "NestedItem", + Fields: RPCFields{ + {Name: "id", ProtoTypeName: DataTypeString, JSONPath: "id"}, + {Name: "value", ProtoTypeName: DataTypeString, JSONPath: "value"}, + }, + }, + }, + }, + }) + + builder := newWireBuilder(minBufferSize) + + got, err := wm.createProtoWire(builder, astjson.MustParse(`{"item":{"id":"1","value":"a"}}`)) + require.NoError(t, err) + + expected := marshalDynamic(t, runtime, "NestedMessageRequest", func(msg *dynamicpb.Message, desc protoref.MessageDescriptor) { + itemField := desc.Fields().ByName("item") + item := dynamicpb.NewMessage(itemField.Message()) + item.Set(itemField.Message().Fields().ByName("id"), protoref.ValueOfString("1")) + item.Set(itemField.Message().Fields().ByName("value"), protoref.ValueOfString("a")) + msg.Set(itemField, protoref.ValueOfMessage(item)) + }) + + assertProtoEqual(t, runtime, "NestedMessageRequest", expected, got) + }) + + t.Run("repeated nested messages", func(t *testing.T) { + wm := compileTestWireMessage(t, runtime, "NestedMessageRequest", &RPCMessage{ + Name: "NestedMessageRequest", + Fields: RPCFields{ + { + Name: "items", ProtoTypeName: DataTypeMessage, JSONPath: "items", Repeated: true, + Message: &RPCMessage{ + Name: "NestedItem", + Fields: RPCFields{ + {Name: "id", ProtoTypeName: DataTypeString, JSONPath: "id"}, + {Name: "value", ProtoTypeName: DataTypeString, JSONPath: "value"}, + }, + }, + }, + }, + }) + + builder := newWireBuilder(minBufferSize) + + got, err := wm.createProtoWire(builder, astjson.MustParse(`{"items":[{"id":"1","value":"a"},{"id":"2","value":"b"}]}`)) + require.NoError(t, err) + + expected := marshalDynamic(t, runtime, "NestedMessageRequest", func(msg *dynamicpb.Message, desc protoref.MessageDescriptor) { + itemsField := desc.Fields().ByName("items") + itemDesc := itemsField.Message() + list := msg.Mutable(itemsField).List() + + item1 := dynamicpb.NewMessage(itemDesc) + item1.Set(itemDesc.Fields().ByName("id"), protoref.ValueOfString("1")) + item1.Set(itemDesc.Fields().ByName("value"), protoref.ValueOfString("a")) + list.Append(protoref.ValueOfMessage(item1)) + + item2 := dynamicpb.NewMessage(itemDesc) + item2.Set(itemDesc.Fields().ByName("id"), protoref.ValueOfString("2")) + item2.Set(itemDesc.Fields().ByName("value"), protoref.ValueOfString("b")) + list.Append(protoref.ValueOfMessage(item2)) + }) + + assertProtoEqual(t, runtime, "NestedMessageRequest", expected, got) + }) + + t.Run("enum field", func(t *testing.T) { + wm := compileTestWireMessage(t, runtime, "EnumRequest", &RPCMessage{ + Name: "EnumRequest", + Fields: RPCFields{ + {Name: "status", ProtoTypeName: DataTypeEnum, JSONPath: "status", EnumName: "Status"}, + }, + }) + + // STATUS_ACTIVE = 1 + builder := newWireBuilder(minBufferSize) + + got, err := wm.createProtoWire(builder, astjson.MustParse(`{"status":1}`)) + require.NoError(t, err) + + expected := marshalDynamic(t, runtime, "EnumRequest", func(msg *dynamicpb.Message, desc protoref.MessageDescriptor) { + msg.Set(desc.Fields().ByName("status"), protoref.ValueOfEnum(1)) + }) + + assertProtoEqual(t, runtime, "EnumRequest", expected, got) + }) + + t.Run("repeated enums", func(t *testing.T) { + wm := compileTestWireMessage(t, runtime, "EnumRequest", &RPCMessage{ + Name: "EnumRequest", + Fields: RPCFields{ + {Name: "statuses", ProtoTypeName: DataTypeEnum, JSONPath: "statuses", EnumName: "Status", Repeated: true}, + }, + }) + + builder := newWireBuilder(minBufferSize) + + got, err := wm.createProtoWire(builder, astjson.MustParse(`{"statuses":[0,1,2]}`)) + require.NoError(t, err) + + expected := marshalDynamic(t, runtime, "EnumRequest", func(msg *dynamicpb.Message, desc protoref.MessageDescriptor) { + list := msg.Mutable(desc.Fields().ByName("statuses")).List() + list.Append(protoref.ValueOfEnum(0)) + list.Append(protoref.ValueOfEnum(1)) + list.Append(protoref.ValueOfEnum(2)) + }) + + assertProtoEqual(t, runtime, "EnumRequest", expected, got) + }) + + t.Run("list wrapper with strings", func(t *testing.T) { + // RPC plan models ListOfString as a flat optional scalar with IsListType + ListMetadata. + // createProtoWire must produce: ListWrapperRequest { optional_tags: ListOfString { list: List { items: [...] } } } + msg := runtime.getMessageByName("ListWrapperRequest") + require.NotNil(t, msg) + wm, compileErr := compileWireMessage(&RPCMessage{ + Name: "ListWrapperRequest", + Fields: RPCFields{ + { + Name: "optional_tags", + ProtoTypeName: DataTypeString, + JSONPath: "optionalTags", + Optional: true, + IsListType: true, + ListMetadata: &ListMetadata{ + NestingLevel: 1, + LevelInfo: []LevelInfo{{Optional: true}}, + }, + }, + }, + }, msg) + + require.NoError(t, compileErr) + + builder := newWireBuilder(minBufferSize) + + got, err := wm.createProtoWire(builder, astjson.MustParse(`{"optionalTags":["a","b"]}`)) + require.NoError(t, err) + + expected := marshalDynamic(t, runtime, "ListWrapperRequest", func(msg *dynamicpb.Message, desc protoref.MessageDescriptor) { + // Build: ListWrapperRequest { optional_tags: ListOfString { list: List { items: ["a","b"] } } } + optTagsField := desc.Fields().ByName("optional_tags") + listOfStringDesc := optTagsField.Message() + + listField := listOfStringDesc.Fields().ByName("list") + listDesc := listField.Message() + + innerList := dynamicpb.NewMessage(listDesc) + items := innerList.Mutable(listDesc.Fields().ByName("items")).List() + items.Append(protoref.ValueOfString("a")) + items.Append(protoref.ValueOfString("b")) + + listOfString := dynamicpb.NewMessage(listOfStringDesc) + listOfString.Set(listField, protoref.ValueOfMessage(innerList)) + + msg.Set(optTagsField, protoref.ValueOfMessage(listOfString)) + }) + + assertProtoEqual(t, runtime, "ListWrapperRequest", expected, got) + }) + + t.Run("nested list wrapper two levels", func(t *testing.T) { + // RPC plan models ListOfListOfString as a flat scalar with IsListType + ListMetadata (NestingLevel=2). + // createProtoWire must produce: NestedListRequest { tag_groups: ListOfListOfString { list: { items: [ ListOfString{...}, ... ] } } } + msg := runtime.getMessageByName("NestedListRequest") + require.NotNil(t, msg) + wm, compileErr := compileWireMessage(&RPCMessage{ + Name: "NestedListRequest", + Fields: RPCFields{ + { + Name: "tag_groups", + ProtoTypeName: DataTypeString, + JSONPath: "tagGroups", + IsListType: true, + ListMetadata: &ListMetadata{ + NestingLevel: 2, + LevelInfo: []LevelInfo{{Optional: false}, {Optional: false}}, + }, + }, + }, + }, msg) + + require.NoError(t, compileErr) + + builder := newWireBuilder(minBufferSize) + + got, err := wm.createProtoWire(builder, astjson.MustParse(`{"tagGroups":[["a","b"],["c"]]}`)) + require.NoError(t, err) + + expected := marshalDynamic(t, runtime, "NestedListRequest", func(msg *dynamicpb.Message, desc protoref.MessageDescriptor) { + tagGroupsField := desc.Fields().ByName("tag_groups") + lolosDesc := tagGroupsField.Message() // ListOfListOfString + + outerListField := lolosDesc.Fields().ByName("list") + outerListDesc := outerListField.Message() // ListOfListOfString.List + outerItemsField := outerListDesc.Fields().ByName("items") + losDesc := outerItemsField.Message() // ListOfString + + // Build ListOfString for ["a","b"] + buildListOfString := func(values ...string) *dynamicpb.Message { + innerListField := losDesc.Fields().ByName("list") + innerListDesc := innerListField.Message() + innerList := dynamicpb.NewMessage(innerListDesc) + items := innerList.Mutable(innerListDesc.Fields().ByName("items")).List() + for _, v := range values { + items.Append(protoref.ValueOfString(v)) + } + los := dynamicpb.NewMessage(losDesc) + los.Set(innerListField, protoref.ValueOfMessage(innerList)) + return los + } + + outerList := dynamicpb.NewMessage(outerListDesc) + outerItems := outerList.Mutable(outerItemsField).List() + outerItems.Append(protoref.ValueOfMessage(buildListOfString("a", "b"))) + outerItems.Append(protoref.ValueOfMessage(buildListOfString("c"))) + + lolos := dynamicpb.NewMessage(lolosDesc) + lolos.Set(outerListField, protoref.ValueOfMessage(outerList)) + + msg.Set(tagGroupsField, protoref.ValueOfMessage(lolos)) + }) + + assertProtoEqual(t, runtime, "NestedListRequest", expected, got) + }) + + t.Run("mixed request with multiple field types", func(t *testing.T) { + wm := compileTestWireMessage(t, runtime, "MixedRequest", &RPCMessage{ + Name: "MixedRequest", + Fields: RPCFields{ + {Name: "id", ProtoTypeName: DataTypeString, JSONPath: "id"}, + {Name: "description", ProtoTypeName: DataTypeString, JSONPath: "description", Optional: true}, + {Name: "tags", ProtoTypeName: DataTypeString, JSONPath: "tags", Repeated: true}, + {Name: "price", ProtoTypeName: DataTypeDouble, JSONPath: "price"}, + {Name: "status", ProtoTypeName: DataTypeEnum, JSONPath: "status", EnumName: "Status"}, + }, + }) + + builder := newWireBuilder(minBufferSize) + + got, err := wm.createProtoWire(builder, astjson.MustParse(`{"id":"p1","description":"a product","tags":["sale","new"],"price":29.99,"status":1}`)) + require.NoError(t, err) + + expected := marshalDynamic(t, runtime, "MixedRequest", func(msg *dynamicpb.Message, desc protoref.MessageDescriptor) { + msg.Set(desc.Fields().ByName("id"), protoref.ValueOfString("p1")) + setWrapperValue(msg, "description", protoref.ValueOfString("a product")) + tagsList := msg.Mutable(desc.Fields().ByName("tags")).List() + tagsList.Append(protoref.ValueOfString("sale")) + tagsList.Append(protoref.ValueOfString("new")) + msg.Set(desc.Fields().ByName("price"), protoref.ValueOfFloat64(29.99)) + msg.Set(desc.Fields().ByName("status"), protoref.ValueOfEnum(1)) + }) + + assertProtoEqual(t, runtime, "MixedRequest", expected, got) + }) +} From f7aa79607ba4458f47a65702236fa61fff1dc460 Mon Sep 17 00:00:00 2001 From: Ludwig Bedacht Date: Wed, 22 Apr 2026 13:52:09 +0200 Subject: [PATCH 02/12] chore: add enum support to the wire message --- .../datasource/grpc_datasource/codec.go | 7 + .../datasource/grpc_datasource/compiler.go | 11 +- .../datasource/grpc_datasource/program.go | 2 +- .../grpc_datasource/program_test.go | 2 +- .../datasource/grpc_datasource/runtime.go | 44 ++++- .../grpc_datasource/runtime_test.go | 150 +++++++++++++++++- .../engine/datasource/grpc_datasource/wire.go | 67 ++++++-- .../datasource/grpc_datasource/wire_test.go | 126 +++++++++++++-- 8 files changed, 364 insertions(+), 45 deletions(-) diff --git a/v2/pkg/engine/datasource/grpc_datasource/codec.go b/v2/pkg/engine/datasource/grpc_datasource/codec.go index f2be1dc2c..2d2c91153 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/codec.go +++ b/v2/pkg/engine/datasource/grpc_datasource/codec.go @@ -1,6 +1,8 @@ package grpcdatasource import ( + "fmt" + "google.golang.org/grpc/encoding" "google.golang.org/grpc/encoding/proto" _ "google.golang.org/grpc/encoding/proto" @@ -41,6 +43,11 @@ func (c *connectCodec) Marshal(v any) (out mem.BufferSlice, err error) { } } + // TODO: This should never happen + if defaultCodec == nil { + return nil, fmt.Errorf("default codec is nil") + } + return defaultCodec.Marshal(v) } diff --git a/v2/pkg/engine/datasource/grpc_datasource/compiler.go b/v2/pkg/engine/datasource/grpc_datasource/compiler.go index e375c7271..08b974e07 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/compiler.go +++ b/v2/pkg/engine/datasource/grpc_datasource/compiler.go @@ -211,8 +211,8 @@ type EnumValue struct { // RPCCompiler compiles protobuf schema strings into a Document and can // build protobuf messages from JSON data based on the schema. type RPCCompiler struct { - doc *Document // The compiled Document - Ancestor []Message + doc *Document // The compiled Document + runtime *runtimeSchema // The compiled runtime schema } // ServiceByName returns a Service by its name. @@ -311,6 +311,13 @@ func NewProtoCompiler(schema string, mapping *GRPCMapping) (*RPCCompiler, error) // Process the schema file pc.processFile(schemaFile, mapping) + runtime, err := newSchemaRuntime(pc.doc) + if err != nil { + return nil, err + } + + pc.runtime = runtime + return pc, nil } diff --git a/v2/pkg/engine/datasource/grpc_datasource/program.go b/v2/pkg/engine/datasource/grpc_datasource/program.go index a0441a133..d8292f7ae 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/program.go +++ b/v2/pkg/engine/datasource/grpc_datasource/program.go @@ -118,7 +118,7 @@ func compileFetch(call *RPCCall, runtime *runtimeSchema, dependentCall *RPCCall) responseType: responseMessage, } - wireMessage, err := compileWireMessage(&f.request.rpcMessage, requestMessage) + wireMessage, err := compileWireMessage(runtime, &f.request.rpcMessage, requestMessage) if err != nil { return fetch{}, err } diff --git a/v2/pkg/engine/datasource/grpc_datasource/program_test.go b/v2/pkg/engine/datasource/grpc_datasource/program_test.go index cfbec33dc..9ddf0b022 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/program_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/program_test.go @@ -51,7 +51,7 @@ func TestCompileProgram(t *testing.T) { compiler, err := NewProtoCompiler(grpctest.MustProtoSchema(t), testMapping()) require.NoError(t, err) - runtime, err := newSchemaRuntime(compiler) + runtime, err := newSchemaRuntime(compiler.doc) require.NoError(t, err) program, err := compileProgram(plan, runtime) diff --git a/v2/pkg/engine/datasource/grpc_datasource/runtime.go b/v2/pkg/engine/datasource/grpc_datasource/runtime.go index 013a3049f..d2d88e44d 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/runtime.go +++ b/v2/pkg/engine/datasource/grpc_datasource/runtime.go @@ -10,6 +10,7 @@ import ( type runtimeSchema struct { messageByName map[string]*runtimeMessage messageByFullname map[string]*runtimeMessage + enumByName map[string]*runtimeEnum serviceNamesByMethod map[string]string } @@ -20,6 +21,16 @@ type runtimeMessage struct { fieldsByName map[string]*runtimeField } +type runtimeEnum struct { + name string + valuesByName map[string]*runtimeEnumValue +} + +type runtimeEnumValue struct { + name string + value int32 +} + type runtimeField struct { name string owner *runtimeMessage @@ -31,15 +42,16 @@ type runtimeField struct { optional bool } -func newSchemaRuntime(compiler *RPCCompiler) (*runtimeSchema, error) { +func newSchemaRuntime(doc *Document) (*runtimeSchema, error) { runtime := &runtimeSchema{ - messageByName: make(map[string]*runtimeMessage, len(compiler.doc.Messages)), - messageByFullname: make(map[string]*runtimeMessage, len(compiler.doc.Messages)), - serviceNamesByMethod: make(map[string]string, len(compiler.doc.Methods)), + messageByName: make(map[string]*runtimeMessage, len(doc.Messages)), + messageByFullname: make(map[string]*runtimeMessage, len(doc.Messages)), + serviceNamesByMethod: make(map[string]string, len(doc.Methods)), + enumByName: make(map[string]*runtimeEnum, len(doc.Enums)), } - for i := range compiler.doc.Messages { - message := &compiler.doc.Messages[i] + for i := range doc.Messages { + message := &doc.Messages[i] rtMessage := &runtimeMessage{ name: message.Name, @@ -58,10 +70,26 @@ func newSchemaRuntime(compiler *RPCCompiler) (*runtimeSchema, error) { } } - for _, service := range compiler.doc.Services { + for _, service := range doc.Services { for i := range service.MethodsRefs { - runtime.serviceNamesByMethod[compiler.doc.Methods[i].Name] = service.FullName + runtime.serviceNamesByMethod[doc.Methods[i].Name] = service.FullName + } + } + + for _, enum := range doc.Enums { + rtEnum := &runtimeEnum{ + name: enum.Name, + valuesByName: make(map[string]*runtimeEnumValue, len(enum.Values)), + } + + for _, value := range enum.Values { + rtEnum.valuesByName[value.GraphqlValue] = &runtimeEnumValue{ + name: value.Name, + value: value.Number, + } } + + runtime.enumByName[enum.Name] = rtEnum } return runtime, nil diff --git a/v2/pkg/engine/datasource/grpc_datasource/runtime_test.go b/v2/pkg/engine/datasource/grpc_datasource/runtime_test.go index a4b224e47..7df3ed92f 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/runtime_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/runtime_test.go @@ -3,6 +3,7 @@ package grpcdatasource import ( "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -10,11 +11,139 @@ func TestNewSchemaRuntime(t *testing.T) { compiler, err := NewProtoCompiler(testSchemaWithLookup, testMapping()) require.NoError(t, err) - runtime, err := newSchemaRuntime(compiler) + runtime, err := newSchemaRuntime(compiler.doc) require.NoError(t, err) - require.Equal(t, 5, len(runtime.messageByName)) - require.Equal(t, 5, len(runtime.messageByFullname)) + require.Len(t, runtime.messageByName, 5) + require.Len(t, runtime.messageByFullname, 5) + require.Len(t, runtime.enumByName, 2) + require.Len(t, runtime.serviceNamesByMethod, 1) +} + +func TestSchemaRuntimeMessages(t *testing.T) { + compiler, err := NewProtoCompiler(testSchemaWithLookup, testMapping()) + require.NoError(t, err) + + runtime, err := newSchemaRuntime(compiler.doc) + require.NoError(t, err) + + t.Run("getMessageByName returns existing message", func(t *testing.T) { + msg := runtime.getMessageByName("Product") + require.NotNil(t, msg) + assert.Equal(t, "Product", msg.name) + }) + + t.Run("getMessageByName returns nil for unknown message", func(t *testing.T) { + msg := runtime.getMessageByName("NonExistent") + assert.Nil(t, msg) + }) + + t.Run("message has correct fields", func(t *testing.T) { + msg := runtime.getMessageByName("Product") + require.NotNil(t, msg) + + assert.Contains(t, msg.fieldsByName, "id") + assert.Contains(t, msg.fieldsByName, "name") + assert.Contains(t, msg.fieldsByName, "price") + assert.Contains(t, msg.fieldsByName, "status") + assert.Contains(t, msg.fieldsByName, "category") + }) + + t.Run("field data types are correct", func(t *testing.T) { + msg := runtime.getMessageByName("Product") + require.NotNil(t, msg) + + assert.Equal(t, DataTypeString, msg.fieldsByName["id"].dataType) + assert.Equal(t, DataTypeString, msg.fieldsByName["name"].dataType) + assert.Equal(t, DataTypeDouble, msg.fieldsByName["price"].dataType) + assert.Equal(t, DataTypeEnum, msg.fieldsByName["status"].dataType) + assert.Equal(t, DataTypeEnum, msg.fieldsByName["category"].dataType) + }) + + t.Run("repeated field is detected", func(t *testing.T) { + msg := runtime.getMessageByName("LookupProductByIdRequest") + require.NotNil(t, msg) + + field := msg.fieldsByName["inputs"] + require.NotNil(t, field) + assert.True(t, field.repeated) + }) + + t.Run("message field has child message reference", func(t *testing.T) { + msg := runtime.getMessageByName("LookupProductByIdResult") + require.NotNil(t, msg) + + field := msg.fieldsByName["product"] + require.NotNil(t, field) + assert.Equal(t, DataTypeMessage, field.dataType) + require.NotNil(t, field.message) + assert.Equal(t, "Product", field.message.name) + }) + + t.Run("newEmptyMessage creates a valid message", func(t *testing.T) { + msg := runtime.getMessageByName("Product") + require.NotNil(t, msg) + + empty := msg.newEmptyMessage() + require.NotNil(t, empty) + assert.Equal(t, msg.desc.FullName(), empty.Descriptor().FullName()) + }) +} + +func TestSchemaRuntimeEnums(t *testing.T) { + compiler, err := NewProtoCompiler(testSchemaWithLookup, testMapping()) + require.NoError(t, err) + + runtime, err := newSchemaRuntime(compiler.doc) + require.NoError(t, err) + + t.Run("enums are registered by name", func(t *testing.T) { + require.Contains(t, runtime.enumByName, "ProductStatus") + require.Contains(t, runtime.enumByName, "CategoryKind") + }) + + t.Run("enum has correct values", func(t *testing.T) { + productStatus := runtime.enumByName["ProductStatus"] + require.NotNil(t, productStatus) + assert.Equal(t, "ProductStatus", productStatus.name) + + // Values are keyed by GraphqlValue (mapped from the proto enum value name) + assert.NotEmpty(t, productStatus.valuesByName) + }) + + t.Run("enum values have correct numeric values", func(t *testing.T) { + categoryKind := runtime.enumByName["CategoryKind"] + require.NotNil(t, categoryKind) + + // The mapping transforms proto names to GraphQL names. + // Check that we have entries and they carry the right numeric values. + for _, v := range categoryKind.valuesByName { + assert.GreaterOrEqual(t, v.value, int32(0)) + assert.NotEmpty(t, v.name) + } + }) + + t.Run("unknown enum is not registered", func(t *testing.T) { + assert.NotContains(t, runtime.enumByName, "NonExistentEnum") + }) +} + +func TestSchemaRuntimeServices(t *testing.T) { + compiler, err := NewProtoCompiler(testSchemaWithLookup, testMapping()) + require.NoError(t, err) + + runtime, err := newSchemaRuntime(compiler.doc) + require.NoError(t, err) + + t.Run("service methods are registered", func(t *testing.T) { + assert.NotEmpty(t, runtime.serviceNamesByMethod) + }) + + t.Run("method maps to service name", func(t *testing.T) { + serviceName, ok := runtime.serviceNamesByMethod["LookupProductById"] + assert.True(t, ok) + assert.Equal(t, "product.v1.ProductService", serviceName) + }) } // =============== Test Schemas ================== // @@ -23,6 +152,19 @@ var testSchemaWithLookup = ` syntax = "proto3"; package product.v1; +enum ProductStatus { + PRODUCT_STATUS_UNSPECIFIED = 0; + PRODUCT_STATUS_ACTIVE = 1; + PRODUCT_STATUS_DISCONTINUED = 2; + PRODUCT_STATUS_OUT_OF_STOCK = 3; +} + +enum CategoryKind { + CATEGORY_KIND_UNSPECIFIED = 0; + CATEGORY_KIND_PHYSICAL = 1; + CATEGORY_KIND_DIGITAL = 2; +} + service ProductService { rpc LookupProductById(LookupProductByIdRequest) returns (LookupProductByIdResponse) {} } @@ -47,6 +189,8 @@ message Product { string id = 1; string name = 2; double price = 3; + ProductStatus status = 4; + CategoryKind category = 5; } ` diff --git a/v2/pkg/engine/datasource/grpc_datasource/wire.go b/v2/pkg/engine/datasource/grpc_datasource/wire.go index 8e6208512..6a48e93a2 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/wire.go +++ b/v2/pkg/engine/datasource/grpc_datasource/wire.go @@ -41,6 +41,7 @@ type wireField struct { dataType DataType wireType protowire.Type runtimeMessage *runtimeMessage + runtimeEnum *runtimeEnum staticValue string jsonPath string optional bool @@ -50,7 +51,7 @@ type wireField struct { } const ( - minBufferSize = 1 << 10 // 1KiB + minBufferSize = 1 << 9 // 512 bytes maxBufferSize = 1 << 20 // 1MiB defaultBufferCount = 10 @@ -110,7 +111,7 @@ func (w *wireBuilder) Reset() { w.buffers = w.buffers[:0] } -func compileWireMessage(rpcMessage *RPCMessage, message *runtimeMessage) (*wireMessage, error) { +func compileWireMessage(runtime *runtimeSchema, rpcMessage *RPCMessage, message *runtimeMessage) (*wireMessage, error) { if message == nil { return nil, fmt.Errorf("message not found for fetch request") } @@ -143,8 +144,27 @@ func compileWireMessage(rpcMessage *RPCMessage, message *runtimeMessage) (*wireM listMetadata: rpcField.ListMetadata, } + if rpcField.EnumName != "" { + rtEnum, ok := runtime.enumByName[rpcField.EnumName] + if !ok { + return nil, fmt.Errorf("enum not found for name %s", rpcField.EnumName) + } + + wf.runtimeEnum = rtEnum + } + if rpcField.Message != nil { - child, err := compileWireMessage(rpcField.Message, field.message) + fieldMessage := field.message + // we we are using wrapper messages, they are compiled from the protobuf schema but doesn't match with the RPC planner schema. + // We need to resolve the correct message from the runtime schema. + if rpcField.Message.Name != fieldMessage.name { + fieldMessage = runtime.getMessageByName(rpcField.Message.Name) + if fieldMessage == nil { + return nil, fmt.Errorf("message not found for name %s", rpcField.Message.Name) + } + } + + child, err := compileWireMessage(runtime, rpcField.Message, fieldMessage) if err != nil { return nil, err } @@ -202,7 +222,7 @@ func (f *wireField) appendFieldWire(builder *wireBuilder, buf *bytes.Buffer, dat return nil } - if f.isListWrapper() { + if f.listMetadata != nil { // TODO: build a wireMessage for the list wrapper and just create the proto wire for it //wm := &wireMessage{fields: make([]wireField, 0, 1)} return f.appendListFieldValue(builder, buf, fieldData, 0) @@ -235,7 +255,6 @@ func (f *wireField) appendListFieldValue(builder *wireBuilder, buf *bytes.Buffer return f.appendFieldWire(builder, buf, data) } - // TODO: Check for optional md := f.listMetadata.LevelInfo[level] level++ @@ -280,6 +299,7 @@ func (f *wireField) appendListFieldValue(builder *wireBuilder, buf *bytes.Buffer wireType: getWireType(itemsField.dataType), runtimeMessage: itemsField.message, listMetadata: f.listMetadata, + child: f.child, } iwf.tag = protowire.AppendTag(nil, iwf.number, iwf.wireType) @@ -328,10 +348,6 @@ func (f *wireField) appendOptionalScalarFieldValue(builder *wireBuilder, buf *by return nil } -func (f *wireField) isListWrapper() bool { - return f.listMetadata != nil -} - func (f *wireField) isOptionalScalar() bool { return f.optional && f.dataType != DataTypeMessage } @@ -341,7 +357,6 @@ func (f *wireField) appendFieldValue(builder *wireBuilder, buf *bytes.Buffer, da childWire, err := f.child.createProtoWire(builder, data) if err != nil { return err - } buf.Write(f.tag) @@ -351,13 +366,19 @@ func (f *wireField) appendFieldValue(builder *wireBuilder, buf *bytes.Buffer, da switch f.wireType { case protowire.BytesType: - value := data.GetStringBytes() buf.Write(f.tag) - buf.Write(protowire.AppendBytes(nil, value)) + buf.Write(protowire.AppendBytes(nil, data.GetStringBytes())) case protowire.VarintType: - uintValue := data.GetUint64() + value := data.GetUint64() + if f.runtimeEnum != nil { + var err error + if value, err = f.getEnumValue(data); err != nil { + return err + } + } + buf.Write(f.tag) - buf.Write(protowire.AppendVarint(nil, uintValue)) + buf.Write(protowire.AppendVarint(nil, value)) case protowire.Fixed64Type: buf.Write(f.tag) buf.Write(protowire.AppendFixed64(nil, math.Float64bits(data.GetFloat64()))) @@ -368,6 +389,24 @@ func (f *wireField) appendFieldValue(builder *wireBuilder, buf *bytes.Buffer, da return nil } +func (f *wireField) getEnumValue(data *astjson.Value) (uint64, error) { + enumValueName := data.GetStringBytes() + if len(enumValueName) == 0 { + return 0, fmt.Errorf("enum value name is required for enum field %s", f.jsonPath) + } + + ev, found := f.runtimeEnum.valuesByName[string(enumValueName)] + if !found { + return 0, fmt.Errorf("enum value not found for name %s", string(enumValueName)) + } + + if ev.value < 0 { + return 0, fmt.Errorf("enum value %s is negative for enum field %s", string(enumValueName), f.jsonPath) + } + + return uint64(ev.value), nil +} + func getWireType(dataType DataType) protowire.Type { switch dataType { case DataTypeString, DataTypeBytes: diff --git a/v2/pkg/engine/datasource/grpc_datasource/wire_test.go b/v2/pkg/engine/datasource/grpc_datasource/wire_test.go index 972d476f4..814d66177 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/wire_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/wire_test.go @@ -120,13 +120,24 @@ service TestService { } ` +func testWireMapping() *GRPCMapping { + return &GRPCMapping{ + Service: "TestService", + EnumValues: map[string][]EnumValueMapping{ + "Status": { + {Value: "UNSPECIFIED", TargetValue: "STATUS_UNSPECIFIED"}, + {Value: "ACTIVE", TargetValue: "STATUS_ACTIVE"}, + {Value: "INACTIVE", TargetValue: "STATUS_INACTIVE"}, + }, + }, + } +} + func newWireTestRuntime(t *testing.T) *runtimeSchema { t.Helper() - compiler, err := NewProtoCompiler(testWireSchema, &GRPCMapping{ - Service: "TestService", - }) + compiler, err := NewProtoCompiler(testWireSchema, testWireMapping()) require.NoError(t, err) - runtime, err := newSchemaRuntime(compiler) + runtime, err := newSchemaRuntime(compiler.doc) require.NoError(t, err) return runtime } @@ -135,7 +146,7 @@ func compileTestWireMessage(t *testing.T, runtime *runtimeSchema, messageName st t.Helper() msg := runtime.getMessageByName(messageName) require.NotNilf(t, msg, "message %q not found in runtime", messageName) - wm, err := compileWireMessage(rpcMessage, msg) + wm, err := compileWireMessage(runtime, rpcMessage, msg) require.NoError(t, err) return wm } @@ -155,20 +166,20 @@ func marshalDynamic(t *testing.T, runtime *runtimeSchema, messageName string, bu // assertProtoEqual unmarshals both byte slices into the same message type and compares the resulting messages. // This allows valid encoding differences (e.g. packed vs unpacked repeated scalars) to pass. -func assertProtoEqual(t *testing.T, runtime *runtimeSchema, messageName string, expected, got []byte) { +func assertProtoEqual(t *testing.T, runtime *runtimeSchema, messageName string, expectedMessageBytes, gotMessageBytes []byte) { t.Helper() rtMsg := runtime.getMessageByName(messageName) require.NotNilf(t, rtMsg, "message %q not found in runtime", messageName) expectedMsg := dynamicpb.NewMessage(rtMsg.desc) - require.NoError(t, proto.Unmarshal(expected, expectedMsg), "failed to unmarshal expected bytes") + require.NoError(t, proto.Unmarshal(expectedMessageBytes, expectedMsg), "failed to unmarshal expected bytes") gotMsg := dynamicpb.NewMessage(rtMsg.desc) - require.NoError(t, proto.Unmarshal(got, gotMsg), "failed to unmarshal got bytes") + require.NoError(t, proto.Unmarshal(gotMessageBytes, gotMsg), "failed to unmarshal got bytes %x", gotMessageBytes) assert.True(t, proto.Equal(expectedMsg, gotMsg), "messages not equal\nexpected: %v\ngot: %v\nexpected bytes: %x\ngot bytes: %x", - expectedMsg, gotMsg, expected, got) + expectedMsg, gotMsg, expectedMessageBytes, gotMessageBytes) } // setWrapperValue sets a google.protobuf wrapper field (e.g. StringValue, Int32Value) on a dynamic message. @@ -242,7 +253,7 @@ func TestCompileWireMessage(t *testing.T) { require.NotNil(t, msg) // Optional + IsListType: compileWireMessage must not treat this as a wrapper scalar. // Currently this errors because it tries to wrap in google.protobuf.*Value and looks for "value" in ListOfString. - wm, err := compileWireMessage(&RPCMessage{ + wm, err := compileWireMessage(runtime, &RPCMessage{ Name: "ListWrapperRequest", Fields: RPCFields{ { @@ -557,10 +568,11 @@ func TestCreateProtoWire(t *testing.T) { }, }) - // STATUS_ACTIVE = 1 + // GraphQL sends enum values as strings (e.g. "ACTIVE"), not proto-prefixed names or integers. + // The wire builder must resolve "ACTIVE" -> STATUS_ACTIVE = 1 via the runtime enum map. builder := newWireBuilder(minBufferSize) - got, err := wm.createProtoWire(builder, astjson.MustParse(`{"status":1}`)) + got, err := wm.createProtoWire(builder, astjson.MustParse(`{"status":"ACTIVE"}`)) require.NoError(t, err) expected := marshalDynamic(t, runtime, "EnumRequest", func(msg *dynamicpb.Message, desc protoref.MessageDescriptor) { @@ -580,7 +592,7 @@ func TestCreateProtoWire(t *testing.T) { builder := newWireBuilder(minBufferSize) - got, err := wm.createProtoWire(builder, astjson.MustParse(`{"statuses":[0,1,2]}`)) + got, err := wm.createProtoWire(builder, astjson.MustParse(`{"statuses":["UNSPECIFIED","ACTIVE","INACTIVE"]}`)) require.NoError(t, err) expected := marshalDynamic(t, runtime, "EnumRequest", func(msg *dynamicpb.Message, desc protoref.MessageDescriptor) { @@ -598,7 +610,7 @@ func TestCreateProtoWire(t *testing.T) { // createProtoWire must produce: ListWrapperRequest { optional_tags: ListOfString { list: List { items: [...] } } } msg := runtime.getMessageByName("ListWrapperRequest") require.NotNil(t, msg) - wm, compileErr := compileWireMessage(&RPCMessage{ + wm, compileErr := compileWireMessage(runtime, &RPCMessage{ Name: "ListWrapperRequest", Fields: RPCFields{ { @@ -649,7 +661,7 @@ func TestCreateProtoWire(t *testing.T) { // createProtoWire must produce: NestedListRequest { tag_groups: ListOfListOfString { list: { items: [ ListOfString{...}, ... ] } } } msg := runtime.getMessageByName("NestedListRequest") require.NotNil(t, msg) - wm, compileErr := compileWireMessage(&RPCMessage{ + wm, compileErr := compileWireMessage(runtime, &RPCMessage{ Name: "NestedListRequest", Fields: RPCFields{ { @@ -709,6 +721,88 @@ func TestCreateProtoWire(t *testing.T) { assertProtoEqual(t, runtime, "NestedListRequest", expected, got) }) + t.Run("nested list wrapper two levels with messages", func(t *testing.T) { + // NestedListRequest { item_groups: ListOfListOfNestedItem { list: { items: [ ListOfNestedItem{...}, ... ] } } } + // The inner ListOfNestedItem contains NestedItem messages with id + value fields. + msg := runtime.getMessageByName("NestedListRequest") + require.NotNil(t, msg) + wm, compileErr := compileWireMessage(runtime, &RPCMessage{ + Name: "NestedListRequest", + Fields: RPCFields{ + { + Name: "item_groups", + ProtoTypeName: DataTypeMessage, + JSONPath: "itemGroups", + IsListType: true, + ListMetadata: &ListMetadata{ + NestingLevel: 2, + LevelInfo: []LevelInfo{{Optional: false}, {Optional: false}}, + }, + Message: &RPCMessage{ + Name: "NestedItem", + Fields: RPCFields{ + {Name: "id", ProtoTypeName: DataTypeString, JSONPath: "id"}, + {Name: "value", ProtoTypeName: DataTypeString, JSONPath: "value"}, + }, + }, + }, + }, + }, msg) + + require.NoError(t, compileErr) + + builder := newWireBuilder(minBufferSize) + + got, err := wm.createProtoWire(builder, astjson.MustParse(`{"itemGroups":[[{"id":"1","value":"a"},{"id":"2","value":"b"}],[{"id":"3","value":"c"}]]}`)) + require.NoError(t, err) + + expected := marshalDynamic(t, runtime, "NestedListRequest", func(msg *dynamicpb.Message, desc protoref.MessageDescriptor) { + itemGroupsField := desc.Fields().ByName("item_groups") + loloniDesc := itemGroupsField.Message() // ListOfListOfNestedItem + + outerListField := loloniDesc.Fields().ByName("list") + outerListDesc := outerListField.Message() // ListOfListOfNestedItem.List + outerItemsField := outerListDesc.Fields().ByName("items") + loniDesc := outerItemsField.Message() // ListOfNestedItem + + // Helper: build a ListOfNestedItem from NestedItem values + buildListOfNestedItem := func(items ...struct{ id, value string }) *dynamicpb.Message { + innerListField := loniDesc.Fields().ByName("list") + innerListDesc := innerListField.Message() // ListOfNestedItem.List + nestedItemDesc := innerListDesc.Fields().ByName("items").Message() // NestedItem + + innerList := dynamicpb.NewMessage(innerListDesc) + itemsList := innerList.Mutable(innerListDesc.Fields().ByName("items")).List() + for _, item := range items { + ni := dynamicpb.NewMessage(nestedItemDesc) + ni.Set(nestedItemDesc.Fields().ByName("id"), protoref.ValueOfString(item.id)) + ni.Set(nestedItemDesc.Fields().ByName("value"), protoref.ValueOfString(item.value)) + itemsList.Append(protoref.ValueOfMessage(ni)) + } + loni := dynamicpb.NewMessage(loniDesc) + loni.Set(innerListField, protoref.ValueOfMessage(innerList)) + return loni + } + + outerList := dynamicpb.NewMessage(outerListDesc) + outerItems := outerList.Mutable(outerItemsField).List() + outerItems.Append(protoref.ValueOfMessage(buildListOfNestedItem( + struct{ id, value string }{"1", "a"}, + struct{ id, value string }{"2", "b"}, + ))) + outerItems.Append(protoref.ValueOfMessage(buildListOfNestedItem( + struct{ id, value string }{"3", "c"}, + ))) + + lolosni := dynamicpb.NewMessage(loloniDesc) + lolosni.Set(outerListField, protoref.ValueOfMessage(outerList)) + + msg.Set(itemGroupsField, protoref.ValueOfMessage(lolosni)) + }) + + assertProtoEqual(t, runtime, "NestedListRequest", expected, got) + }) + t.Run("mixed request with multiple field types", func(t *testing.T) { wm := compileTestWireMessage(t, runtime, "MixedRequest", &RPCMessage{ Name: "MixedRequest", @@ -723,7 +817,7 @@ func TestCreateProtoWire(t *testing.T) { builder := newWireBuilder(minBufferSize) - got, err := wm.createProtoWire(builder, astjson.MustParse(`{"id":"p1","description":"a product","tags":["sale","new"],"price":29.99,"status":1}`)) + got, err := wm.createProtoWire(builder, astjson.MustParse(`{"id":"p1","description":"a product","tags":["sale","new"],"price":29.99,"status":"ACTIVE"}`)) require.NoError(t, err) expected := marshalDynamic(t, runtime, "MixedRequest", func(msg *dynamicpb.Message, desc protoref.MessageDescriptor) { From 7425342d3a4427a646416934d18c3d69ea59cf56 Mon Sep 17 00:00:00 2001 From: Ludwig Bedacht Date: Wed, 22 Apr 2026 14:33:40 +0200 Subject: [PATCH 03/12] chore: run first test in datasource --- .../grpc_datasource/grpc_datasource.go | 149 +++++++++++++++--- .../grpc_datasource/grpc_datasource_test.go | 24 +-- .../datasource/grpc_datasource/program.go | 16 ++ .../datasource/grpc_datasource/wire_test.go | 113 +++++++++++++ 4 files changed, 260 insertions(+), 42 deletions(-) diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go index 8c15816d7..7de885282 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go @@ -18,6 +18,7 @@ import ( "golang.org/x/sync/errgroup" "google.golang.org/grpc" "google.golang.org/grpc/metadata" + "google.golang.org/protobuf/types/dynamicpb" "github.com/wundergraph/astjson" "github.com/wundergraph/go-arena" @@ -100,6 +101,132 @@ func NewDataSource(client grpc.ClientConnInterface, config DataSourceConfig) (*D // The input is expected to contain the necessary information to make // a gRPC call, including service name, method name, and request data. func (d *DataSource) Load(ctx context.Context, headers http.Header, input []byte) (data []byte, err error) { + // get variables from input + value, err := astjson.ParseBytes(input) + if err != nil { + return nil, err + } + + if value.Exists("body") { + value = value.Get("body") + } + + astJsonVariables := value.Get("variables") + if !value.Exists() { + return nil, fmt.Errorf("variables are required") + } + + variables := gjson.ParseBytes(input).Get("body.variables") + + _ = astJsonVariables + + var poolItems []*arena.PoolItem + defer func() { + d.pool.ReleaseMany(poolItems) + }() + + item := d.acquirePoolItem(input, 0) + poolItems = append(poolItems, item) + builder := newJSONBuilder(item.Arena, d.mapping, variables) + + if d.disabled { + return builder.writeErrorBytes(fmt.Errorf("gRPC datasource needs to be enabled to be used")), nil + } + + // convert headers to grpc metadata and attach to ctx + if len(headers) > 0 { + // assume that each header has exactly one value for default pairs size + pairs := make([]string, 0, len(headers)*2) + for headerName, headerValues := range headers { + headerName = strings.ToLower(headerName) + for _, v := range headerValues { + pairs = append(pairs, headerName, v) + } + } + ctx = metadata.AppendToOutgoingContext(ctx, pairs...) + } + + program, err := compileProgram(d.plan, d.rc.runtime) + if err != nil { + return nil, err + } + + root := astjson.ObjectValue(nil) + + for _, stage := range program.stages { + results := make([]resultData, len(stage.fetches)) + + for index, fetch := range stage.fetches { + responseMessage := dynamicpb.NewMessage(fetch.response.responseType.desc) + + wireMessage, err := compileWireMessage(d.rc.runtime, &fetch.request.rpcMessage, fetch.request.message) + if err != nil { + return nil, err + } + + wm, err := wireMessage.createProtoWire(newWireBuilder(minBufferSize), astJsonVariables) + if err != nil { + return nil, err + } + + pm := NewPreWiredInputMessage(wm) + err = d.cc.Invoke(ctx, fetch.methodFullName(), pm, responseMessage, grpc.ForceCodecV2(&connectCodec{})) + if err != nil { + return nil, err + } + + responseJson, err := builder.marshalResponseJSON(&fetch.response.rpcMessage, responseMessage) + if err != nil { + return nil, err + } + + results[index] = resultData{ + kind: fetch.kind, + response: responseJson, + responsePath: fetch.responsePath, + } + } + + for _, result := range results { + switch result.kind { + case CallKindResolve, CallKindRequired: + err = builder.mergeWithPath(root, result.response, result.responsePath) + default: + root, err = builder.mergeValues(root, result.response) + } + if err != nil { + return builder.writeErrorBytes(err), nil + } + } + } + + resultValue := builder.toDataObject(root) + return resultValue.MarshalTo(nil), err +} + +func (d *DataSource) acquirePoolItem(input []byte, index int) *arena.PoolItem { + keyGen := xxhash.New() + _, _ = keyGen.Write(input) + var b [8]byte + binary.LittleEndian.PutUint64(b[:], uint64(index)) + _, _ = keyGen.Write(b[:]) + key := keyGen.Sum64() + item := d.pool.Acquire(key) + return item +} + +// LoadWithFiles implements resolve.DataSource interface. +// Similar to Load, but handles file uploads if needed. +// +// Note: File uploads are typically not part of gRPC, so this method +// might not be applicable for most gRPC use cases. +// +// Currently unimplemented. +func (d *DataSource) LoadWithFiles(ctx context.Context, headers http.Header, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { + panic("unimplemented") +} + +func (d *DataSource) LoadOld(ctx context.Context, headers http.Header, input []byte) (data []byte, err error) { // get variables from input variables := gjson.Parse(unsafebytes.BytesToString(input)).Get("body.variables") @@ -211,25 +338,3 @@ func (d *DataSource) Load(ctx context.Context, headers http.Header, input []byte value := builder.toDataObject(root) return value.MarshalTo(nil), err } - -func (d *DataSource) acquirePoolItem(input []byte, index int) *arena.PoolItem { - keyGen := xxhash.New() - _, _ = keyGen.Write(input) - var b [8]byte - binary.LittleEndian.PutUint64(b[:], uint64(index)) - _, _ = keyGen.Write(b[:]) - key := keyGen.Sum64() - item := d.pool.Acquire(key) - return item -} - -// LoadWithFiles implements resolve.DataSource interface. -// Similar to Load, but handles file uploads if needed. -// -// Note: File uploads are typically not part of gRPC, so this method -// might not be applicable for most gRPC use cases. -// -// Currently unimplemented. -func (d *DataSource) LoadWithFiles(ctx context.Context, headers http.Header, input []byte, files []*httpclient.FileUpload) (data []byte, err error) { - panic("unimplemented") -} diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go index 112ce26d6..9f70fbf4e 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go @@ -33,7 +33,7 @@ func Benchmark_DataSource_Load(b *testing.B) { schemaDoc := grpctest.MustGraphQLSchema(b) query := `query ComplexFilterTypeQuery($filter: ComplexFilterTypeInput!) { complexFilterType(filter: $filter) { id name } }` - variables := `{"variables":{"filter":{"name":"test","filterField1":"test","filterField2":"test"}}}` + variables := `{"variables":{"filter":{"filter":{"name":"test","filterField1":"test","filterField2":"test"}}}}` // Parse the GraphQL query queryDoc, report := astparser.ParseGraphqlDocumentString(query) @@ -175,7 +175,7 @@ func setupTestGRPCServer(t testing.TB) (conn *grpc.ClientConn, cleanup func()) { // Test_DataSource_Load tests the datasource.Load method with a mock gRPC interface func Test_DataSource_Load(t *testing.T) { query := `query ComplexFilterTypeQuery($filter: ComplexFilterTypeInput!) { complexFilterType(filter: $filter) { id name } }` - variables := `{"variables":{"filter":{"name":"test","filterField1":"test","filterField2":"test"}}}` + variables := `{"variables":{"filter":{"filter":{"name":"test","filterField1":"test","filterField2":"test"}}}}` // Parse the GraphQL schema schemaDoc := grpctest.MustGraphQLSchema(t) @@ -197,28 +197,12 @@ func Test_DataSource_Load(t *testing.T) { Definition: &schemaDoc, SubgraphName: "Products", Compiler: compiler, - Mapping: &GRPCMapping{ - Service: "Products", - QueryRPCs: RPCConfigMap[RPCConfig]{ - "complexFilterType": { - RPC: "QueryComplexFilterType", - Request: "QueryComplexFilterTypeRequest", - Response: "QueryComplexFilterTypeResponse", - }, - }, - Fields: map[string]FieldMap{ - "Query": { - "complexFilterType": { - TargetName: "complex_filter_type", - }, - }, - }, - }, + Mapping: testMapping(), }) require.NoError(t, err) - _, err = ds.Load(context.Background(), nil, []byte(`{"query":"`+query+`","variables":`+variables+`}`)) + _, err = ds.Load(context.Background(), nil, []byte(`{"query":"`+query+`","body":`+variables+`}`)) require.NoError(t, err) } diff --git a/v2/pkg/engine/datasource/grpc_datasource/program.go b/v2/pkg/engine/datasource/grpc_datasource/program.go index d8292f7ae..041d4485c 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/program.go +++ b/v2/pkg/engine/datasource/grpc_datasource/program.go @@ -2,6 +2,7 @@ package grpcdatasource import ( "fmt" + "strings" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" ) @@ -36,6 +37,20 @@ type fetchRequest struct { type fetchResponse struct { // reponse type is the type of the response message. responseType *runtimeMessage + rpcMessage RPCMessage +} + +func (f *fetch) methodFullName() string { + var builder strings.Builder + + builder.Grow(len(f.serviceName) + len(f.methodName) + 2) + builder.WriteRune('/') + builder.WriteString(f.serviceName) + builder.WriteRune('/') + builder.WriteString(f.methodName) + + return builder.String() + } func compileProgram(plan *RPCExecutionPlan, runtime *runtimeSchema) (*program, error) { @@ -116,6 +131,7 @@ func compileFetch(call *RPCCall, runtime *runtimeSchema, dependentCall *RPCCall) f.response = &fetchResponse{ responseType: responseMessage, + rpcMessage: call.Response, } wireMessage, err := compileWireMessage(runtime, &f.request.rpcMessage, requestMessage) diff --git a/v2/pkg/engine/datasource/grpc_datasource/wire_test.go b/v2/pkg/engine/datasource/grpc_datasource/wire_test.go index 814d66177..76f113edc 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/wire_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/wire_test.go @@ -107,6 +107,14 @@ message MixedRequest { Status status = 7; } +message LookupProductByIdRequestKey { + string id = 1; +} + +message LookupProductByIdRequest { + repeated LookupProductByIdRequestKey keys = 1; +} + service TestService { rpc Empty(EmptyRequest) returns (EmptyRequest) {} rpc Scalar(ScalarRequest) returns (ScalarRequest) {} @@ -117,6 +125,7 @@ service TestService { rpc NestedList(NestedListRequest) returns (NestedListRequest) {} rpc Enum(EnumRequest) returns (EnumRequest) {} rpc Mixed(MixedRequest) returns (MixedRequest) {} + rpc LookupProductById(LookupProductByIdRequest) returns (LookupProductByIdRequest) {} } ` @@ -311,6 +320,32 @@ func TestCompileWireMessage(t *testing.T) { assert.Equal(t, DataTypeEnum, wm.fields[1].dataType) assert.True(t, wm.fields[1].repeated) }) + + t.Run("entity lookup request", func(t *testing.T) { + wm := compileTestWireMessage(t, runtime, "LookupProductByIdRequest", &RPCMessage{ + Name: "LookupProductByIdRequest", + Fields: RPCFields{ + { + Name: "keys", + ProtoTypeName: DataTypeMessage, + Repeated: true, + JSONPath: "representations", + Message: &RPCMessage{ + Name: "LookupProductByIdRequestKey", + MemberTypes: []string{"Product"}, + Fields: RPCFields{ + {Name: "id", ProtoTypeName: DataTypeString, JSONPath: "id"}, + }, + }, + }, + }, + }) + assert.Len(t, wm.fields, 1) + assert.True(t, wm.fields[0].repeated) + assert.Equal(t, DataTypeMessage, wm.fields[0].dataType) + assert.NotNil(t, wm.fields[0].child) + assert.Len(t, wm.fields[0].child.fields, 1) + }) } func TestCreateProtoWire(t *testing.T) { @@ -832,4 +867,82 @@ func TestCreateProtoWire(t *testing.T) { assertProtoEqual(t, runtime, "MixedRequest", expected, got) }) + + t.Run("entity lookup single key", func(t *testing.T) { + wm := compileTestWireMessage(t, runtime, "LookupProductByIdRequest", &RPCMessage{ + Name: "LookupProductByIdRequest", + Fields: RPCFields{ + { + Name: "keys", + ProtoTypeName: DataTypeMessage, + Repeated: true, + JSONPath: "representations", + Message: &RPCMessage{ + Name: "LookupProductByIdRequestKey", + MemberTypes: []string{"Product"}, + Fields: RPCFields{ + {Name: "id", ProtoTypeName: DataTypeString, JSONPath: "id"}, + }, + }, + }, + }, + }) + + builder := newWireBuilder(minBufferSize) + + got, err := wm.createProtoWire(builder, astjson.MustParse(`{"representations":[{"__typename":"Product","id":"1"}]}`)) + require.NoError(t, err) + + expected := marshalDynamic(t, runtime, "LookupProductByIdRequest", func(msg *dynamicpb.Message, desc protoref.MessageDescriptor) { + keysField := desc.Fields().ByName("keys") + keyDesc := keysField.Message() + list := msg.Mutable(keysField).List() + + key1 := dynamicpb.NewMessage(keyDesc) + key1.Set(keyDesc.Fields().ByName("id"), protoref.ValueOfString("1")) + list.Append(protoref.ValueOfMessage(key1)) + }) + + assertProtoEqual(t, runtime, "LookupProductByIdRequest", expected, got) + }) + + t.Run("entity lookup multiple keys", func(t *testing.T) { + wm := compileTestWireMessage(t, runtime, "LookupProductByIdRequest", &RPCMessage{ + Name: "LookupProductByIdRequest", + Fields: RPCFields{ + { + Name: "keys", + ProtoTypeName: DataTypeMessage, + Repeated: true, + JSONPath: "representations", + Message: &RPCMessage{ + Name: "LookupProductByIdRequestKey", + MemberTypes: []string{"Product"}, + Fields: RPCFields{ + {Name: "id", ProtoTypeName: DataTypeString, JSONPath: "id"}, + }, + }, + }, + }, + }) + + builder := newWireBuilder(minBufferSize) + + got, err := wm.createProtoWire(builder, astjson.MustParse(`{"representations":[{"__typename":"Product","id":"1"},{"__typename":"Product","id":"2"},{"__typename":"Product","id":"3"}]}`)) + require.NoError(t, err) + + expected := marshalDynamic(t, runtime, "LookupProductByIdRequest", func(msg *dynamicpb.Message, desc protoref.MessageDescriptor) { + keysField := desc.Fields().ByName("keys") + keyDesc := keysField.Message() + list := msg.Mutable(keysField).List() + + for _, id := range []string{"1", "2", "3"} { + key := dynamicpb.NewMessage(keyDesc) + key.Set(keyDesc.Fields().ByName("id"), protoref.ValueOfString(id)) + list.Append(protoref.ValueOfMessage(key)) + } + }) + + assertProtoEqual(t, runtime, "LookupProductByIdRequest", expected, got) + }) } From 00c64f0fa792cc6864a2d4fe6b3f20dddb1e269a Mon Sep 17 00:00:00 2001 From: Ludwig Bedacht Date: Wed, 22 Apr 2026 15:10:44 +0200 Subject: [PATCH 04/12] chore: reduce number of allocations --- .../grpc_datasource/grpc_datasource.go | 62 ++++---- .../datasource/grpc_datasource/program.go | 46 ++---- .../engine/datasource/grpc_datasource/wire.go | 150 +++++++----------- .../datasource/grpc_datasource/wire_test.go | 72 +++------ 4 files changed, 128 insertions(+), 202 deletions(-) diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go index 7de885282..88e1a1f5d 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go @@ -7,6 +7,7 @@ package grpcdatasource import ( + "bytes" "context" "encoding/binary" "fmt" @@ -52,7 +53,11 @@ type DataSource struct { definition *ast.Document disabled bool - pool *arena.Pool + pool *arena.Pool + program *program + codecOpt grpc.CallOption + wireBuf bytes.Buffer + resultsBuf []resultData } type ProtoConfig struct { @@ -79,6 +84,10 @@ func NewDataSource(client grpc.ClientConnInterface, config DataSourceConfig) (*D if err != nil { return nil, err } + program, err := compileProgram(plan, config.Compiler.runtime) + if err != nil { + return nil, err + } return &DataSource{ plan: plan, @@ -89,6 +98,8 @@ func NewDataSource(client grpc.ClientConnInterface, config DataSourceConfig) (*D federationConfigs: config.FederationConfigs, disabled: config.Disabled, pool: arena.NewArenaPool(), + program: program, + codecOpt: grpc.ForceCodecV2(&connectCodec{}), }, nil } @@ -101,8 +112,15 @@ func NewDataSource(client grpc.ClientConnInterface, config DataSourceConfig) (*D // The input is expected to contain the necessary information to make // a gRPC call, including service name, method name, and request data. func (d *DataSource) Load(ctx context.Context, headers http.Header, input []byte) (data []byte, err error) { + var poolItems []*arena.PoolItem + defer func() { + d.pool.ReleaseMany(poolItems) + }() + + item := d.acquirePoolItem(input, 0) + poolItems = append(poolItems, item) // get variables from input - value, err := astjson.ParseBytes(input) + value, err := astjson.ParseBytesWithArena(item.Arena, input) if err != nil { return nil, err } @@ -120,13 +138,6 @@ func (d *DataSource) Load(ctx context.Context, headers http.Header, input []byte _ = astJsonVariables - var poolItems []*arena.PoolItem - defer func() { - d.pool.ReleaseMany(poolItems) - }() - - item := d.acquirePoolItem(input, 0) - poolItems = append(poolItems, item) builder := newJSONBuilder(item.Arena, d.mapping, variables) if d.disabled { @@ -146,31 +157,24 @@ func (d *DataSource) Load(ctx context.Context, headers http.Header, input []byte ctx = metadata.AppendToOutgoingContext(ctx, pairs...) } - program, err := compileProgram(d.plan, d.rc.runtime) - if err != nil { - return nil, err - } - root := astjson.ObjectValue(nil) - for _, stage := range program.stages { - results := make([]resultData, len(stage.fetches)) + for _, stage := range d.program.stages { + results := d.resultsBuf[:0] + if cap(results) < len(stage.fetches) { + results = make([]resultData, 0, len(stage.fetches)) + } - for index, fetch := range stage.fetches { + for _, fetch := range stage.fetches { responseMessage := dynamicpb.NewMessage(fetch.response.responseType.desc) - wireMessage, err := compileWireMessage(d.rc.runtime, &fetch.request.rpcMessage, fetch.request.message) - if err != nil { + d.wireBuf.Reset() + if err = fetch.request.wire.appendProtoWire(&d.wireBuf, astJsonVariables); err != nil { return nil, err } - wm, err := wireMessage.createProtoWire(newWireBuilder(minBufferSize), astJsonVariables) - if err != nil { - return nil, err - } - - pm := NewPreWiredInputMessage(wm) - err = d.cc.Invoke(ctx, fetch.methodFullName(), pm, responseMessage, grpc.ForceCodecV2(&connectCodec{})) + pm := NewPreWiredInputMessage(d.wireBuf.Bytes()) + err = d.cc.Invoke(ctx, fetch.methodFullName, pm, responseMessage, d.codecOpt) if err != nil { return nil, err } @@ -180,13 +184,15 @@ func (d *DataSource) Load(ctx context.Context, headers http.Header, input []byte return nil, err } - results[index] = resultData{ + results = append(results, resultData{ kind: fetch.kind, response: responseJson, responsePath: fetch.responsePath, - } + }) } + d.resultsBuf = results + for _, result := range results { switch result.kind { case CallKindResolve, CallKindRequired: diff --git a/v2/pkg/engine/datasource/grpc_datasource/program.go b/v2/pkg/engine/datasource/grpc_datasource/program.go index 041d4485c..eac0063e9 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/program.go +++ b/v2/pkg/engine/datasource/grpc_datasource/program.go @@ -2,7 +2,6 @@ package grpcdatasource import ( "fmt" - "strings" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" ) @@ -16,14 +15,15 @@ type stage struct { } type fetch struct { - id int - kind CallKind - dependentCall *RPCCall - serviceName string - methodName string - responsePath ast.Path - request *fetchRequest - response *fetchResponse + id int + kind CallKind + dependentCall *RPCCall + serviceName string + methodName string + methodFullName string + responsePath ast.Path + request *fetchRequest + response *fetchResponse } type fetchRequest struct { @@ -35,24 +35,11 @@ type fetchRequest struct { } type fetchResponse struct { - // reponse type is the type of the response message. + // response type is the type of the response message. responseType *runtimeMessage rpcMessage RPCMessage } -func (f *fetch) methodFullName() string { - var builder strings.Builder - - builder.Grow(len(f.serviceName) + len(f.methodName) + 2) - builder.WriteRune('/') - builder.WriteString(f.serviceName) - builder.WriteRune('/') - builder.WriteString(f.methodName) - - return builder.String() - -} - func compileProgram(plan *RPCExecutionPlan, runtime *runtimeSchema) (*program, error) { stageIndexes, err := compileStageIndexes(plan) if err != nil { @@ -106,12 +93,13 @@ func compileFetch(call *RPCCall, runtime *runtimeSchema, dependentCall *RPCCall) } f := fetch{ - id: call.ID, - kind: call.Kind, - dependentCall: dependentCall, - serviceName: serviceName, - methodName: call.MethodName, - responsePath: call.ResponsePath, + id: call.ID, + kind: call.Kind, + dependentCall: dependentCall, + serviceName: serviceName, + methodName: call.MethodName, + methodFullName: "/" + serviceName + "/" + call.MethodName, + responsePath: call.ResponsePath, } requestMessage := runtime.getMessageByName(call.Request.Name) diff --git a/v2/pkg/engine/datasource/grpc_datasource/wire.go b/v2/pkg/engine/datasource/grpc_datasource/wire.go index 6a48e93a2..7f54d25fe 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/wire.go +++ b/v2/pkg/engine/datasource/grpc_datasource/wire.go @@ -4,7 +4,6 @@ import ( "bytes" "fmt" "math" - "sync" "github.com/wundergraph/astjson" "google.golang.org/protobuf/encoding/protowire" @@ -51,65 +50,35 @@ type wireField struct { } const ( - minBufferSize = 1 << 9 // 512 bytes - maxBufferSize = 1 << 20 // 1MiB - - defaultBufferCount = 10 + minBufferSize = 1 << 8 // 256 bytes ) -// TODO: This should use an area instead -type bufferPool struct { - pool sync.Pool - defaultSize int +// writeVarint writes a varint to buf using a stack-allocated scratch buffer to avoid heap allocation. +func writeVarint(buf *bytes.Buffer, v uint64) { + var scratch [10]byte + buf.Write(protowire.AppendVarint(scratch[:0], v)) } -func newMempool(size int) *bufferPool { - if size <= 0 || size < minBufferSize || size > maxBufferSize { - size = minBufferSize - } - - return &bufferPool{ - pool: sync.Pool{ - New: func() any { - return bytes.NewBuffer(make([]byte, 0, size)) - }, - }, - defaultSize: size, - } +// writeFixed64 writes a fixed64 to buf using a stack-allocated scratch buffer to avoid heap allocation. +func writeFixed64(buf *bytes.Buffer, v uint64) { + var scratch [8]byte + buf.Write(protowire.AppendFixed64(scratch[:0], v)) } -func (b *bufferPool) Get() *bytes.Buffer { - buf := b.pool.Get().(*bytes.Buffer) - buf.Reset() - if buf.Cap() < b.defaultSize { - buf.Grow(b.defaultSize - buf.Cap()) - } - - return buf +// writeLengthPrefixed writes a length-delimited field value (length varint + raw bytes) to buf +// without allocating, unlike protowire.AppendBytes(nil, data). +func writeLengthPrefixed(buf *bytes.Buffer, data []byte) { + var scratch [10]byte + buf.Write(protowire.AppendVarint(scratch[:0], uint64(len(data)))) + buf.Write(data) } -func (b *bufferPool) Put(buf *bytes.Buffer) { - b.pool.Put(buf) +// writeTag writes a protobuf tag to buf using a stack-allocated scratch buffer. +func writeTag(buf *bytes.Buffer, num protowire.Number, typ protowire.Type) { + var scratch [10]byte + buf.Write(protowire.AppendTag(scratch[:0], num, typ)) } -type wireBuilder struct { - pool *bufferPool - buffers []*bytes.Buffer -} - -func newWireBuilder(size int) *wireBuilder { - return &wireBuilder{ - pool: newMempool(size), - buffers: make([]*bytes.Buffer, 0, defaultBufferCount), - } -} - -func (w *wireBuilder) Reset() { - for _, buf := range w.buffers { - w.pool.Put(buf) - } - w.buffers = w.buffers[:0] -} func compileWireMessage(runtime *runtimeSchema, rpcMessage *RPCMessage, message *runtimeMessage) (*wireMessage, error) { if message == nil { @@ -180,21 +149,25 @@ func compileWireMessage(runtime *runtimeSchema, rpcMessage *RPCMessage, message } // createProtoWire creates a proto wire from the wire plan. -func (w *wireMessage) createProtoWire(builder *wireBuilder, data *astjson.Value) ([]byte, error) { - // TODO: Use arena or a global buffer pool - buf := builder.pool.Get() +func (w *wireMessage) createProtoWire(data *astjson.Value) ([]byte, error) { + buf := bytes.NewBuffer(make([]byte, 0, minBufferSize)) + if err := w.appendProtoWire(buf, data); err != nil { + return nil, err + } + return buf.Bytes(), nil +} +// appendProtoWire encodes the message fields into the given buffer. +func (w *wireMessage) appendProtoWire(buf *bytes.Buffer, data *astjson.Value) error { for _, field := range w.fields { - err := field.appendFieldWire(builder, buf, data) - if err != nil { - return nil, err + if err := field.appendFieldWire(buf, data); err != nil { + return err } } - - return buf.Bytes(), nil + return nil } -func (f *wireField) appendFieldWire(builder *wireBuilder, buf *bytes.Buffer, data *astjson.Value) error { +func (f *wireField) appendFieldWire(buf *bytes.Buffer, data *astjson.Value) error { var fieldData *astjson.Value if f.jsonPath == "" { @@ -213,7 +186,7 @@ func (f *wireField) appendFieldWire(builder *wireBuilder, buf *bytes.Buffer, dat if f.repeated { for _, element := range fieldData.GetArray() { - err := f.appendFieldValue(builder, buf, element) + err := f.appendFieldValue(buf, element) if err != nil { return err } @@ -225,14 +198,14 @@ func (f *wireField) appendFieldWire(builder *wireBuilder, buf *bytes.Buffer, dat if f.listMetadata != nil { // TODO: build a wireMessage for the list wrapper and just create the proto wire for it //wm := &wireMessage{fields: make([]wireField, 0, 1)} - return f.appendListFieldValue(builder, buf, fieldData, 0) + return f.appendListFieldValue(buf, fieldData, 0) } if f.isOptionalScalar() { - return f.appendOptionalScalarFieldValue(builder, buf, fieldData) + return f.appendOptionalScalarFieldValue(buf, fieldData) } - return f.appendFieldValue(builder, buf, fieldData) + return f.appendFieldValue(buf, fieldData) } // appendListFieldValue appends the list value to the buffer. @@ -249,10 +222,10 @@ func (f *wireField) appendFieldWire(builder *wireBuilder, buf *bytes.Buffer, dat // } // // ``` -func (f *wireField) appendListFieldValue(builder *wireBuilder, buf *bytes.Buffer, data *astjson.Value, level int) error { +func (f *wireField) appendListFieldValue(buf *bytes.Buffer, data *astjson.Value, level int) error { if level >= f.listMetadata.NestingLevel { f.listMetadata = nil // reset the list metadata to avoid infinite recursion - return f.appendFieldWire(builder, buf, data) + return f.appendFieldWire(buf, data) } md := f.listMetadata.LevelInfo[level] @@ -263,17 +236,11 @@ func (f *wireField) appendListFieldValue(builder *wireBuilder, buf *bytes.Buffer return fmt.Errorf("runtime message not found for field %s", f.jsonPath) } - listBuffer := builder.pool.Get() - defer builder.pool.Put(listBuffer) - field, ok := runtimeMsg.fieldsByName["list"] if !ok { return fmt.Errorf("list field not found for message %s but was expected", runtimeMsg.name) } - // We will always have a message type here, therefore we must use the bytes type. - listBuffer.Write(protowire.AppendTag(nil, field.desc.Number(), protowire.BytesType)) - listMessage := field.message if listMessage == nil { return fmt.Errorf("expected nested message type for list wrapper field but the field %s doesn't have a message", f.jsonPath) @@ -289,8 +256,12 @@ func (f *wireField) appendListFieldValue(builder *wireBuilder, buf *bytes.Buffer return fmt.Errorf("list is required but has no elements") } - itemsBuffer := builder.pool.Get() - defer builder.pool.Put(itemsBuffer) + listBuffer := bytes.NewBuffer(make([]byte, 0, minBufferSize)) + + // We will always have a message type here, therefore we must use the bytes type. + writeTag(listBuffer, field.desc.Number(), protowire.BytesType) + + itemsBuffer := bytes.NewBuffer(make([]byte, 0, minBufferSize)) for i := range elements { iwf := wireField{ @@ -303,20 +274,20 @@ func (f *wireField) appendListFieldValue(builder *wireBuilder, buf *bytes.Buffer } iwf.tag = protowire.AppendTag(nil, iwf.number, iwf.wireType) - if err := iwf.appendListFieldValue(builder, itemsBuffer, elements[i], level); err != nil { + if err := iwf.appendListFieldValue(itemsBuffer, elements[i], level); err != nil { return err } } - listBuffer.Write(protowire.AppendBytes(nil, itemsBuffer.Bytes())) + writeLengthPrefixed(listBuffer, itemsBuffer.Bytes()) buf.Write(f.tag) - buf.Write(protowire.AppendBytes(nil, listBuffer.Bytes())) + writeLengthPrefixed(buf, listBuffer.Bytes()) return nil } -func (f *wireField) appendOptionalScalarFieldValue(builder *wireBuilder, buf *bytes.Buffer, data *astjson.Value) error { +func (f *wireField) appendOptionalScalarFieldValue(buf *bytes.Buffer, data *astjson.Value) error { if f.runtimeMessage == nil { return fmt.Errorf("runtime message not found for optional scalar field %s but was expected", f.jsonPath) } @@ -326,9 +297,6 @@ func (f *wireField) appendOptionalScalarFieldValue(builder *wireBuilder, buf *by return fmt.Errorf("wrapper field not found for message %s but was expected", f.runtimeMessage.name) } - fieldBuf := builder.pool.Get() - defer builder.pool.Put(fieldBuf) - wf := wireField{ number: wrapperField.desc.Number(), dataType: wrapperField.dataType, @@ -338,13 +306,14 @@ func (f *wireField) appendOptionalScalarFieldValue(builder *wireBuilder, buf *by } wf.tag = protowire.AppendTag(nil, wf.number, wf.wireType) - err := wf.appendFieldValue(builder, fieldBuf, data) - if err != nil { + + fieldBuf := bytes.NewBuffer(make([]byte, 0, minBufferSize)) + if err := wf.appendFieldValue(fieldBuf, data); err != nil { return err } buf.Write(f.tag) - buf.Write(protowire.AppendBytes(nil, fieldBuf.Bytes())) + writeLengthPrefixed(buf, fieldBuf.Bytes()) return nil } @@ -352,22 +321,21 @@ func (f *wireField) isOptionalScalar() bool { return f.optional && f.dataType != DataTypeMessage } -func (f *wireField) appendFieldValue(builder *wireBuilder, buf *bytes.Buffer, data *astjson.Value) error { +func (f *wireField) appendFieldValue(buf *bytes.Buffer, data *astjson.Value) error { if f.child != nil { - childWire, err := f.child.createProtoWire(builder, data) - if err != nil { + childBuf := bytes.NewBuffer(make([]byte, 0, minBufferSize)) + if err := f.child.appendProtoWire(childBuf, data); err != nil { return err } - buf.Write(f.tag) - buf.Write(protowire.AppendBytes(nil, childWire)) + writeLengthPrefixed(buf, childBuf.Bytes()) return nil } switch f.wireType { case protowire.BytesType: buf.Write(f.tag) - buf.Write(protowire.AppendBytes(nil, data.GetStringBytes())) + writeLengthPrefixed(buf, data.GetStringBytes()) case protowire.VarintType: value := data.GetUint64() if f.runtimeEnum != nil { @@ -378,10 +346,10 @@ func (f *wireField) appendFieldValue(builder *wireBuilder, buf *bytes.Buffer, da } buf.Write(f.tag) - buf.Write(protowire.AppendVarint(nil, value)) + writeVarint(buf, value) case protowire.Fixed64Type: buf.Write(f.tag) - buf.Write(protowire.AppendFixed64(nil, math.Float64bits(data.GetFloat64()))) + writeFixed64(buf, math.Float64bits(data.GetFloat64())) default: return fmt.Errorf("unsupported wire type %d", f.wireType) } diff --git a/v2/pkg/engine/datasource/grpc_datasource/wire_test.go b/v2/pkg/engine/datasource/grpc_datasource/wire_test.go index 76f113edc..b5a61f2ae 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/wire_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/wire_test.go @@ -357,9 +357,7 @@ func TestCreateProtoWire(t *testing.T) { Fields: nil, }) - builder := newWireBuilder(minBufferSize) - - got, err := wm.createProtoWire(builder, astjson.MustParse(`{}`)) + got, err := wm.createProtoWire(astjson.MustParse(`{}`)) require.NoError(t, err) expected := marshalDynamic(t, runtime, "EmptyRequest", func(msg *dynamicpb.Message, desc protoref.MessageDescriptor) {}) @@ -375,9 +373,7 @@ func TestCreateProtoWire(t *testing.T) { }, }) - builder := newWireBuilder(minBufferSize) - - got, err := wm.createProtoWire(builder, astjson.MustParse(`{"name":"hello"}`)) + got, err := wm.createProtoWire(astjson.MustParse(`{"name":"hello"}`)) require.NoError(t, err) expected := marshalDynamic(t, runtime, "ScalarRequest", func(msg *dynamicpb.Message, desc protoref.MessageDescriptor) { @@ -397,9 +393,7 @@ func TestCreateProtoWire(t *testing.T) { }, }) - builder := newWireBuilder(minBufferSize) - - got, err := wm.createProtoWire(builder, astjson.MustParse(`{"name":"alice","age":30,"score":99.5}`)) + got, err := wm.createProtoWire(astjson.MustParse(`{"name":"alice","age":30,"score":99.5}`)) require.NoError(t, err) expected := marshalDynamic(t, runtime, "ScalarRequest", func(msg *dynamicpb.Message, desc protoref.MessageDescriptor) { @@ -421,9 +415,7 @@ func TestCreateProtoWire(t *testing.T) { }, }) - builder := newWireBuilder(minBufferSize) - - got, err := wm.createProtoWire(builder, astjson.MustParse(`{"name":"hello"}`)) + got, err := wm.createProtoWire(astjson.MustParse(`{"name":"hello"}`)) require.NoError(t, err) expected := marshalDynamic(t, runtime, "WrapperScalarRequest", func(msg *dynamicpb.Message, desc protoref.MessageDescriptor) { @@ -441,9 +433,7 @@ func TestCreateProtoWire(t *testing.T) { }, }) - builder := newWireBuilder(minBufferSize) - - got, err := wm.createProtoWire(builder, astjson.MustParse(`{}`)) + got, err := wm.createProtoWire(astjson.MustParse(`{}`)) require.NoError(t, err) expected := marshalDynamic(t, runtime, "WrapperScalarRequest", func(msg *dynamicpb.Message, desc protoref.MessageDescriptor) { @@ -462,9 +452,7 @@ func TestCreateProtoWire(t *testing.T) { }, }) - builder := newWireBuilder(minBufferSize) - - got, err := wm.createProtoWire(builder, astjson.MustParse(`{"age":25,"score":3.14}`)) + got, err := wm.createProtoWire(astjson.MustParse(`{"age":25,"score":3.14}`)) require.NoError(t, err) expected := marshalDynamic(t, runtime, "WrapperScalarRequest", func(msg *dynamicpb.Message, desc protoref.MessageDescriptor) { @@ -483,9 +471,7 @@ func TestCreateProtoWire(t *testing.T) { }, }) - builder := newWireBuilder(minBufferSize) - - got, err := wm.createProtoWire(builder, astjson.MustParse(`{"tags":["foo","bar","baz"]}`)) + got, err := wm.createProtoWire(astjson.MustParse(`{"tags":["foo","bar","baz"]}`)) require.NoError(t, err) expected := marshalDynamic(t, runtime, "RepeatedScalarRequest", func(msg *dynamicpb.Message, desc protoref.MessageDescriptor) { @@ -506,9 +492,7 @@ func TestCreateProtoWire(t *testing.T) { }, }) - builder := newWireBuilder(minBufferSize) - - got, err := wm.createProtoWire(builder, astjson.MustParse(`{"scores":[1,2,3]}`)) + got, err := wm.createProtoWire(astjson.MustParse(`{"scores":[1,2,3]}`)) require.NoError(t, err) expected := marshalDynamic(t, runtime, "RepeatedScalarRequest", func(msg *dynamicpb.Message, desc protoref.MessageDescriptor) { @@ -538,9 +522,7 @@ func TestCreateProtoWire(t *testing.T) { }, }) - builder := newWireBuilder(minBufferSize) - - got, err := wm.createProtoWire(builder, astjson.MustParse(`{"item":{"id":"1","value":"a"}}`)) + got, err := wm.createProtoWire(astjson.MustParse(`{"item":{"id":"1","value":"a"}}`)) require.NoError(t, err) expected := marshalDynamic(t, runtime, "NestedMessageRequest", func(msg *dynamicpb.Message, desc protoref.MessageDescriptor) { @@ -571,9 +553,7 @@ func TestCreateProtoWire(t *testing.T) { }, }) - builder := newWireBuilder(minBufferSize) - - got, err := wm.createProtoWire(builder, astjson.MustParse(`{"items":[{"id":"1","value":"a"},{"id":"2","value":"b"}]}`)) + got, err := wm.createProtoWire(astjson.MustParse(`{"items":[{"id":"1","value":"a"},{"id":"2","value":"b"}]}`)) require.NoError(t, err) expected := marshalDynamic(t, runtime, "NestedMessageRequest", func(msg *dynamicpb.Message, desc protoref.MessageDescriptor) { @@ -605,9 +585,7 @@ func TestCreateProtoWire(t *testing.T) { // GraphQL sends enum values as strings (e.g. "ACTIVE"), not proto-prefixed names or integers. // The wire builder must resolve "ACTIVE" -> STATUS_ACTIVE = 1 via the runtime enum map. - builder := newWireBuilder(minBufferSize) - - got, err := wm.createProtoWire(builder, astjson.MustParse(`{"status":"ACTIVE"}`)) + got, err := wm.createProtoWire(astjson.MustParse(`{"status":"ACTIVE"}`)) require.NoError(t, err) expected := marshalDynamic(t, runtime, "EnumRequest", func(msg *dynamicpb.Message, desc protoref.MessageDescriptor) { @@ -625,9 +603,7 @@ func TestCreateProtoWire(t *testing.T) { }, }) - builder := newWireBuilder(minBufferSize) - - got, err := wm.createProtoWire(builder, astjson.MustParse(`{"statuses":["UNSPECIFIED","ACTIVE","INACTIVE"]}`)) + got, err := wm.createProtoWire(astjson.MustParse(`{"statuses":["UNSPECIFIED","ACTIVE","INACTIVE"]}`)) require.NoError(t, err) expected := marshalDynamic(t, runtime, "EnumRequest", func(msg *dynamicpb.Message, desc protoref.MessageDescriptor) { @@ -664,9 +640,7 @@ func TestCreateProtoWire(t *testing.T) { require.NoError(t, compileErr) - builder := newWireBuilder(minBufferSize) - - got, err := wm.createProtoWire(builder, astjson.MustParse(`{"optionalTags":["a","b"]}`)) + got, err := wm.createProtoWire(astjson.MustParse(`{"optionalTags":["a","b"]}`)) require.NoError(t, err) expected := marshalDynamic(t, runtime, "ListWrapperRequest", func(msg *dynamicpb.Message, desc protoref.MessageDescriptor) { @@ -714,9 +688,7 @@ func TestCreateProtoWire(t *testing.T) { require.NoError(t, compileErr) - builder := newWireBuilder(minBufferSize) - - got, err := wm.createProtoWire(builder, astjson.MustParse(`{"tagGroups":[["a","b"],["c"]]}`)) + got, err := wm.createProtoWire(astjson.MustParse(`{"tagGroups":[["a","b"],["c"]]}`)) require.NoError(t, err) expected := marshalDynamic(t, runtime, "NestedListRequest", func(msg *dynamicpb.Message, desc protoref.MessageDescriptor) { @@ -786,9 +758,7 @@ func TestCreateProtoWire(t *testing.T) { require.NoError(t, compileErr) - builder := newWireBuilder(minBufferSize) - - got, err := wm.createProtoWire(builder, astjson.MustParse(`{"itemGroups":[[{"id":"1","value":"a"},{"id":"2","value":"b"}],[{"id":"3","value":"c"}]]}`)) + got, err := wm.createProtoWire(astjson.MustParse(`{"itemGroups":[[{"id":"1","value":"a"},{"id":"2","value":"b"}],[{"id":"3","value":"c"}]]}`)) require.NoError(t, err) expected := marshalDynamic(t, runtime, "NestedListRequest", func(msg *dynamicpb.Message, desc protoref.MessageDescriptor) { @@ -850,9 +820,7 @@ func TestCreateProtoWire(t *testing.T) { }, }) - builder := newWireBuilder(minBufferSize) - - got, err := wm.createProtoWire(builder, astjson.MustParse(`{"id":"p1","description":"a product","tags":["sale","new"],"price":29.99,"status":"ACTIVE"}`)) + got, err := wm.createProtoWire(astjson.MustParse(`{"id":"p1","description":"a product","tags":["sale","new"],"price":29.99,"status":"ACTIVE"}`)) require.NoError(t, err) expected := marshalDynamic(t, runtime, "MixedRequest", func(msg *dynamicpb.Message, desc protoref.MessageDescriptor) { @@ -888,9 +856,7 @@ func TestCreateProtoWire(t *testing.T) { }, }) - builder := newWireBuilder(minBufferSize) - - got, err := wm.createProtoWire(builder, astjson.MustParse(`{"representations":[{"__typename":"Product","id":"1"}]}`)) + got, err := wm.createProtoWire(astjson.MustParse(`{"representations":[{"__typename":"Product","id":"1"}]}`)) require.NoError(t, err) expected := marshalDynamic(t, runtime, "LookupProductByIdRequest", func(msg *dynamicpb.Message, desc protoref.MessageDescriptor) { @@ -926,9 +892,7 @@ func TestCreateProtoWire(t *testing.T) { }, }) - builder := newWireBuilder(minBufferSize) - - got, err := wm.createProtoWire(builder, astjson.MustParse(`{"representations":[{"__typename":"Product","id":"1"},{"__typename":"Product","id":"2"},{"__typename":"Product","id":"3"}]}`)) + got, err := wm.createProtoWire(astjson.MustParse(`{"representations":[{"__typename":"Product","id":"1"},{"__typename":"Product","id":"2"},{"__typename":"Product","id":"3"}]}`)) require.NoError(t, err) expected := marshalDynamic(t, runtime, "LookupProductByIdRequest", func(msg *dynamicpb.Message, desc protoref.MessageDescriptor) { From a4e810db4c48aff81d8a79c768d0b2f307ce4b51 Mon Sep 17 00:00:00 2001 From: Ludwig Bedacht Date: Wed, 22 Apr 2026 16:12:43 +0200 Subject: [PATCH 05/12] chore: register codec --- v2/pkg/engine/datasource/grpc_datasource/codec.go | 4 ++++ .../engine/datasource/grpc_datasource/grpc_datasource.go | 7 +++---- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/v2/pkg/engine/datasource/grpc_datasource/codec.go b/v2/pkg/engine/datasource/grpc_datasource/codec.go index 2d2c91153..712a8f086 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/codec.go +++ b/v2/pkg/engine/datasource/grpc_datasource/codec.go @@ -13,6 +13,10 @@ var defaultCodec = encoding.GetCodecV2("proto") type connectCodec struct{} +func init() { + encoding.RegisterCodecV2(&connectCodec{}) +} + // Name implements [encoding.CodecV2]. func (c *connectCodec) Name() string { // we use the default proto codec to allow marshalling our own message but not diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go index 88e1a1f5d..9d0d8dadf 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go @@ -168,13 +168,12 @@ func (d *DataSource) Load(ctx context.Context, headers http.Header, input []byte for _, fetch := range stage.fetches { responseMessage := dynamicpb.NewMessage(fetch.response.responseType.desc) - d.wireBuf.Reset() - if err = fetch.request.wire.appendProtoWire(&d.wireBuf, astJsonVariables); err != nil { + buffer, err := fetch.request.wire.createProtoWire(astJsonVariables) + if err != nil { return nil, err } - pm := NewPreWiredInputMessage(d.wireBuf.Bytes()) - err = d.cc.Invoke(ctx, fetch.methodFullName, pm, responseMessage, d.codecOpt) + err = d.cc.Invoke(ctx, fetch.methodFullName, NewPreWiredInputMessage(buffer), responseMessage) if err != nil { return nil, err } From 9e2693ab9c1709738690e0632b4efb7e1fb598d0 Mon Sep 17 00:00:00 2001 From: Ludwig Bedacht Date: Thu, 23 Apr 2026 14:36:00 +0200 Subject: [PATCH 06/12] chore: improve program compilation design --- v2/pkg/ast/path.go | 8 + .../datasource/grpc_datasource/entity.go | 49 ++- .../datasource/grpc_datasource/entity_test.go | 29 +- .../grpc_datasource/execution_plan.go | 4 + .../execution_plan_field_resolvers_test.go | 5 +- .../grpc_datasource/grpc_datasource.go | 320 ++++++++++-------- .../grpc_datasource/grpc_datasource_test.go | 32 +- .../grpc_datasource/json_builder.go | 4 +- .../datasource/grpc_datasource/program.go | 246 ++++++++++++-- .../grpc_datasource/program_test.go | 8 +- .../grpc_datasource/runtime_test.go | 8 +- .../engine/datasource/grpc_datasource/wire.go | 207 +++++++---- .../datasource/grpc_datasource/wire_test.go | 162 +++++---- 13 files changed, 705 insertions(+), 377 deletions(-) diff --git a/v2/pkg/ast/path.go b/v2/pkg/ast/path.go index 19493ec60..392705fa2 100644 --- a/v2/pkg/ast/path.go +++ b/v2/pkg/ast/path.go @@ -128,6 +128,14 @@ func (p Path) String() string { return out } +func (p Path) ToPathItemStrings() []string { + out := make([]string, len(p)) + for i := range p { + out[i] = unsafebytes.BytesToString(p[i].FieldName) + } + return out +} + func (p Path) DotDelimitedString() string { builder := strings.Builder{} diff --git a/v2/pkg/engine/datasource/grpc_datasource/entity.go b/v2/pkg/engine/datasource/grpc_datasource/entity.go index 850bb0606..f2a2e79ca 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/entity.go +++ b/v2/pkg/engine/datasource/grpc_datasource/entity.go @@ -4,9 +4,8 @@ import ( "errors" "fmt" - "github.com/tidwall/gjson" - "github.com/wundergraph/astjson" + "github.com/wundergraph/go-arena" ) // entityIndexMap maps positions in the typed gRPC response back to positions @@ -19,25 +18,55 @@ type entityIndexMap []int // newEntityIndexMap builds the index map for a single entity call by collecting // the positions of representations whose __typename matches the requested type. // A single pass over representations populates the slice. -func newEntityIndexMap(requestedEntityType string, representations []gjson.Result) entityIndexMap { +func newEntityIndexMap(requestedEntityType string, representations []*astjson.Value) entityIndexMap { indexMap := make(entityIndexMap, 0, len(representations)) for i, representation := range representations { - if representation.Get(typenameFieldName).String() == requestedEntityType { + if string(representation.Get(typenameFieldName).GetStringBytes()) == requestedEntityType { indexMap = append(indexMap, i) } } return indexMap } -// getRepresentations gets the representations from the variables. -// If no representations are found, it returns nil. -func getRepresentations(variables gjson.Result) []gjson.Result { +// getRepresentationsAST gets the representations from the variables. +// If no representations are found, it returns an empty slice. +func getRepresentations(variables *astjson.Value) []*astjson.Value { + r := variables.Get("representations") + if !r.Exists() { + return nil + } + + arr := r.GetArray() + if len(arr) == 0 { + return make([]*astjson.Value, 0) + } + + return arr +} + +// filterRepresentations filters the representations to only include the ones of the requested entity type. +func filterRepresentations(arena arena.Arena, variables *astjson.Value, requestedEntityType string) *astjson.Value { r := variables.Get("representations") if !r.Exists() { return nil } - return r.Array() + representations := r.GetArray() + if len(representations) == 0 { + return nil + } + + ov := astjson.ObjectValue(arena) + representationsArr := astjson.ArrayValue(arena) + + for _, representation := range representations { + if string(representation.Get(typenameFieldName).GetStringBytes()) == requestedEntityType { + representationsArr.SetArrayItem(arena, len(representationsArr.GetArray()), representation) + } + } + + ov.Set(arena, "representations", representationsArr) + return ov } // validateEntityResponse verifies that the number of entities returned by the @@ -45,7 +74,7 @@ func getRepresentations(variables gjson.Result) []gjson.Result { // Callers should subsequently build an entityIndexMap via newEntityIndexMap to // merge the response — mergeEntities relies on the invariant that // len(response entities) == len(indexMap), which this function establishes. -func validateEntityResponse(data *astjson.Value, requestedEntityType string, representations []gjson.Result) error { +func validateEntityResponse(data *astjson.Value, requestedEntityType string, representations []*astjson.Value) error { if data == nil { return errors.New("validateEntityResponse: subgraph response data is nil") } @@ -60,7 +89,7 @@ func validateEntityResponse(data *astjson.Value, requestedEntityType string, rep expected := 0 for _, representation := range representations { - if representation.Get(typenameFieldName).String() == requestedEntityType { + if string(representation.Get(typenameFieldName).GetStringBytes()) == requestedEntityType { expected++ } } diff --git a/v2/pkg/engine/datasource/grpc_datasource/entity_test.go b/v2/pkg/engine/datasource/grpc_datasource/entity_test.go index d9d7e10e3..3e1c3d607 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/entity_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/entity_test.go @@ -4,14 +4,13 @@ import ( "testing" "github.com/stretchr/testify/assert" - "github.com/tidwall/gjson" "github.com/wundergraph/astjson" ) func TestNewEntityIndexMap(t *testing.T) { t.Run("returns empty map when no representations match", func(t *testing.T) { - reps := getRepresentations(gjson.Parse(`{"representations":[ + reps := getRepresentations(astjson.MustParse(`{"representations":[ {"__typename":"Storage","id":"1"} ]}`)) idx := newEntityIndexMap("Product", reps) @@ -24,7 +23,7 @@ func TestNewEntityIndexMap(t *testing.T) { }) t.Run("ordered representations [Product, Product, Storage, Storage]", func(t *testing.T) { - reps := getRepresentations(gjson.Parse(`{"representations":[ + reps := getRepresentations(astjson.MustParse(`{"representations":[ {"__typename":"Product","id":"1"}, {"__typename":"Product","id":"2"}, {"__typename":"Storage","id":"3"}, @@ -39,7 +38,7 @@ func TestNewEntityIndexMap(t *testing.T) { }) t.Run("unordered representations [Product, Storage, Product, Storage]", func(t *testing.T) { - reps := getRepresentations(gjson.Parse(`{"representations":[ + reps := getRepresentations(astjson.MustParse(`{"representations":[ {"__typename":"Product","id":"1"}, {"__typename":"Storage","id":"2"}, {"__typename":"Product","id":"3"}, @@ -54,7 +53,7 @@ func TestNewEntityIndexMap(t *testing.T) { }) t.Run("interleaved representations across three types", func(t *testing.T) { - reps := getRepresentations(gjson.Parse(`{"representations":[ + reps := getRepresentations(astjson.MustParse(`{"representations":[ {"__typename":"Product","id":"1"}, {"__typename":"Storage","id":"2"}, {"__typename":"Warehouse","id":"3"}, @@ -69,7 +68,7 @@ func TestNewEntityIndexMap(t *testing.T) { }) t.Run("single matching representation", func(t *testing.T) { - reps := getRepresentations(gjson.Parse(`{"representations":[ + reps := getRepresentations(astjson.MustParse(`{"representations":[ {"__typename":"Storage","id":"1"}, {"__typename":"Product","id":"2"}, {"__typename":"Storage","id":"3"} @@ -79,7 +78,7 @@ func TestNewEntityIndexMap(t *testing.T) { }) t.Run("preserves original positions for fully matching list", func(t *testing.T) { - reps := getRepresentations(gjson.Parse(`{"representations":[ + reps := getRepresentations(astjson.MustParse(`{"representations":[ {"__typename":"Product","id":"1"}, {"__typename":"Product","id":"2"}, {"__typename":"Product","id":"3"} @@ -92,7 +91,7 @@ func TestNewEntityIndexMap(t *testing.T) { // Interface-entity representations carry the interface name as __typename // (e.g. "Resource"). The index map cares only about the typename string, // not whether it refers to an interface or a concrete type. - reps := getRepresentations(gjson.Parse(`{"representations":[ + reps := getRepresentations(astjson.MustParse(`{"representations":[ {"__typename":"Resource","id":"1"}, {"__typename":"Product","id":"2"}, {"__typename":"Resource","id":"3"}, @@ -109,27 +108,27 @@ func TestNewEntityIndexMap(t *testing.T) { func TestGetRepresentations(t *testing.T) { t.Run("returns nil when representations key missing", func(t *testing.T) { - vars := gjson.Parse(`{"other":"value"}`) + vars := astjson.MustParse(`{"other":"value"}`) assert.Nil(t, getRepresentations(vars)) }) t.Run("returns empty slice when representations is empty array", func(t *testing.T) { - vars := gjson.Parse(`{"representations":[]}`) + vars := astjson.MustParse(`{"representations":[]}`) reps := getRepresentations(vars) assert.NotNil(t, reps) assert.Empty(t, reps) }) t.Run("returns representations when present", func(t *testing.T) { - vars := gjson.Parse(`{"representations":[{"__typename":"Product","id":"1"},{"__typename":"Storage","id":"2"}]}`) + vars := astjson.MustParse(`{"representations":[{"__typename":"Product","id":"1"},{"__typename":"Storage","id":"2"}]}`) reps := getRepresentations(vars) assert.Len(t, reps, 2) - assert.Equal(t, "Product", reps[0].Get("__typename").String()) - assert.Equal(t, "Storage", reps[1].Get("__typename").String()) + assert.Equal(t, "Product", string(reps[0].Get("__typename").GetStringBytes())) + assert.Equal(t, "Storage", string(reps[1].Get("__typename").GetStringBytes())) }) } func TestValidateEntityResponse(t *testing.T) { - reps := getRepresentations(gjson.Parse(`{"representations":[ + reps := getRepresentations(astjson.MustParse(`{"representations":[ {"__typename":"Product","id":"1"}, {"__typename":"Product","id":"2"} ]}`)) @@ -163,7 +162,7 @@ func TestValidateEntityResponse(t *testing.T) { }) t.Run("counts only representations of the requested type", func(t *testing.T) { - mixedReps := getRepresentations(gjson.Parse(`{"representations":[ + mixedReps := getRepresentations(astjson.MustParse(`{"representations":[ {"__typename":"Product","id":"1"}, {"__typename":"Storage","id":"2"}, {"__typename":"Product","id":"3"} diff --git a/v2/pkg/engine/datasource/grpc_datasource/execution_plan.go b/v2/pkg/engine/datasource/grpc_datasource/execution_plan.go index ed3e9d305..a5db0a3dd 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/execution_plan.go +++ b/v2/pkg/engine/datasource/grpc_datasource/execution_plan.go @@ -907,6 +907,7 @@ type resolverField struct { fieldsSelectionSetRef int responsePath ast.Path contextPath ast.Path + contextJSONRoot string contextFields []contextField fieldArguments []fieldArgument @@ -1010,6 +1011,8 @@ func (r *rpcPlanningContext) setResolvedField(walker *astvisitor.Walker, fieldDe return err } + resolvedField.contextJSONRoot = string(walker.Path[1].FieldName) + for _, contextFieldRef := range contextFields { mapping := r.resolveFieldMapping( walker.EnclosingTypeDefinition.NameString(r.definition), @@ -1620,6 +1623,7 @@ func (r *rpcPlanningContext) newResolveRPCCall(config *resolveRPCCallConfig) (RP { Name: contextFieldName, ProtoTypeName: DataTypeMessage, + JSONPath: resolvedField.contextJSONRoot, Repeated: true, Message: config.contextMessage, }, diff --git a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_field_resolvers_test.go b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_field_resolvers_test.go index 3b7ea647d..07f7ea985 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_field_resolvers_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_field_resolvers_test.go @@ -3902,6 +3902,7 @@ func TestExecutionPlanFieldResolvers_CustomSchemas(t *testing.T) { Name: "context", ProtoTypeName: DataTypeMessage, Repeated: true, + JSONPath: "foo", Message: &RPCMessage{ Name: "ResolveFooFooResolverContext", Fields: []RPCField{ @@ -4373,13 +4374,13 @@ func TestExecutionPlan_FederationFieldResolvers(t *testing.T) { Name: "id", ProtoTypeName: DataTypeString, JSONPath: "id", - ResolvePath: buildPath("result.id"), + ResolvePath: buildPath("_entities.id"), }, { Name: "price", ProtoTypeName: DataTypeDouble, JSONPath: "price", - ResolvePath: buildPath("result.price"), + ResolvePath: buildPath("_entities.price"), }, }, }, diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go index 9d0d8dadf..a4986f06f 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go @@ -16,7 +16,6 @@ import ( "github.com/cespare/xxhash/v2" "github.com/tidwall/gjson" - "golang.org/x/sync/errgroup" "google.golang.org/grpc" "google.golang.org/grpc/metadata" "google.golang.org/protobuf/types/dynamicpb" @@ -28,10 +27,9 @@ import ( "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/plan" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" - "github.com/wundergraph/graphql-go-tools/v2/pkg/internal/unsafebytes" ) -type resultData struct { +type fetchData struct { kind CallKind response *astjson.Value responsePath ast.Path @@ -53,11 +51,10 @@ type DataSource struct { definition *ast.Document disabled bool - pool *arena.Pool - program *program - codecOpt grpc.CallOption - wireBuf bytes.Buffer - resultsBuf []resultData + pool *arena.Pool + program *program + codecOpt grpc.CallOption + wireBuf bytes.Buffer } type ProtoConfig struct { @@ -136,8 +133,6 @@ func (d *DataSource) Load(ctx context.Context, headers http.Header, input []byte variables := gjson.ParseBytes(input).Get("body.variables") - _ = astJsonVariables - builder := newJSONBuilder(item.Arena, d.mapping, variables) if d.disabled { @@ -159,45 +154,102 @@ func (d *DataSource) Load(ctx context.Context, headers http.Header, input []byte root := astjson.ObjectValue(nil) + callMap := make(map[int]fetchData) + + representations := getRepresentations(astJsonVariables) for _, stage := range d.program.stages { - results := d.resultsBuf[:0] - if cap(results) < len(stage.fetches) { - results = make([]resultData, 0, len(stage.fetches)) - } + results := make([]fetchData, 0, len(stage.fetches)) for _, fetch := range stage.fetches { + // TODO: unmarshal with our own codec logic responseMessage := dynamicpb.NewMessage(fetch.response.responseType.desc) - buffer, err := fetch.request.wire.createProtoWire(astJsonVariables) + requestVariables := astJsonVariables + if fetch.requestedEntityType != "" { + requestVariables = filterRepresentations(item.Arena, requestVariables, fetch.requestedEntityType) + } + + // if fetch.dependentCall != nil { + // requestVariables = astjson.DeepCopy(item.Arena, astJsonVariables) + + // call, found := callMap[fetch.dependentCall.ID] + // if !found { + // return nil, fmt.Errorf("dependent call %d not found", fetch.dependentCall.ID) + // } + + // contextField := fetch.request.rpcMessage.Fields.ByName(contextFieldName) + // if contextField == nil { + // return nil, fmt.Errorf("context field not found in dependent call %d", fetch.dependentCall.ID) + // } + + // contextValue := call.response.Get(contextField.JSONPath) + // if !contextValue.Exists() { + // return nil, fmt.Errorf("context value not found in dependent call %d", fetch.dependentCall.ID) + // } + + // var contextData []*astjson.Value + // if contextValue.Type() == astjson.TypeArray { + // contextData = contextValue.GetArray() + // } else { + // contextData = []*astjson.Value{contextValue} + // } + + // ov := astjson.ObjectValue(item.Arena) + // contextArr := astjson.ArrayValue(item.Arena) + // for i, data := range contextData { + // contextArr.SetArrayItem(item.Arena, i, data) + // } + // ov.Set(item.Arena, contextField.JSONPath, contextArr) + + // requestVariables, _, err = astjson.MergeValues(item.Arena, requestVariables, ov) + // if err != nil { + // return nil, err + // } + // } + + buffer, err := fetch.request.createProtoWire(requestVariables) if err != nil { return nil, err } err = d.cc.Invoke(ctx, fetch.methodFullName, NewPreWiredInputMessage(buffer), responseMessage) if err != nil { - return nil, err + return builder.writeErrorBytes(err), nil } responseJson, err := builder.marshalResponseJSON(&fetch.response.rpcMessage, responseMessage) if err != nil { - return nil, err + return builder.writeErrorBytes(err), nil } - results = append(results, resultData{ + fetchResult := fetchData{ kind: fetch.kind, response: responseJson, responsePath: fetch.responsePath, - }) - } + } - d.resultsBuf = results + // In case of a federated response, we need to ensure that the response is valid. + // The number of entities per type must match the number of lookup keys in the variables. + // On success we build the index map used by mergeEntities to place each response + // entity at the correct position in the final _entities array. + if fetch.kind == CallKindEntity { + if err := validateEntityResponse(responseJson, fetch.requestedEntityType, representations); err != nil { + return builder.writeErrorBytes(err), nil + } + + fetchResult.entityIndexMap = newEntityIndexMap(fetch.requestedEntityType, representations) + } + + results = append(results, fetchResult) + callMap[fetch.id] = fetchResult + } for _, result := range results { switch result.kind { case CallKindResolve, CallKindRequired: err = builder.mergeWithPath(root, result.response, result.responsePath) default: - root, err = builder.mergeValues(root, result.response) + root, err = builder.mergeValues(root, result) } if err != nil { return builder.writeErrorBytes(err), nil @@ -231,115 +283,115 @@ func (d *DataSource) LoadWithFiles(ctx context.Context, headers http.Header, inp panic("unimplemented") } -func (d *DataSource) LoadOld(ctx context.Context, headers http.Header, input []byte) (data []byte, err error) { - // get variables from input - variables := gjson.Parse(unsafebytes.BytesToString(input)).Get("body.variables") - - var poolItems []*arena.PoolItem - defer func() { - d.pool.ReleaseMany(poolItems) - }() - - item := d.acquirePoolItem(input, 0) - poolItems = append(poolItems, item) - - builder := newJSONBuilder(item.Arena, d.mapping, variables) - - if d.disabled { - return builder.writeErrorBytes(fmt.Errorf("gRPC datasource needs to be enabled to be used")), nil - } - - // convert headers to grpc metadata and attach to ctx - if len(headers) > 0 { - // assume that each header has exactly one value for default pairs size - pairs := make([]string, 0, len(headers)*2) - for headerName, headerValues := range headers { - headerName = strings.ToLower(headerName) - for _, v := range headerValues { - pairs = append(pairs, headerName, v) - } - } - ctx = metadata.AppendToOutgoingContext(ctx, pairs...) - } - - graph := NewDependencyGraph(d.plan) - - root := astjson.ObjectValue(nil) - - representations := getRepresentations(variables) - if err := graph.TopologicalSortResolve(func(nodes []FetchItem) error { - // TODO: Compile fetches should be converted to a program. - // The program defines all the fetches that need to be executed in parallel for a given query. - - serviceCalls, err := d.rc.CompileFetches(graph, nodes, variables) - if err != nil { - return err - } - - results := make([]resultData, len(serviceCalls)) - errGrp, errGrpCtx := errgroup.WithContext(ctx) - - // make gRPC calls - for index, serviceCall := range serviceCalls { - item := d.acquirePoolItem(input, index) - poolItems = append(poolItems, item) - - builder := newJSONBuilder(item.Arena, d.mapping, variables) - errGrp.Go(func() error { - // Invoke the gRPC method - this will populate serviceCall.Output - err := d.cc.Invoke(errGrpCtx, serviceCall.MethodFullName(), serviceCall.Input, serviceCall.Output) - if err != nil { - return err - } - - response, err := builder.marshalResponseJSON(&serviceCall.RPC.Response, serviceCall.Output) - if err != nil { - return err - } - - results[index] = resultData{ - kind: serviceCall.RPC.Kind, - response: response, - responsePath: serviceCall.RPC.ResponsePath, - } - - // In case of a federated response, we need to ensure that the response is valid. - // The number of entities per type must match the number of lookup keys in the variables. - // On success we build the index map used by mergeEntities to place each response - // entity at the correct position in the final _entities array. - if serviceCall.RPC.Kind == CallKindEntity { - if err := validateEntityResponse(response, serviceCall.RPC.RequestedEntityType, representations); err != nil { - return err - } - - results[index].entityIndexMap = newEntityIndexMap(serviceCall.RPC.RequestedEntityType, representations) - } - - return nil - }) - } - - if err := errGrp.Wait(); err != nil { - return err - } - - for _, result := range results { - switch result.kind { - case CallKindResolve, CallKindRequired: - err = builder.mergeWithPath(root, result.response, result.responsePath) - default: - root, err = builder.mergeValues(root, result) - } - if err != nil { - return err - } - } - - return nil - }); err != nil { - return builder.writeErrorBytes(err), nil - } - - value := builder.toDataObject(root) - return value.MarshalTo(nil), err -} +// func (d *DataSource) LoadOld(ctx context.Context, headers http.Header, input []byte) (data []byte, err error) { +// // get variables from input +// variables := gjson.Parse(unsafebytes.BytesToString(input)).Get("body.variables") + +// var poolItems []*arena.PoolItem +// defer func() { +// d.pool.ReleaseMany(poolItems) +// }() + +// item := d.acquirePoolItem(input, 0) +// poolItems = append(poolItems, item) + +// builder := newJSONBuilder(item.Arena, d.mapping, variables) + +// if d.disabled { +// return builder.writeErrorBytes(fmt.Errorf("gRPC datasource needs to be enabled to be used")), nil +// } + +// // convert headers to grpc metadata and attach to ctx +// if len(headers) > 0 { +// // assume that each header has exactly one value for default pairs size +// pairs := make([]string, 0, len(headers)*2) +// for headerName, headerValues := range headers { +// headerName = strings.ToLower(headerName) +// for _, v := range headerValues { +// pairs = append(pairs, headerName, v) +// } +// } +// ctx = metadata.AppendToOutgoingContext(ctx, pairs...) +// } + +// graph := NewDependencyGraph(d.plan) + +// root := astjson.ObjectValue(nil) + +// representations := getRepresentations(variables) +// if err := graph.TopologicalSortResolve(func(nodes []FetchItem) error { +// // TODO: Compile fetches should be converted to a program. +// // The program defines all the fetches that need to be executed in parallel for a given query. + +// serviceCalls, err := d.rc.CompileFetches(graph, nodes, variables) +// if err != nil { +// return err +// } + +// results := make([]resultData, len(serviceCalls)) +// errGrp, errGrpCtx := errgroup.WithContext(ctx) + +// // make gRPC calls +// for index, serviceCall := range serviceCalls { +// item := d.acquirePoolItem(input, index) +// poolItems = append(poolItems, item) + +// builder := newJSONBuilder(item.Arena, d.mapping, variables) +// errGrp.Go(func() error { +// // Invoke the gRPC method - this will populate serviceCall.Output +// err := d.cc.Invoke(errGrpCtx, serviceCall.MethodFullName(), serviceCall.Input, serviceCall.Output) +// if err != nil { +// return err +// } + +// response, err := builder.marshalResponseJSON(&serviceCall.RPC.Response, serviceCall.Output) +// if err != nil { +// return err +// } + +// results[index] = resultData{ +// kind: serviceCall.RPC.Kind, +// response: response, +// responsePath: serviceCall.RPC.ResponsePath, +// } + +// // In case of a federated response, we need to ensure that the response is valid. +// // The number of entities per type must match the number of lookup keys in the variables. +// // On success we build the index map used by mergeEntities to place each response +// // entity at the correct position in the final _entities array. +// if serviceCall.RPC.Kind == CallKindEntity { +// if err := validateEntityResponse(response, serviceCall.RPC.RequestedEntityType, representations); err != nil { +// return err +// } + +// results[index].entityIndexMap = newEntityIndexMap(serviceCall.RPC.RequestedEntityType, representations) +// } + +// return nil +// }) +// } + +// if err := errGrp.Wait(); err != nil { +// return err +// } + +// for _, result := range results { +// switch result.kind { +// case CallKindResolve, CallKindRequired: +// err = builder.mergeWithPath(root, result.response, result.responsePath) +// default: +// root, err = builder.mergeValues(root, result) +// } +// if err != nil { +// return err +// } +// } + +// return nil +// }); err != nil { +// return builder.writeErrorBytes(err), nil +// } + +// value := builder.toDataObject(root) +// return value.MarshalTo(nil), err +// } diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go index 9f70fbf4e..bedf12f59 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go @@ -377,37 +377,7 @@ func Test_DataSource_Load_WithMockService_WithResponseMapping(t *testing.T) { Definition: &schemaDoc, SubgraphName: "Products", Compiler: compiler, - Mapping: &GRPCMapping{ - Service: "Products", - QueryRPCs: RPCConfigMap[RPCConfig]{ - "complexFilterType": { - RPC: "QueryComplexFilterType", - Request: "QueryComplexFilterTypeRequest", - Response: "QueryComplexFilterTypeResponse", - }, - }, - Fields: map[string]FieldMap{ - "Query": { - "complexFilterType": { - TargetName: "complex_filter_type", - ArgumentMappings: map[string]string{ - "filter": "filter", - }, - }, - }, - "FilterType": { - "name": { - TargetName: "name", - }, - "filterField1": { - TargetName: "filter_field_1", - }, - "filterField2": { - TargetName: "filter_field_2", - }, - }, - }, - }, + Mapping: testMapping(), }) require.NoError(t, err) diff --git a/v2/pkg/engine/datasource/grpc_datasource/json_builder.go b/v2/pkg/engine/datasource/grpc_datasource/json_builder.go index 556371d51..1901ee8a3 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/json_builder.go +++ b/v2/pkg/engine/datasource/grpc_datasource/json_builder.go @@ -50,7 +50,7 @@ func newJSONBuilder(a arena.Arena, mapping *GRPCMapping, variables gjson.Result) // mergeValues combines two JSON values while preserving proper federation entity ordering. // This is a critical function for GraphQL federation where multiple subgraphs may // return entities that need to be merged in the correct order. -func (j *jsonBuilder) mergeValues(left *astjson.Value, right resultData) (*astjson.Value, error) { +func (j *jsonBuilder) mergeValues(left *astjson.Value, right fetchData) (*astjson.Value, error) { if right.kind != CallKindEntity { // No federation index map available - use simple merge // This path is taken for non-federated queries @@ -76,7 +76,7 @@ func (j *jsonBuilder) mergeValues(left *astjson.Value, right resultData) (*astjs // _entities array. On subsequent calls, left is the result of a previous // mergeEntities call and already holds the _entities array, so we mutate it // in place rather than copying every accumulated entity into a new array. -func (j *jsonBuilder) mergeEntities(left *astjson.Value, rightResult resultData) (*astjson.Value, error) { +func (j *jsonBuilder) mergeEntities(left *astjson.Value, rightResult fetchData) (*astjson.Value, error) { right := rightResult.response rightEntities := right.Get(entityPath).GetArray() diff --git a/v2/pkg/engine/datasource/grpc_datasource/program.go b/v2/pkg/engine/datasource/grpc_datasource/program.go index eac0063e9..43f2883cb 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/program.go +++ b/v2/pkg/engine/datasource/grpc_datasource/program.go @@ -3,6 +3,7 @@ package grpcdatasource import ( "fmt" + "github.com/wundergraph/astjson" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" ) @@ -11,35 +12,75 @@ type program struct { } type stage struct { - fetches []fetch + fetches []fetchProgram } -type fetch struct { - id int - kind CallKind - dependentCall *RPCCall - serviceName string - methodName string - methodFullName string - responsePath ast.Path - request *fetchRequest - response *fetchResponse +type fetchProgram struct { + id int + kind CallKind + dependentCall *RPCCall + serviceName string + methodName string + methodFullName string + responsePath ast.Path + request *request + response *response + requestedEntityType string } -type fetchRequest struct { - message *runtimeMessage - rpcMessage RPCMessage +type request struct { + message *programMessage + context *fetchRequestContext // The wire message will be created fromt the // request structure. wire *wireMessage } -type fetchResponse struct { +type programMessage struct { + name string + runtime *runtimeMessage + fields []programField +} + +type programField struct { + runtime *runtimeField + dataType DataType + jsonPath string + enumName string + staticValue string + optional bool + repeated bool + listMetadata *ListMetadata + child *programMessage +} + +type fetchRequestContext struct { + message *runtimeMessage + fields []fetchRequestContextField +} + +type fetchRequestContextField struct { + runtime *runtimeField + resolvePath resolvePath +} + +type resolvePath []*runtimeField + +type response struct { // response type is the type of the response message. responseType *runtimeMessage rpcMessage RPCMessage } +func (f *request) createProtoWire(requestVariables *astjson.Value) ([]byte, error) { + wire, err := f.wire.createProtoWire(requestVariables) + if err != nil { + return nil, err + } + + return wire, nil +} + func compileProgram(plan *RPCExecutionPlan, runtime *runtimeSchema) (*program, error) { stageIndexes, err := compileStageIndexes(plan) if err != nil { @@ -58,7 +99,7 @@ func compileProgram(plan *RPCExecutionPlan, runtime *runtimeSchema) (*program, e stages: make([]stage, stageCount), } - stageMap := make(map[int][]fetch, stageCount) + stageMap := make(map[int][]fetchProgram, stageCount) for i := range plan.Calls { call := &plan.Calls[i] @@ -86,45 +127,62 @@ func compileProgram(plan *RPCExecutionPlan, runtime *runtimeSchema) (*program, e return program, nil } -func compileFetch(call *RPCCall, runtime *runtimeSchema, dependentCall *RPCCall) (fetch, error) { +func compileFetch(call *RPCCall, runtime *runtimeSchema, dependentCall *RPCCall) (fetchProgram, error) { serviceName, ok := runtime.serviceNamesByMethod[call.MethodName] if !ok { - return fetch{}, fmt.Errorf("service name not found for method %s", call.MethodName) + return fetchProgram{}, fmt.Errorf("service name not found for method %s", call.MethodName) } - f := fetch{ - id: call.ID, - kind: call.Kind, - dependentCall: dependentCall, - serviceName: serviceName, - methodName: call.MethodName, - methodFullName: "/" + serviceName + "/" + call.MethodName, - responsePath: call.ResponsePath, + f := fetchProgram{ + id: call.ID, + kind: call.Kind, + dependentCall: dependentCall, + serviceName: serviceName, + methodName: call.MethodName, + methodFullName: "/" + serviceName + "/" + call.MethodName, + responsePath: call.ResponsePath, + requestedEntityType: call.RequestedEntityType, } requestMessage := runtime.getMessageByName(call.Request.Name) if requestMessage == nil { - return fetch{}, fmt.Errorf("request message not found for method %s", call.MethodName) + return fetchProgram{}, fmt.Errorf("request message not found for method %s", call.MethodName) } responseMessage := runtime.getMessageByName(call.Response.Name) if responseMessage == nil { - return fetch{}, fmt.Errorf("response message not found for method %s", call.MethodName) + return fetchProgram{}, fmt.Errorf("response message not found for method %s", call.MethodName) } - f.request = &fetchRequest{ - message: requestMessage, - rpcMessage: call.Request, - } - - f.response = &fetchResponse{ + f.response = &response{ responseType: responseMessage, rpcMessage: call.Response, } - wireMessage, err := compileWireMessage(runtime, &f.request.rpcMessage, requestMessage) + switch f.kind { + case CallKindStandard, CallKindEntity: + fetchRequest, err := compileFetchRequest(runtime, &call.Request, requestMessage) + if err != nil { + return fetchProgram{}, err + } + f.request = fetchRequest + + case CallKindResolve: + dependentMessage := runtime.getMessageByName(dependentCall.Response.Name) + if dependentMessage == nil { + return fetchProgram{}, fmt.Errorf("dependent message not found for method %s", dependentCall.MethodName) + } + + fetchRequest, err := compileFetchRequestWithContext(requestMessage, dependentMessage, call.Request) + if err != nil { + return fetchProgram{}, err + } + f.request = fetchRequest + } + + wireMessage, err := compileWireMessageFromRequest(runtime, f.request) if err != nil { - return fetch{}, err + return fetchProgram{}, err } f.request.wire = wireMessage @@ -132,6 +190,122 @@ func compileFetch(call *RPCCall, runtime *runtimeSchema, dependentCall *RPCCall) return f, nil } +func compileFetchRequest(runtime *runtimeSchema, rpcMessage *RPCMessage, message *runtimeMessage) (*request, error) { + requestMessage, err := compileMessage(runtime, rpcMessage, message) + if err != nil { + return nil, err + } + + return &request{ + message: requestMessage, + }, nil +} + +func compileMessage(runtime *runtimeSchema, rpcMessage *RPCMessage, rtMessage *runtimeMessage) (*programMessage, error) { + msg := &programMessage{ + name: rpcMessage.Name, + runtime: rtMessage, + fields: make([]programField, 0, len(rpcMessage.Fields)), + } + + for _, f := range rpcMessage.Fields { + rtFieldMessage := runtime.getMessageByName(rpcMessage.Name) + if rtFieldMessage == nil { + return nil, fmt.Errorf("message not found for name %s", f.Message.Name) + } + + runtimeField := rtFieldMessage.fieldsByName[f.Name] + if runtimeField == nil { + return nil, fmt.Errorf("field not found for name %s", f.Name) + } + + requestField, err := compileField(runtime, f, runtimeField) + if err != nil { + return nil, err + } + msg.fields = append(msg.fields, requestField) + } + + return msg, nil +} + +func compileField(runtime *runtimeSchema, rpcField RPCField, rtField *runtimeField) (programField, error) { + f := programField{ + runtime: rtField, + dataType: rpcField.ProtoTypeName, + jsonPath: rpcField.JSONPath, + enumName: rpcField.EnumName, + staticValue: rpcField.StaticValue, + optional: rpcField.Optional, + repeated: rpcField.Repeated, + listMetadata: rpcField.ListMetadata, + child: nil, + } + + if rpcField.Message != nil { + if rtField.message == nil { + return programField{}, fmt.Errorf("child message not found for name %s", rpcField.Message.Name) + } + + childMessage, err := compileMessage(runtime, rpcField.Message, rtField.message) + if err != nil { + return programField{}, err + } + + f.child = childMessage + } + + return f, nil +} + +func compileFetchRequestWithContext(message *runtimeMessage, dependentMessage *runtimeMessage, rpcMessage RPCMessage) (*request, error) { + request := &request{} + + // context and field_args + for _, field := range rpcMessage.Fields { + switch field.Name { + case "context": + contextField, found := message.fieldsByName[field.Name] + if !found { + return nil, fmt.Errorf("context message not found for method %s", rpcMessage.Name) + } + + fetchRequestContext, err := compileFetchRequestContext(contextField.message, dependentMessage, field.Message) + if err != nil { + return nil, err + } + + request.context = fetchRequestContext + case "field_args": + // wireMessage, err := compileWireMessage(field.Message, message) + // if err != nil { + // return nil, err + // } + + // request.wire = wireMessage + } + } + + return request, nil +} + +func compileFetchRequestContext(message, dependentMessage *runtimeMessage, rpcMessage *RPCMessage) (*fetchRequestContext, error) { + if message == nil || dependentMessage == nil { + return nil, fmt.Errorf("unable to compile fetch request context: message or dependent message is nil") + } + + if rpcMessage == nil { + return nil, fmt.Errorf("unable to compile fetch request context: rpc message is nil") + } + + fetchRequestContext := &fetchRequestContext{ + message: message, + fields: make([]fetchRequestContextField, 0, len(rpcMessage.Fields)), + } + + return fetchRequestContext, nil +} + func compileStageIndexes(plan *RPCExecutionPlan) ([]int, error) { // We are using a slice to store the batch index for each noded ordered. stageIndexes := initializeSlice(len(plan.Calls), -1) diff --git a/v2/pkg/engine/datasource/grpc_datasource/program_test.go b/v2/pkg/engine/datasource/grpc_datasource/program_test.go index 9ddf0b022..d47570776 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/program_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/program_test.go @@ -22,9 +22,13 @@ func TestCompileProgram(t *testing.T) { expected expected err error }{ + // { + // name: "simple program", + // operation: `query UsersWithTypename { users { __typename id __typename name } }`, + // }, { - name: "simple program", - operation: `query UsersWithTypename { users { __typename id __typename name } }`, + name: "query with field resolver", + operation: `query CategoriesWithFieldResolvers($whoop: ProductCountFilter) { categories { id productCount(filters: $whoop) } }`, }, } diff --git a/v2/pkg/engine/datasource/grpc_datasource/runtime_test.go b/v2/pkg/engine/datasource/grpc_datasource/runtime_test.go index 7df3ed92f..4810c9d6a 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/runtime_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/runtime_test.go @@ -7,7 +7,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestNewSchemaRuntime(t *testing.T) { +func TestNewRuntimeSchema(t *testing.T) { compiler, err := NewProtoCompiler(testSchemaWithLookup, testMapping()) require.NoError(t, err) @@ -20,7 +20,7 @@ func TestNewSchemaRuntime(t *testing.T) { require.Len(t, runtime.serviceNamesByMethod, 1) } -func TestSchemaRuntimeMessages(t *testing.T) { +func TestRuntimeSchemaMessages(t *testing.T) { compiler, err := NewProtoCompiler(testSchemaWithLookup, testMapping()) require.NoError(t, err) @@ -90,7 +90,7 @@ func TestSchemaRuntimeMessages(t *testing.T) { }) } -func TestSchemaRuntimeEnums(t *testing.T) { +func TestRuntimeSchemaEnums(t *testing.T) { compiler, err := NewProtoCompiler(testSchemaWithLookup, testMapping()) require.NoError(t, err) @@ -128,7 +128,7 @@ func TestSchemaRuntimeEnums(t *testing.T) { }) } -func TestSchemaRuntimeServices(t *testing.T) { +func TestRuntimeSchemaServices(t *testing.T) { compiler, err := NewProtoCompiler(testSchemaWithLookup, testMapping()) require.NoError(t, err) diff --git a/v2/pkg/engine/datasource/grpc_datasource/wire.go b/v2/pkg/engine/datasource/grpc_datasource/wire.go index 7f54d25fe..c9c616aad 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/wire.go +++ b/v2/pkg/engine/datasource/grpc_datasource/wire.go @@ -53,87 +53,62 @@ const ( minBufferSize = 1 << 8 // 256 bytes ) -// writeVarint writes a varint to buf using a stack-allocated scratch buffer to avoid heap allocation. -func writeVarint(buf *bytes.Buffer, v uint64) { - var scratch [10]byte - buf.Write(protowire.AppendVarint(scratch[:0], v)) -} - -// writeFixed64 writes a fixed64 to buf using a stack-allocated scratch buffer to avoid heap allocation. -func writeFixed64(buf *bytes.Buffer, v uint64) { - var scratch [8]byte - buf.Write(protowire.AppendFixed64(scratch[:0], v)) -} - -// writeLengthPrefixed writes a length-delimited field value (length varint + raw bytes) to buf -// without allocating, unlike protowire.AppendBytes(nil, data). -func writeLengthPrefixed(buf *bytes.Buffer, data []byte) { - var scratch [10]byte - buf.Write(protowire.AppendVarint(scratch[:0], uint64(len(data)))) - buf.Write(data) -} +func compileWireMessageFromRequest(schema *runtimeSchema, request *request) (*wireMessage, error) { + if request == nil { + return nil, fmt.Errorf("unable to compile wire message from request: request is nil") + } -// writeTag writes a protobuf tag to buf using a stack-allocated scratch buffer. -func writeTag(buf *bytes.Buffer, num protowire.Number, typ protowire.Type) { - var scratch [10]byte - buf.Write(protowire.AppendTag(scratch[:0], num, typ)) + return compileWireMessage(schema, request.message) } - -func compileWireMessage(runtime *runtimeSchema, rpcMessage *RPCMessage, message *runtimeMessage) (*wireMessage, error) { - if message == nil { +func compileWireMessage(schema *runtimeSchema, msg *programMessage) (*wireMessage, error) { + if msg == nil { return nil, fmt.Errorf("message not found for fetch request") } - msg := &wireMessage{ - fields: make([]wireField, len(rpcMessage.Fields)), - } - // TODO: This is possible for `@requires` fields, but not yet supported. - if rpcMessage.OneOfType != OneOfTypeNone { - return nil, fmt.Errorf("oneof type not supported yet") - } + messageFields := msg.fields - for i := range rpcMessage.Fields { - rpcField := &rpcMessage.Fields[i] + wm := &wireMessage{ + fields: make([]wireField, len(messageFields)), + } - field, ok := message.fieldsByName[rpcField.Name] - if !ok { - return nil, fmt.Errorf("field not found for name %s", rpcField.Name) - } + for i := range messageFields { + messageField := messageFields[i] wf := wireField{ - number: field.desc.Number(), - runtimeMessage: field.message, - dataType: rpcField.ProtoTypeName, - wireType: getWireType(field.dataType), - jsonPath: rpcField.JSONPath, - staticValue: rpcField.StaticValue, - optional: rpcField.Optional, - repeated: rpcField.Repeated, - listMetadata: rpcField.ListMetadata, + number: messageField.runtime.desc.Number(), + runtimeMessage: messageField.runtime.message, + dataType: messageField.dataType, + wireType: getWireType(messageField.runtime.dataType), + jsonPath: messageField.jsonPath, + staticValue: messageField.staticValue, + optional: messageField.optional, + repeated: messageField.repeated, + listMetadata: messageField.listMetadata, } - if rpcField.EnumName != "" { - rtEnum, ok := runtime.enumByName[rpcField.EnumName] + if messageField.enumName != "" { + rtEnum, ok := schema.enumByName[messageField.enumName] if !ok { - return nil, fmt.Errorf("enum not found for name %s", rpcField.EnumName) + return nil, fmt.Errorf("enum not found for name %s", messageField.enumName) } wf.runtimeEnum = rtEnum } - if rpcField.Message != nil { - fieldMessage := field.message + if messageField.child != nil { + fieldMessageRuntime := messageField.child.runtime + // we we are using wrapper messages, they are compiled from the protobuf schema but doesn't match with the RPC planner schema. // We need to resolve the correct message from the runtime schema. - if rpcField.Message.Name != fieldMessage.name { - fieldMessage = runtime.getMessageByName(rpcField.Message.Name) - if fieldMessage == nil { - return nil, fmt.Errorf("message not found for name %s", rpcField.Message.Name) + if fieldMessageRuntime.name != messageField.child.name { + fieldMessageRuntime = schema.getMessageByName(messageField.child.runtime.name) + if fieldMessageRuntime == nil { + return nil, fmt.Errorf("message not found for name %s", messageField.child.runtime.name) } } - child, err := compileWireMessage(runtime, rpcField.Message, fieldMessage) + child, err := compileWireMessage(schema, messageField.child) if err != nil { return nil, err } @@ -142,12 +117,82 @@ func compileWireMessage(runtime *runtimeSchema, rpcMessage *RPCMessage, message } wf.tag = protowire.AppendTag(nil, wf.number, wf.wireType) - msg.fields[i] = wf + wm.fields[i] = wf } - return msg, nil + return wm, nil + } +// func compileWireMessage(runtime *runtimeSchema, rpcMessage *RPCMessage, message *runtimeMessage) (*wireMessage, error) { +// if message == nil { +// return nil, fmt.Errorf("message not found for fetch request") +// } +// msg := &wireMessage{ +// fields: make([]wireField, len(rpcMessage.Fields)), +// } + +// // TODO: This is possible for `@requires` fields, but not yet supported. +// if rpcMessage.OneOfType != OneOfTypeNone { +// return nil, fmt.Errorf("oneof type not supported yet") +// } + +// for i := range rpcMessage.Fields { +// rpcField := &rpcMessage.Fields[i] + +// field, ok := message.fieldsByName[rpcField.Name] +// if !ok { +// return nil, fmt.Errorf("field not found for name %s", rpcField.Name) +// } + +// wf := wireField{ +// number: field.desc.Number(), +// runtimeMessage: field.message, +// dataType: rpcField.ProtoTypeName, +// wireType: getWireType(field.dataType), +// jsonPath: rpcField.JSONPath, +// resolvePath: rpcField.ResolvePath, +// staticValue: rpcField.StaticValue, +// optional: rpcField.Optional, +// repeated: rpcField.Repeated, +// listMetadata: rpcField.ListMetadata, +// } + +// if rpcField.EnumName != "" { +// rtEnum, ok := runtime.enumByName[rpcField.EnumName] +// if !ok { +// return nil, fmt.Errorf("enum not found for name %s", rpcField.EnumName) +// } + +// wf.runtimeEnum = rtEnum +// } + +// if rpcField.Message != nil { +// fieldMessage := field.message +// // we we are using wrapper messages, they are compiled from the protobuf schema but doesn't match with the RPC planner schema. +// // We need to resolve the correct message from the runtime schema. +// if rpcField.Message.Name != fieldMessage.name { +// fieldMessage = runtime.getMessageByName(rpcField.Message.Name) +// if fieldMessage == nil { +// return nil, fmt.Errorf("message not found for name %s", rpcField.Message.Name) +// } +// } + +// child, err := compileWireMessage(runtime, rpcField.Message, fieldMessage) +// if err != nil { +// return nil, err +// } + +// wf.child = child +// } + +// wf.tag = protowire.AppendTag(nil, wf.number, wf.wireType) +// msg.fields[i] = wf +// } + +// return msg, nil +// } + // createProtoWire creates a proto wire from the wire plan. func (w *wireMessage) createProtoWire(data *astjson.Value) ([]byte, error) { buf := bytes.NewBuffer(make([]byte, 0, minBufferSize)) @@ -170,10 +215,11 @@ func (w *wireMessage) appendProtoWire(buf *bytes.Buffer, data *astjson.Value) er func (f *wireField) appendFieldWire(buf *bytes.Buffer, data *astjson.Value) error { var fieldData *astjson.Value - if f.jsonPath == "" { - fieldData = data - } else { + switch { + case f.jsonPath != "": fieldData = data.Get(f.jsonPath) + default: + fieldData = data } if !fieldData.Exists() { @@ -259,7 +305,7 @@ func (f *wireField) appendListFieldValue(buf *bytes.Buffer, data *astjson.Value, listBuffer := bytes.NewBuffer(make([]byte, 0, minBufferSize)) // We will always have a message type here, therefore we must use the bytes type. - writeTag(listBuffer, field.desc.Number(), protowire.BytesType) + listBuffer.Write(protowire.AppendTag(listBuffer.AvailableBuffer(), field.desc.Number(), protowire.BytesType)) itemsBuffer := bytes.NewBuffer(make([]byte, 0, minBufferSize)) @@ -279,10 +325,12 @@ func (f *wireField) appendListFieldValue(buf *bytes.Buffer, data *astjson.Value, } } - writeLengthPrefixed(listBuffer, itemsBuffer.Bytes()) + listBuffer.Write(protowire.AppendVarint(listBuffer.AvailableBuffer(), uint64(itemsBuffer.Len()))) + listBuffer.Write(itemsBuffer.Bytes()) buf.Write(f.tag) - writeLengthPrefixed(buf, listBuffer.Bytes()) + buf.Write(protowire.AppendVarint(buf.AvailableBuffer(), uint64(listBuffer.Len()))) + buf.Write(listBuffer.Bytes()) return nil } @@ -313,7 +361,8 @@ func (f *wireField) appendOptionalScalarFieldValue(buf *bytes.Buffer, data *astj } buf.Write(f.tag) - writeLengthPrefixed(buf, fieldBuf.Bytes()) + buf.Write(protowire.AppendVarint(buf.AvailableBuffer(), uint64(fieldBuf.Len()))) + buf.Write(fieldBuf.Bytes()) return nil } @@ -328,16 +377,19 @@ func (f *wireField) appendFieldValue(buf *bytes.Buffer, data *astjson.Value) err return err } buf.Write(f.tag) - writeLengthPrefixed(buf, childBuf.Bytes()) + buf.Write(protowire.AppendVarint(buf.AvailableBuffer(), uint64(childBuf.Len()))) + buf.Write(childBuf.Bytes()) return nil } switch f.wireType { case protowire.BytesType: buf.Write(f.tag) - writeLengthPrefixed(buf, data.GetStringBytes()) + sb := data.GetStringBytes() + buf.Write(protowire.AppendVarint(buf.AvailableBuffer(), uint64(len(sb)))) + buf.Write(sb) case protowire.VarintType: - value := data.GetUint64() + value := getUint64Value(data) if f.runtimeEnum != nil { var err error if value, err = f.getEnumValue(data); err != nil { @@ -346,10 +398,10 @@ func (f *wireField) appendFieldValue(buf *bytes.Buffer, data *astjson.Value) err } buf.Write(f.tag) - writeVarint(buf, value) + buf.Write(protowire.AppendVarint(buf.AvailableBuffer(), value)) case protowire.Fixed64Type: buf.Write(f.tag) - writeFixed64(buf, math.Float64bits(data.GetFloat64())) + buf.Write(protowire.AppendFixed64(buf.AvailableBuffer(), math.Float64bits(data.GetFloat64()))) default: return fmt.Errorf("unsupported wire type %d", f.wireType) } @@ -357,6 +409,17 @@ func (f *wireField) appendFieldValue(buf *bytes.Buffer, data *astjson.Value) err return nil } +func getUint64Value(data *astjson.Value) uint64 { + switch data.Type() { + case astjson.TypeNumber: + return data.GetUint64() + case astjson.TypeTrue: + return 1 + default: + return 0 + } +} + func (f *wireField) getEnumValue(data *astjson.Value) (uint64, error) { enumValueName := data.GetStringBytes() if len(enumValueName) == 0 { diff --git a/v2/pkg/engine/datasource/grpc_datasource/wire_test.go b/v2/pkg/engine/datasource/grpc_datasource/wire_test.go index b5a61f2ae..0b4a535e1 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/wire_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/wire_test.go @@ -151,15 +151,22 @@ func newWireTestRuntime(t *testing.T) *runtimeSchema { return runtime } -func compileTestWireMessage(t *testing.T, runtime *runtimeSchema, messageName string, rpcMessage *RPCMessage) *wireMessage { +func compileTestWireMessage(t *testing.T, runtime *runtimeSchema, message *programMessage) *wireMessage { t.Helper() - msg := runtime.getMessageByName(messageName) - require.NotNilf(t, msg, "message %q not found in runtime", messageName) - wm, err := compileWireMessage(runtime, rpcMessage, msg) + wm, err := compileWireMessage(runtime, message) require.NoError(t, err) return wm } +func compileTestProgramMessage(t *testing.T, runtime *runtimeSchema, rpcMessage *RPCMessage, rtMessage *runtimeMessage) *programMessage { + t.Helper() + + message, err := compileMessage(runtime, rpcMessage, rtMessage) + require.NoError(t, err) + + return message +} + // marshalDynamic builds a dynamicpb message using the runtime's descriptor and marshals it via proto.Marshal. // This produces the canonical protobuf encoding to compare against createProtoWire output. func marshalDynamic(t *testing.T, runtime *runtimeSchema, messageName string, build func(msg *dynamicpb.Message, desc protoref.MessageDescriptor)) []byte { @@ -203,15 +210,17 @@ func TestCompileWireMessage(t *testing.T) { runtime := newWireTestRuntime(t) t.Run("empty message with no fields", func(t *testing.T) { - wm := compileTestWireMessage(t, runtime, "EmptyRequest", &RPCMessage{ + message := compileTestProgramMessage(t, runtime, &RPCMessage{ Name: "EmptyRequest", Fields: nil, - }) + }, runtime.getMessageByName("EmptyRequest")) + + wm := compileTestWireMessage(t, runtime, message) assert.Len(t, wm.fields, 0) }) t.Run("scalar fields", func(t *testing.T) { - wm := compileTestWireMessage(t, runtime, "ScalarRequest", &RPCMessage{ + message := compileTestProgramMessage(t, runtime, &RPCMessage{ Name: "ScalarRequest", Fields: RPCFields{ {Name: "name", ProtoTypeName: DataTypeString, JSONPath: "name"}, @@ -219,7 +228,8 @@ func TestCompileWireMessage(t *testing.T) { {Name: "score", ProtoTypeName: DataTypeDouble, JSONPath: "score"}, {Name: "active", ProtoTypeName: DataTypeBool, JSONPath: "active"}, }, - }) + }, runtime.getMessageByName("ScalarRequest")) + wm := compileTestWireMessage(t, runtime, message) assert.Len(t, wm.fields, 4) assert.Equal(t, DataTypeString, wm.fields[0].dataType) assert.Equal(t, DataTypeInt32, wm.fields[1].dataType) @@ -228,7 +238,7 @@ func TestCompileWireMessage(t *testing.T) { }) t.Run("wrapper scalar fields as optional scalars", func(t *testing.T) { - wm := compileTestWireMessage(t, runtime, "WrapperScalarRequest", &RPCMessage{ + message := compileTestProgramMessage(t, runtime, &RPCMessage{ Name: "WrapperScalarRequest", Fields: RPCFields{ {Name: "name", ProtoTypeName: DataTypeString, JSONPath: "name", Optional: true}, @@ -236,7 +246,8 @@ func TestCompileWireMessage(t *testing.T) { {Name: "score", ProtoTypeName: DataTypeDouble, JSONPath: "score", Optional: true}, {Name: "active", ProtoTypeName: DataTypeBool, JSONPath: "active", Optional: true}, }, - }) + }, runtime.getMessageByName("WrapperScalarRequest")) + wm := compileTestWireMessage(t, runtime, message) assert.Len(t, wm.fields, 4) assert.True(t, wm.fields[0].optional) assert.True(t, wm.fields[1].optional) @@ -245,13 +256,14 @@ func TestCompileWireMessage(t *testing.T) { }) t.Run("repeated scalar fields", func(t *testing.T) { - wm := compileTestWireMessage(t, runtime, "RepeatedScalarRequest", &RPCMessage{ + message := compileTestProgramMessage(t, runtime, &RPCMessage{ Name: "RepeatedScalarRequest", Fields: RPCFields{ {Name: "tags", ProtoTypeName: DataTypeString, JSONPath: "tags", Repeated: true}, {Name: "scores", ProtoTypeName: DataTypeInt32, JSONPath: "scores", Repeated: true}, }, - }) + }, runtime.getMessageByName("RepeatedScalarRequest")) + wm := compileTestWireMessage(t, runtime, message) assert.Len(t, wm.fields, 2) assert.True(t, wm.fields[0].repeated) assert.True(t, wm.fields[1].repeated) @@ -260,9 +272,8 @@ func TestCompileWireMessage(t *testing.T) { t.Run("list wrapper with list metadata", func(t *testing.T) { msg := runtime.getMessageByName("ListWrapperRequest") require.NotNil(t, msg) - // Optional + IsListType: compileWireMessage must not treat this as a wrapper scalar. - // Currently this errors because it tries to wrap in google.protobuf.*Value and looks for "value" in ListOfString. - wm, err := compileWireMessage(runtime, &RPCMessage{ + + message := compileTestProgramMessage(t, runtime, &RPCMessage{ Name: "ListWrapperRequest", Fields: RPCFields{ { @@ -279,6 +290,10 @@ func TestCompileWireMessage(t *testing.T) { }, }, msg) + // Optional + IsListType: compileWireMessage must not treat this as a wrapper scalar. + // Currently this errors because it tries to wrap in google.protobuf.*Value and looks for "value" in ListOfString. + wm, err := compileWireMessage(runtime, message) + require.NoError(t, err) assert.Len(t, wm.fields, 1) @@ -287,7 +302,7 @@ func TestCompileWireMessage(t *testing.T) { }) t.Run("nested list wrapper with list metadata", func(t *testing.T) { - wm := compileTestWireMessage(t, runtime, "NestedListRequest", &RPCMessage{ + message := compileTestProgramMessage(t, runtime, &RPCMessage{ Name: "NestedListRequest", Fields: RPCFields{ { @@ -301,20 +316,22 @@ func TestCompileWireMessage(t *testing.T) { }, }, }, - }) + }, runtime.getMessageByName("NestedListRequest")) + wm := compileTestWireMessage(t, runtime, message) assert.Len(t, wm.fields, 1) assert.NotNil(t, wm.fields[0].listMetadata) assert.Equal(t, 2, wm.fields[0].listMetadata.NestingLevel) }) t.Run("enum field", func(t *testing.T) { - wm := compileTestWireMessage(t, runtime, "EnumRequest", &RPCMessage{ + message := compileTestProgramMessage(t, runtime, &RPCMessage{ Name: "EnumRequest", Fields: RPCFields{ {Name: "status", ProtoTypeName: DataTypeEnum, JSONPath: "status", EnumName: "Status"}, {Name: "statuses", ProtoTypeName: DataTypeEnum, JSONPath: "statuses", EnumName: "Status", Repeated: true}, }, - }) + }, runtime.getMessageByName("EnumRequest")) + wm := compileTestWireMessage(t, runtime, message) assert.Len(t, wm.fields, 2) assert.Equal(t, DataTypeEnum, wm.fields[0].dataType) assert.Equal(t, DataTypeEnum, wm.fields[1].dataType) @@ -322,7 +339,7 @@ func TestCompileWireMessage(t *testing.T) { }) t.Run("entity lookup request", func(t *testing.T) { - wm := compileTestWireMessage(t, runtime, "LookupProductByIdRequest", &RPCMessage{ + message := compileTestProgramMessage(t, runtime, &RPCMessage{ Name: "LookupProductByIdRequest", Fields: RPCFields{ { @@ -339,7 +356,8 @@ func TestCompileWireMessage(t *testing.T) { }, }, }, - }) + }, runtime.getMessageByName("LookupProductByIdRequest")) + wm := compileTestWireMessage(t, runtime, message) assert.Len(t, wm.fields, 1) assert.True(t, wm.fields[0].repeated) assert.Equal(t, DataTypeMessage, wm.fields[0].dataType) @@ -352,10 +370,11 @@ func TestCreateProtoWire(t *testing.T) { runtime := newWireTestRuntime(t) t.Run("empty message", func(t *testing.T) { - wm := compileTestWireMessage(t, runtime, "EmptyRequest", &RPCMessage{ + message := compileTestProgramMessage(t, runtime, &RPCMessage{ Name: "EmptyRequest", Fields: nil, - }) + }, runtime.getMessageByName("EmptyRequest")) + wm := compileTestWireMessage(t, runtime, message) got, err := wm.createProtoWire(astjson.MustParse(`{}`)) require.NoError(t, err) @@ -366,12 +385,13 @@ func TestCreateProtoWire(t *testing.T) { }) t.Run("single string field", func(t *testing.T) { - wm := compileTestWireMessage(t, runtime, "ScalarRequest", &RPCMessage{ + message := compileTestProgramMessage(t, runtime, &RPCMessage{ Name: "ScalarRequest", Fields: RPCFields{ {Name: "name", ProtoTypeName: DataTypeString, JSONPath: "name"}, }, - }) + }, runtime.getMessageByName("ScalarRequest")) + wm := compileTestWireMessage(t, runtime, message) got, err := wm.createProtoWire(astjson.MustParse(`{"name":"hello"}`)) require.NoError(t, err) @@ -384,14 +404,15 @@ func TestCreateProtoWire(t *testing.T) { }) t.Run("string int32 and double fields", func(t *testing.T) { - wm := compileTestWireMessage(t, runtime, "ScalarRequest", &RPCMessage{ + message := compileTestProgramMessage(t, runtime, &RPCMessage{ Name: "ScalarRequest", Fields: RPCFields{ {Name: "name", ProtoTypeName: DataTypeString, JSONPath: "name"}, {Name: "age", ProtoTypeName: DataTypeInt32, JSONPath: "age"}, {Name: "score", ProtoTypeName: DataTypeDouble, JSONPath: "score"}, }, - }) + }, runtime.getMessageByName("ScalarRequest")) + wm := compileTestWireMessage(t, runtime, message) got, err := wm.createProtoWire(astjson.MustParse(`{"name":"alice","age":30,"score":99.5}`)) require.NoError(t, err) @@ -406,14 +427,15 @@ func TestCreateProtoWire(t *testing.T) { }) t.Run("wrapper string value present", func(t *testing.T) { - wm := compileTestWireMessage(t, runtime, "WrapperScalarRequest", &RPCMessage{ + message := compileTestProgramMessage(t, runtime, &RPCMessage{ Name: "WrapperScalarRequest", Fields: RPCFields{ { Name: "name", ProtoTypeName: DataTypeString, JSONPath: "name", Optional: true, }, }, - }) + }, runtime.getMessageByName("WrapperScalarRequest")) + wm := compileTestWireMessage(t, runtime, message) got, err := wm.createProtoWire(astjson.MustParse(`{"name":"hello"}`)) require.NoError(t, err) @@ -426,12 +448,13 @@ func TestCreateProtoWire(t *testing.T) { }) t.Run("wrapper string value absent", func(t *testing.T) { - wm := compileTestWireMessage(t, runtime, "WrapperScalarRequest", &RPCMessage{ + message := compileTestProgramMessage(t, runtime, &RPCMessage{ Name: "WrapperScalarRequest", Fields: RPCFields{ {Name: "name", ProtoTypeName: DataTypeString, JSONPath: "name", Optional: true}, }, - }) + }, runtime.getMessageByName("WrapperScalarRequest")) + wm := compileTestWireMessage(t, runtime, message) got, err := wm.createProtoWire(astjson.MustParse(`{}`)) require.NoError(t, err) @@ -444,13 +467,14 @@ func TestCreateProtoWire(t *testing.T) { }) t.Run("wrapper int32 and double values", func(t *testing.T) { - wm := compileTestWireMessage(t, runtime, "WrapperScalarRequest", &RPCMessage{ + message := compileTestProgramMessage(t, runtime, &RPCMessage{ Name: "WrapperScalarRequest", Fields: RPCFields{ {Name: "age", ProtoTypeName: DataTypeInt32, JSONPath: "age", Optional: true}, {Name: "score", ProtoTypeName: DataTypeDouble, JSONPath: "score", Optional: true}, }, - }) + }, runtime.getMessageByName("WrapperScalarRequest")) + wm := compileTestWireMessage(t, runtime, message) got, err := wm.createProtoWire(astjson.MustParse(`{"age":25,"score":3.14}`)) require.NoError(t, err) @@ -464,12 +488,13 @@ func TestCreateProtoWire(t *testing.T) { }) t.Run("repeated strings", func(t *testing.T) { - wm := compileTestWireMessage(t, runtime, "RepeatedScalarRequest", &RPCMessage{ + message := compileTestProgramMessage(t, runtime, &RPCMessage{ Name: "RepeatedScalarRequest", Fields: RPCFields{ {Name: "tags", ProtoTypeName: DataTypeString, JSONPath: "tags", Repeated: true}, }, - }) + }, runtime.getMessageByName("RepeatedScalarRequest")) + wm := compileTestWireMessage(t, runtime, message) got, err := wm.createProtoWire(astjson.MustParse(`{"tags":["foo","bar","baz"]}`)) require.NoError(t, err) @@ -485,12 +510,13 @@ func TestCreateProtoWire(t *testing.T) { }) t.Run("repeated int32s", func(t *testing.T) { - wm := compileTestWireMessage(t, runtime, "RepeatedScalarRequest", &RPCMessage{ + message := compileTestProgramMessage(t, runtime, &RPCMessage{ Name: "RepeatedScalarRequest", Fields: RPCFields{ {Name: "scores", ProtoTypeName: DataTypeInt32, JSONPath: "scores", Repeated: true}, }, - }) + }, runtime.getMessageByName("RepeatedScalarRequest")) + wm := compileTestWireMessage(t, runtime, message) got, err := wm.createProtoWire(astjson.MustParse(`{"scores":[1,2,3]}`)) require.NoError(t, err) @@ -506,7 +532,7 @@ func TestCreateProtoWire(t *testing.T) { }) t.Run("single nested message", func(t *testing.T) { - wm := compileTestWireMessage(t, runtime, "NestedMessageRequest", &RPCMessage{ + message := compileTestProgramMessage(t, runtime, &RPCMessage{ Name: "NestedMessageRequest", Fields: RPCFields{ { @@ -520,7 +546,8 @@ func TestCreateProtoWire(t *testing.T) { }, }, }, - }) + }, runtime.getMessageByName("NestedMessageRequest")) + wm := compileTestWireMessage(t, runtime, message) got, err := wm.createProtoWire(astjson.MustParse(`{"item":{"id":"1","value":"a"}}`)) require.NoError(t, err) @@ -537,7 +564,7 @@ func TestCreateProtoWire(t *testing.T) { }) t.Run("repeated nested messages", func(t *testing.T) { - wm := compileTestWireMessage(t, runtime, "NestedMessageRequest", &RPCMessage{ + message := compileTestProgramMessage(t, runtime, &RPCMessage{ Name: "NestedMessageRequest", Fields: RPCFields{ { @@ -551,7 +578,8 @@ func TestCreateProtoWire(t *testing.T) { }, }, }, - }) + }, runtime.getMessageByName("NestedMessageRequest")) + wm := compileTestWireMessage(t, runtime, message) got, err := wm.createProtoWire(astjson.MustParse(`{"items":[{"id":"1","value":"a"},{"id":"2","value":"b"}]}`)) require.NoError(t, err) @@ -576,12 +604,13 @@ func TestCreateProtoWire(t *testing.T) { }) t.Run("enum field", func(t *testing.T) { - wm := compileTestWireMessage(t, runtime, "EnumRequest", &RPCMessage{ + message := compileTestProgramMessage(t, runtime, &RPCMessage{ Name: "EnumRequest", Fields: RPCFields{ {Name: "status", ProtoTypeName: DataTypeEnum, JSONPath: "status", EnumName: "Status"}, }, - }) + }, runtime.getMessageByName("EnumRequest")) + wm := compileTestWireMessage(t, runtime, message) // GraphQL sends enum values as strings (e.g. "ACTIVE"), not proto-prefixed names or integers. // The wire builder must resolve "ACTIVE" -> STATUS_ACTIVE = 1 via the runtime enum map. @@ -596,12 +625,13 @@ func TestCreateProtoWire(t *testing.T) { }) t.Run("repeated enums", func(t *testing.T) { - wm := compileTestWireMessage(t, runtime, "EnumRequest", &RPCMessage{ + message := compileTestProgramMessage(t, runtime, &RPCMessage{ Name: "EnumRequest", Fields: RPCFields{ {Name: "statuses", ProtoTypeName: DataTypeEnum, JSONPath: "statuses", EnumName: "Status", Repeated: true}, }, - }) + }, runtime.getMessageByName("EnumRequest")) + wm := compileTestWireMessage(t, runtime, message) got, err := wm.createProtoWire(astjson.MustParse(`{"statuses":["UNSPECIFIED","ACTIVE","INACTIVE"]}`)) require.NoError(t, err) @@ -619,9 +649,7 @@ func TestCreateProtoWire(t *testing.T) { t.Run("list wrapper with strings", func(t *testing.T) { // RPC plan models ListOfString as a flat optional scalar with IsListType + ListMetadata. // createProtoWire must produce: ListWrapperRequest { optional_tags: ListOfString { list: List { items: [...] } } } - msg := runtime.getMessageByName("ListWrapperRequest") - require.NotNil(t, msg) - wm, compileErr := compileWireMessage(runtime, &RPCMessage{ + message := compileTestProgramMessage(t, runtime, &RPCMessage{ Name: "ListWrapperRequest", Fields: RPCFields{ { @@ -636,9 +664,8 @@ func TestCreateProtoWire(t *testing.T) { }, }, }, - }, msg) - - require.NoError(t, compileErr) + }, runtime.getMessageByName("ListWrapperRequest")) + wm := compileTestWireMessage(t, runtime, message) got, err := wm.createProtoWire(astjson.MustParse(`{"optionalTags":["a","b"]}`)) require.NoError(t, err) @@ -668,9 +695,7 @@ func TestCreateProtoWire(t *testing.T) { t.Run("nested list wrapper two levels", func(t *testing.T) { // RPC plan models ListOfListOfString as a flat scalar with IsListType + ListMetadata (NestingLevel=2). // createProtoWire must produce: NestedListRequest { tag_groups: ListOfListOfString { list: { items: [ ListOfString{...}, ... ] } } } - msg := runtime.getMessageByName("NestedListRequest") - require.NotNil(t, msg) - wm, compileErr := compileWireMessage(runtime, &RPCMessage{ + message := compileTestProgramMessage(t, runtime, &RPCMessage{ Name: "NestedListRequest", Fields: RPCFields{ { @@ -684,9 +709,8 @@ func TestCreateProtoWire(t *testing.T) { }, }, }, - }, msg) - - require.NoError(t, compileErr) + }, runtime.getMessageByName("NestedListRequest")) + wm := compileTestWireMessage(t, runtime, message) got, err := wm.createProtoWire(astjson.MustParse(`{"tagGroups":[["a","b"],["c"]]}`)) require.NoError(t, err) @@ -731,9 +755,7 @@ func TestCreateProtoWire(t *testing.T) { t.Run("nested list wrapper two levels with messages", func(t *testing.T) { // NestedListRequest { item_groups: ListOfListOfNestedItem { list: { items: [ ListOfNestedItem{...}, ... ] } } } // The inner ListOfNestedItem contains NestedItem messages with id + value fields. - msg := runtime.getMessageByName("NestedListRequest") - require.NotNil(t, msg) - wm, compileErr := compileWireMessage(runtime, &RPCMessage{ + message := compileTestProgramMessage(t, runtime, &RPCMessage{ Name: "NestedListRequest", Fields: RPCFields{ { @@ -754,9 +776,8 @@ func TestCreateProtoWire(t *testing.T) { }, }, }, - }, msg) - - require.NoError(t, compileErr) + }, runtime.getMessageByName("NestedListRequest")) + wm := compileTestWireMessage(t, runtime, message) got, err := wm.createProtoWire(astjson.MustParse(`{"itemGroups":[[{"id":"1","value":"a"},{"id":"2","value":"b"}],[{"id":"3","value":"c"}]]}`)) require.NoError(t, err) @@ -809,7 +830,7 @@ func TestCreateProtoWire(t *testing.T) { }) t.Run("mixed request with multiple field types", func(t *testing.T) { - wm := compileTestWireMessage(t, runtime, "MixedRequest", &RPCMessage{ + message := compileTestProgramMessage(t, runtime, &RPCMessage{ Name: "MixedRequest", Fields: RPCFields{ {Name: "id", ProtoTypeName: DataTypeString, JSONPath: "id"}, @@ -818,7 +839,8 @@ func TestCreateProtoWire(t *testing.T) { {Name: "price", ProtoTypeName: DataTypeDouble, JSONPath: "price"}, {Name: "status", ProtoTypeName: DataTypeEnum, JSONPath: "status", EnumName: "Status"}, }, - }) + }, runtime.getMessageByName("MixedRequest")) + wm := compileTestWireMessage(t, runtime, message) got, err := wm.createProtoWire(astjson.MustParse(`{"id":"p1","description":"a product","tags":["sale","new"],"price":29.99,"status":"ACTIVE"}`)) require.NoError(t, err) @@ -837,7 +859,7 @@ func TestCreateProtoWire(t *testing.T) { }) t.Run("entity lookup single key", func(t *testing.T) { - wm := compileTestWireMessage(t, runtime, "LookupProductByIdRequest", &RPCMessage{ + message := compileTestProgramMessage(t, runtime, &RPCMessage{ Name: "LookupProductByIdRequest", Fields: RPCFields{ { @@ -854,7 +876,8 @@ func TestCreateProtoWire(t *testing.T) { }, }, }, - }) + }, runtime.getMessageByName("LookupProductByIdRequest")) + wm := compileTestWireMessage(t, runtime, message) got, err := wm.createProtoWire(astjson.MustParse(`{"representations":[{"__typename":"Product","id":"1"}]}`)) require.NoError(t, err) @@ -873,7 +896,7 @@ func TestCreateProtoWire(t *testing.T) { }) t.Run("entity lookup multiple keys", func(t *testing.T) { - wm := compileTestWireMessage(t, runtime, "LookupProductByIdRequest", &RPCMessage{ + message := compileTestProgramMessage(t, runtime, &RPCMessage{ Name: "LookupProductByIdRequest", Fields: RPCFields{ { @@ -890,7 +913,8 @@ func TestCreateProtoWire(t *testing.T) { }, }, }, - }) + }, runtime.getMessageByName("LookupProductByIdRequest")) + wm := compileTestWireMessage(t, runtime, message) got, err := wm.createProtoWire(astjson.MustParse(`{"representations":[{"__typename":"Product","id":"1"},{"__typename":"Product","id":"2"},{"__typename":"Product","id":"3"}]}`)) require.NoError(t, err) From 2a21ffe36be07bcfb8c2c0c2b5c301599b99f0cc Mon Sep 17 00:00:00 2001 From: Ludwig Bedacht Date: Fri, 24 Apr 2026 01:19:05 +0200 Subject: [PATCH 07/12] chore: improve rpc plan assertion and handle recursive compilation --- .../grpc_datasource/execution_plan.go | 4 - .../execution_plan_composite_test.go | 11 +- .../execution_plan_federation_test.go | 17 +-- .../execution_plan_field_resolvers_test.go | 37 +------ .../execution_plan_requires_test.go | 22 +--- .../grpc_datasource/execution_plan_test.go | 103 ++++++++++++++++-- .../grpc_datasource/grpc_datasource.go | 56 ++-------- .../grpc_datasource/grpc_datasource_test.go | 36 +----- .../datasource/grpc_datasource/program.go | 36 ++++-- .../engine/datasource/grpc_datasource/wire.go | 82 ++------------ .../datasource/grpc_datasource/wire_test.go | 6 +- 11 files changed, 161 insertions(+), 249 deletions(-) diff --git a/v2/pkg/engine/datasource/grpc_datasource/execution_plan.go b/v2/pkg/engine/datasource/grpc_datasource/execution_plan.go index a5db0a3dd..ed3e9d305 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/execution_plan.go +++ b/v2/pkg/engine/datasource/grpc_datasource/execution_plan.go @@ -907,7 +907,6 @@ type resolverField struct { fieldsSelectionSetRef int responsePath ast.Path contextPath ast.Path - contextJSONRoot string contextFields []contextField fieldArguments []fieldArgument @@ -1011,8 +1010,6 @@ func (r *rpcPlanningContext) setResolvedField(walker *astvisitor.Walker, fieldDe return err } - resolvedField.contextJSONRoot = string(walker.Path[1].FieldName) - for _, contextFieldRef := range contextFields { mapping := r.resolveFieldMapping( walker.EnclosingTypeDefinition.NameString(r.definition), @@ -1623,7 +1620,6 @@ func (r *rpcPlanningContext) newResolveRPCCall(config *resolveRPCCallConfig) (RP { Name: contextFieldName, ProtoTypeName: DataTypeMessage, - JSONPath: resolvedField.contextJSONRoot, Repeated: true, Message: config.contextMessage, }, diff --git a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_composite_test.go b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_composite_test.go index 761b735e4..8a7125ac4 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_composite_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_composite_test.go @@ -3,7 +3,6 @@ package grpcdatasource import ( "testing" - "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/require" "github.com/wundergraph/graphql-go-tools/v2/pkg/astparser" @@ -986,10 +985,7 @@ func TestCompositeTypeExecutionPlan(t *testing.T) { } require.Empty(t, tt.expectedError) - diff := cmp.Diff(tt.expectedPlan, plan) - if diff != "" { - t.Fatalf("execution plan mismatch: %s", diff) - } + assertExecutionPlanEqual(t, tt.expectedPlan, plan) }) } } @@ -1252,10 +1248,7 @@ func TestMutationUnionExecutionPlan(t *testing.T) { } require.Empty(t, tt.expectedError) - diff := cmp.Diff(tt.expectedPlan, plan) - if diff != "" { - t.Fatalf("execution plan mismatch: %s", diff) - } + assertExecutionPlanEqual(t, tt.expectedPlan, plan) }) } } diff --git a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_federation_test.go b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_federation_test.go index a56882751..4619a67d1 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_federation_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_federation_test.go @@ -5,8 +5,6 @@ import ( "strings" "testing" - "github.com/google/go-cmp/cmp" - "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" "github.com/wundergraph/graphql-go-tools/v2/pkg/astparser" "github.com/wundergraph/graphql-go-tools/v2/pkg/astvalidation" @@ -296,10 +294,7 @@ func TestExecutionPlan_Federation_EntityLookup(t *testing.T) { t.Fatalf("failed to plan operation: %s", err) } - diff := cmp.Diff(tt.expectedPlan, plan) - if diff != "" { - t.Fatalf("execution plan mismatch: %s", diff) - } + assertExecutionPlanEqual(t, tt.expectedPlan, plan) }) } } @@ -1886,10 +1881,7 @@ func TestEntityLookupWithFieldResolvers_ComplexResolverInNestedMessage(t *testin t.Fatalf("failed to plan operation: %s", err) } - diff := cmp.Diff(expectedPlan, plan) - if diff != "" { - t.Fatalf("execution plan mismatch: %s", diff) - } + assertExecutionPlanEqual(t, expectedPlan, plan) }) } @@ -1936,10 +1928,7 @@ func runFederationTest(t *testing.T, tt struct { t.Fatalf("failed to plan operation: %s", err) } - diff := cmp.Diff(tt.expectedPlan, plan) - if diff != "" { - t.Fatalf("execution plan mismatch: %s", diff) - } + assertExecutionPlanEqual(t, tt.expectedPlan, plan) }) } diff --git a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_field_resolvers_test.go b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_field_resolvers_test.go index 07f7ea985..b1583301a 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_field_resolvers_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_field_resolvers_test.go @@ -3,8 +3,6 @@ package grpcdatasource import ( "testing" - "github.com/google/go-cmp/cmp" - "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" "github.com/wundergraph/graphql-go-tools/v2/pkg/astparser" "github.com/wundergraph/graphql-go-tools/v2/pkg/astvalidation" @@ -79,7 +77,6 @@ func TestExecutionPlanFieldResolvers(t *testing.T) { { Name: "context", ProtoTypeName: DataTypeMessage, - JSONPath: "", Repeated: true, Message: &RPCMessage{ Name: "ResolveCategoryProductCountContext", @@ -211,7 +208,6 @@ func TestExecutionPlanFieldResolvers(t *testing.T) { { Name: "context", ProtoTypeName: DataTypeMessage, - JSONPath: "", Repeated: true, Message: &RPCMessage{ Name: "ResolveCategoryProductCountContext", @@ -314,7 +310,6 @@ func TestExecutionPlanFieldResolvers(t *testing.T) { { Name: "context", ProtoTypeName: DataTypeMessage, - JSONPath: "", Repeated: true, Message: &RPCMessage{ Name: "ResolveCategoryProductCountContext", @@ -491,7 +486,6 @@ func TestExecutionPlanFieldResolvers(t *testing.T) { { Name: "context", ProtoTypeName: DataTypeMessage, - JSONPath: "", Repeated: true, Message: &RPCMessage{ Name: "ResolveSubcategoryItemCountContext", @@ -623,7 +617,6 @@ func TestExecutionPlanFieldResolvers(t *testing.T) { { Name: "context", ProtoTypeName: DataTypeMessage, - JSONPath: "", Repeated: true, Message: &RPCMessage{ Name: "ResolveCategoryCategoryMetricsContext", @@ -747,7 +740,6 @@ func TestExecutionPlanFieldResolvers(t *testing.T) { { Name: "context", ProtoTypeName: DataTypeMessage, - JSONPath: "", Repeated: true, Message: &RPCMessage{ Name: "ResolveCategoryCategoryMetricsContext", @@ -874,7 +866,6 @@ func TestExecutionPlanFieldResolvers(t *testing.T) { { Name: "context", ProtoTypeName: DataTypeMessage, - JSONPath: "", Repeated: true, Message: &RPCMessage{ Name: "ResolveCategoryPopularityScoreContext", @@ -942,7 +933,6 @@ func TestExecutionPlanFieldResolvers(t *testing.T) { { Name: "context", ProtoTypeName: DataTypeMessage, - JSONPath: "", Repeated: true, Message: &RPCMessage{ Name: "ResolveCategoryCategoryMetricsContext", @@ -3902,7 +3892,6 @@ func TestExecutionPlanFieldResolvers_CustomSchemas(t *testing.T) { Name: "context", ProtoTypeName: DataTypeMessage, Repeated: true, - JSONPath: "foo", Message: &RPCMessage{ Name: "ResolveFooFooResolverContext", Fields: []RPCField{ @@ -4365,7 +4354,6 @@ func TestExecutionPlan_FederationFieldResolvers(t *testing.T) { { Name: "context", ProtoTypeName: DataTypeMessage, - JSONPath: "", Repeated: true, Message: &RPCMessage{ Name: "ResolveProductShippingEstimateContext", @@ -4374,13 +4362,13 @@ func TestExecutionPlan_FederationFieldResolvers(t *testing.T) { Name: "id", ProtoTypeName: DataTypeString, JSONPath: "id", - ResolvePath: buildPath("_entities.id"), + ResolvePath: buildPath("result.id"), }, { Name: "price", ProtoTypeName: DataTypeDouble, JSONPath: "price", - ResolvePath: buildPath("_entities.price"), + ResolvePath: buildPath("result.price"), }, }, }, @@ -4609,7 +4597,6 @@ func TestExecutionPlan_FederationFieldResolvers(t *testing.T) { { Name: "context", ProtoTypeName: DataTypeMessage, - JSONPath: "", Repeated: true, Message: &RPCMessage{ Name: "ResolveProductShippingEstimateContext", @@ -4716,10 +4703,7 @@ func TestExecutionPlan_FederationFieldResolvers(t *testing.T) { t.Fatalf("failed to plan operation: %s", err) } - diff := cmp.Diff(tt.expectedPlan, plan) - if diff != "" { - t.Fatalf("execution plan mismatch: %s", diff) - } + assertExecutionPlanEqual(t, tt.expectedPlan, plan) }) } } @@ -4818,7 +4802,6 @@ func TestExecutionPlan_FederationFieldResolvers_WithCompositeTypes(t *testing.T) { Name: "context", ProtoTypeName: DataTypeMessage, - JSONPath: "", Repeated: true, Message: &RPCMessage{ Name: "ResolveProductMascotRecommendationContext", @@ -4997,7 +4980,6 @@ func TestExecutionPlan_FederationFieldResolvers_WithCompositeTypes(t *testing.T) { Name: "context", ProtoTypeName: DataTypeMessage, - JSONPath: "", Repeated: true, Message: &RPCMessage{ Name: "ResolveProductStockStatusContext", @@ -5186,7 +5168,6 @@ func TestExecutionPlan_FederationFieldResolvers_WithCompositeTypes(t *testing.T) { Name: "context", ProtoTypeName: DataTypeMessage, - JSONPath: "", Repeated: true, Message: &RPCMessage{ Name: "ResolveProductProductDetailsContext", @@ -5365,10 +5346,7 @@ func TestExecutionPlan_FederationFieldResolvers_WithCompositeTypes(t *testing.T) t.Fatalf("failed to plan operation: %s", err) } - diff := cmp.Diff(tt.expectedPlan, plan) - if diff != "" { - t.Fatalf("execution plan mismatch: %s", diff) - } + assertExecutionPlanEqual(t, tt.expectedPlan, plan) }) } } @@ -5643,7 +5621,6 @@ func TestExecutionPlanFieldResolvers_ArgumentLess(t *testing.T) { { Name: "context", ProtoTypeName: DataTypeMessage, - JSONPath: "", Repeated: true, Message: &RPCMessage{ Name: "ResolveCategoryTotalProductsContext", @@ -5735,7 +5712,6 @@ func TestExecutionPlanFieldResolvers_ArgumentLess(t *testing.T) { { Name: "context", ProtoTypeName: DataTypeMessage, - JSONPath: "", Repeated: true, Message: &RPCMessage{ Name: "ResolveCategoryTopSubcategoryContext", @@ -5844,7 +5820,6 @@ func TestExecutionPlanFieldResolvers_ArgumentLess(t *testing.T) { { Name: "context", ProtoTypeName: DataTypeMessage, - JSONPath: "", Repeated: true, Message: &RPCMessage{ Name: "ResolveCategoryActiveSubcategoriesContext", @@ -5947,7 +5922,6 @@ func TestExecutionPlanFieldResolvers_ArgumentLess(t *testing.T) { { Name: "context", ProtoTypeName: DataTypeMessage, - JSONPath: "", Repeated: true, Message: &RPCMessage{ Name: "ResolveCategoryCategoryMetricsContext", @@ -6030,7 +6004,6 @@ func TestExecutionPlanFieldResolvers_ArgumentLess(t *testing.T) { { Name: "context", ProtoTypeName: DataTypeMessage, - JSONPath: "", Repeated: true, Message: &RPCMessage{ Name: "ResolveCategoryMetricsAverageScoreContext", @@ -6123,7 +6096,6 @@ func TestExecutionPlanFieldResolvers_ArgumentLess(t *testing.T) { { Name: "context", ProtoTypeName: DataTypeMessage, - JSONPath: "", Repeated: true, Message: &RPCMessage{ Name: "ResolveCategoryTotalProductsContext", @@ -6174,7 +6146,6 @@ func TestExecutionPlanFieldResolvers_ArgumentLess(t *testing.T) { { Name: "context", ProtoTypeName: DataTypeMessage, - JSONPath: "", Repeated: true, Message: &RPCMessage{ Name: "ResolveCategoryProductCountContext", diff --git a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_requires_test.go b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_requires_test.go index d70a2c5e9..b9f248bf5 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_requires_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_requires_test.go @@ -3,8 +3,6 @@ package grpcdatasource import ( "testing" - "github.com/google/go-cmp/cmp" - "github.com/wundergraph/graphql-go-tools/v2/pkg/astparser" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/plan" "github.com/wundergraph/graphql-go-tools/v2/pkg/grpctest" @@ -2223,10 +2221,7 @@ func TestExecutionPlan_FederationRequires(t *testing.T) { t.Fatalf("failed to plan operation: %s", err) } - diff := cmp.Diff(tt.expectedPlan, plan) - if diff != "" { - t.Fatalf("execution plan mismatch: %s", diff) - } + assertExecutionPlanEqual(t, tt.expectedPlan, plan) }) } } @@ -3143,10 +3138,7 @@ func TestExecutionPlan_FederationRequires_AbstractTypes(t *testing.T) { t.Fatalf("failed to plan operation: %s", err) } - diff := cmp.Diff(tt.expectedPlan, plan) - if diff != "" { - t.Fatalf("execution plan mismatch: %s", diff) - } + assertExecutionPlanEqual(t, tt.expectedPlan, plan) }) } } @@ -3246,7 +3238,6 @@ func TestExecutionPlan_FederationRequires_WithFieldResolvers(t *testing.T) { { Name: "context", ProtoTypeName: DataTypeMessage, - JSONPath: "", Repeated: true, Message: &RPCMessage{ Name: "ResolveStorageStorageStatusContext", @@ -3481,7 +3472,6 @@ func TestExecutionPlan_FederationRequires_WithFieldResolvers(t *testing.T) { { Name: "context", ProtoTypeName: DataTypeMessage, - JSONPath: "", Repeated: true, Message: &RPCMessage{ Name: "ResolveStorageLinkedStoragesContext", @@ -3738,7 +3728,6 @@ func TestExecutionPlan_FederationRequires_WithFieldResolvers(t *testing.T) { { Name: "context", ProtoTypeName: DataTypeMessage, - JSONPath: "", Repeated: true, Message: &RPCMessage{ Name: "ResolveStorageNearbyStoragesContext", @@ -3992,7 +3981,6 @@ func TestExecutionPlan_FederationRequires_WithFieldResolvers(t *testing.T) { { Name: "context", ProtoTypeName: DataTypeMessage, - JSONPath: "", Repeated: true, Message: &RPCMessage{ Name: "ResolveStorageStorageStatusContext", @@ -4315,7 +4303,6 @@ func TestExecutionPlan_FederationRequires_WithFieldResolvers(t *testing.T) { { Name: "context", ProtoTypeName: DataTypeMessage, - JSONPath: "", Repeated: true, Message: &RPCMessage{ Name: "ResolveStorageLinkedStoragesContext", @@ -4535,10 +4522,7 @@ func TestExecutionPlan_FederationRequires_WithFieldResolvers(t *testing.T) { t.Fatalf("failed to plan operation: %s", err) } - diff := cmp.Diff(tt.expectedPlan, plan) - if diff != "" { - t.Fatalf("execution plan mismatch: %s", diff) - } + assertExecutionPlanEqual(t, tt.expectedPlan, plan) }) } } diff --git a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_test.go b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_test.go index 77b60b813..65dd851de 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_test.go @@ -1,9 +1,11 @@ package grpcdatasource import ( + "fmt" + "slices" "testing" - "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" @@ -28,6 +30,7 @@ type testConfig struct { } func runTestWithConfig(t *testing.T, testCase testCase, testConfig testConfig) { + t.Helper() rpcPlanVisitor := newRPCPlanVisitor(rpcPlanVisitorConfig{ subgraphName: testConfig.subgraphName, mapping: testConfig.mapping, @@ -42,13 +45,11 @@ func runTestWithConfig(t *testing.T, testCase testCase, testConfig testConfig) { } require.Empty(t, testCase.expectedError) - diff := cmp.Diff(testCase.expectedPlan, plan) - if diff != "" { - t.Fatalf("execution plan mismatch: %s", diff) - } + assertExecutionPlanEqual(t, testCase.expectedPlan, plan) } func runTest(t *testing.T, testCase testCase) { + t.Helper() // Parse the GraphQL schema schemaDoc := grpctest.MustGraphQLSchema(t) @@ -1148,10 +1149,7 @@ func TestExecutionPlan_Query(t *testing.T) { } require.Empty(t, tt.expectedError) - diff := cmp.Diff(tt.expectedPlan, plan) - if diff != "" { - t.Fatalf("execution plan mismatch: %s", diff) - } + assertExecutionPlanEqual(t, tt.expectedPlan, plan) }) } } @@ -2580,6 +2578,93 @@ func TestExecutionPlan_Operations_WithAliases(t *testing.T) { } +func assertExecutionPlanEqual(t *testing.T, expected, got *RPCExecutionPlan) { + t.Helper() + require.Equal(t, len(expected.Calls), len(got.Calls), "call count mismatch: expected %d, got %d", len(expected.Calls), len(got.Calls)) + + gotByID := make(map[int]RPCCall, len(got.Calls)) + for _, c := range got.Calls { + gotByID[c.ID] = c + } + + for _, expectedCall := range expected.Calls { + gotCall, ok := gotByID[expectedCall.ID] + require.True(t, ok, "missing call with ID %d (method %s)", expectedCall.ID, expectedCall.MethodName) + assertCallEqual(t, expectedCall, gotCall, fmt.Sprintf("call[ID=%d]", expectedCall.ID)) + } +} + +func assertCallEqual(t *testing.T, expected, got RPCCall, path string) { + t.Helper() + assert.Equal(t, expected.ID, got.ID, "%s.ID", path) + assert.Equal(t, expected.Kind, got.Kind, "%s.Kind", path) + assert.Equal(t, expected.ServiceName, got.ServiceName, "%s.ServiceName", path) + assert.Equal(t, expected.MethodName, got.MethodName, "%s.MethodName", path) + assert.Equal(t, expected.RequestedEntityType, got.RequestedEntityType, "%s.RequestedEntityType", path) + assert.Equal(t, expected.ResponsePath, got.ResponsePath, "%s.ResponsePath", path) + + expectedDeps := slices.Clone(expected.DependentCalls) + gotDeps := slices.Clone(got.DependentCalls) + slices.Sort(expectedDeps) + slices.Sort(gotDeps) + assert.Equal(t, expectedDeps, gotDeps, "%s.DependentCalls", path) + + assertMessageEqual(t, expected.Request, got.Request, path+".Request") + assertMessageEqual(t, expected.Response, got.Response, path+".Response") +} + +func assertMessageEqual(t *testing.T, expected, got RPCMessage, path string) { + t.Helper() + assert.Equal(t, expected.Name, got.Name, "%s.Name", path) + assert.Equal(t, expected.OneOfType, got.OneOfType, "%s.OneOfType", path) + assert.Equal(t, expected.MemberTypes, got.MemberTypes, "%s.MemberTypes", path) + + assertFieldsEqual(t, expected.Fields, got.Fields, path+".Fields") + + // FragmentFields + if expected.FragmentFields == nil && got.FragmentFields == nil { + return + } + require.Equal(t, len(expected.FragmentFields), len(got.FragmentFields), "%s.FragmentFields length mismatch", path) + for typeName, expectedFields := range expected.FragmentFields { + gotFields, ok := got.FragmentFields[typeName] + require.True(t, ok, "%s.FragmentFields missing type %q", path, typeName) + assertFieldsEqual(t, expectedFields, gotFields, fmt.Sprintf("%s.FragmentFields[%s]", path, typeName)) + } +} + +func assertFieldsEqual(t *testing.T, expected, got RPCFields, path string) { + t.Helper() + require.Equal(t, len(expected), len(got), "%s length mismatch: expected %d, got %d", path, len(expected), len(got)) + + for i := range expected { + fp := fmt.Sprintf("%s[%d]", path, i) + ef := expected[i] + gf := got[i] + + assert.Equal(t, ef.Name, gf.Name, "%s.Name", fp) + assert.Equal(t, ef.Alias, gf.Alias, "%s.Alias", fp) + assert.Equal(t, ef.ProtoTypeName, gf.ProtoTypeName, "%s.ProtoTypeName", fp) + assert.Equal(t, ef.JSONPath, gf.JSONPath, "%s.JSONPath", fp) + assert.Equal(t, ef.EnumName, gf.EnumName, "%s.EnumName", fp) + assert.Equal(t, ef.StaticValue, gf.StaticValue, "%s.StaticValue", fp) + assert.Equal(t, ef.Repeated, gf.Repeated, "%s.Repeated", fp) + assert.Equal(t, ef.Optional, gf.Optional, "%s.Optional", fp) + assert.Equal(t, ef.IsListType, gf.IsListType, "%s.IsListType", fp) + assert.Equal(t, ef.ResolvePath, gf.ResolvePath, "%s.ResolvePath", fp) + assert.Equal(t, ef.ListMetadata, gf.ListMetadata, "%s.ListMetadata", fp) + + switch { + case ef.Message == nil && gf.Message == nil: + // ok + case ef.Message == nil || gf.Message == nil: + t.Errorf("%s.Message: expected nil=%v, got nil=%v", fp, ef.Message == nil, gf.Message == nil) + default: + assertMessageEqual(t, *ef.Message, *gf.Message, fp+".Message") + } + } +} + func testSchema(t *testing.T, schema string) ast.Document { t.Helper() diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go index a4986f06f..8aa1e7ef2 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go @@ -30,10 +30,11 @@ import ( ) type fetchData struct { - kind CallKind - response *astjson.Value - responsePath ast.Path - entityIndexMap entityIndexMap + kind CallKind + responseMessage *dynamicpb.Message + response *astjson.Value + responsePath ast.Path + entityIndexMap entityIndexMap } // Verify DataSource implements the resolve.DataSource interface @@ -116,6 +117,8 @@ func (d *DataSource) Load(ctx context.Context, headers http.Header, input []byte item := d.acquirePoolItem(input, 0) poolItems = append(poolItems, item) + + fmt.Println("input", string(input)) // get variables from input value, err := astjson.ParseBytesWithArena(item.Arena, input) if err != nil { @@ -169,44 +172,6 @@ func (d *DataSource) Load(ctx context.Context, headers http.Header, input []byte requestVariables = filterRepresentations(item.Arena, requestVariables, fetch.requestedEntityType) } - // if fetch.dependentCall != nil { - // requestVariables = astjson.DeepCopy(item.Arena, astJsonVariables) - - // call, found := callMap[fetch.dependentCall.ID] - // if !found { - // return nil, fmt.Errorf("dependent call %d not found", fetch.dependentCall.ID) - // } - - // contextField := fetch.request.rpcMessage.Fields.ByName(contextFieldName) - // if contextField == nil { - // return nil, fmt.Errorf("context field not found in dependent call %d", fetch.dependentCall.ID) - // } - - // contextValue := call.response.Get(contextField.JSONPath) - // if !contextValue.Exists() { - // return nil, fmt.Errorf("context value not found in dependent call %d", fetch.dependentCall.ID) - // } - - // var contextData []*astjson.Value - // if contextValue.Type() == astjson.TypeArray { - // contextData = contextValue.GetArray() - // } else { - // contextData = []*astjson.Value{contextValue} - // } - - // ov := astjson.ObjectValue(item.Arena) - // contextArr := astjson.ArrayValue(item.Arena) - // for i, data := range contextData { - // contextArr.SetArrayItem(item.Arena, i, data) - // } - // ov.Set(item.Arena, contextField.JSONPath, contextArr) - - // requestVariables, _, err = astjson.MergeValues(item.Arena, requestVariables, ov) - // if err != nil { - // return nil, err - // } - // } - buffer, err := fetch.request.createProtoWire(requestVariables) if err != nil { return nil, err @@ -223,9 +188,10 @@ func (d *DataSource) Load(ctx context.Context, headers http.Header, input []byte } fetchResult := fetchData{ - kind: fetch.kind, - response: responseJson, - responsePath: fetch.responsePath, + kind: fetch.kind, + response: responseJson, + responseMessage: responseMessage, + responsePath: fetch.responsePath, } // In case of a federated response, we need to ensure that the response is valid. diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go index bedf12f59..aee42c7fc 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go @@ -236,37 +236,7 @@ func Test_DataSource_Load_WithMockService(t *testing.T) { Definition: &schemaDoc, SubgraphName: "Products", Compiler: compiler, - Mapping: &GRPCMapping{ - Service: "Products", - QueryRPCs: RPCConfigMap[RPCConfig]{ - "complexFilterType": { - RPC: "QueryComplexFilterType", - Request: "QueryComplexFilterTypeRequest", - Response: "QueryComplexFilterTypeResponse", - }, - }, - Fields: map[string]FieldMap{ - "Query": { - "complexFilterType": { - TargetName: "complex_filter_type", - ArgumentMappings: map[string]string{ - "filter": "filter", - }, - }, - }, - "FilterType": { - "name": { - TargetName: "name", - }, - "filterField1": { - TargetName: "filter_field_1", - }, - "filterField2": { - TargetName: "filter_field_2", - }, - }, - }, - }, + Mapping: testMapping(), }) require.NoError(t, err) @@ -2630,7 +2600,7 @@ func Test_DataSource_Load_WithNestedLists(t *testing.T) { collaborations } }`, - vars: `{"variables":{"input":{"name":"New Author","email":"author@example.com","skills":["Go","GraphQL","gRPC"],"languages":["English","Spanish"],"socialLinks":["twitter.com/author","github.com/author"],"teamsByProject":[["Alice","Bob"],["Charlie","David","Eve"]],"collaborations":[["Project1","Project2"],["Project3"]]}}}`, + vars: `{"variables":{"input":{"name":"New Author","email":"author@example.com","skills":["Go","GraphQL","gRPC"],"languages":["English","Spanish"],"socialLinks":["twitter.com/author","github.com/author"],"teamsByProject":[["Alice","Bob"],["Charlie","David","Eve"]],"collaborations":[["Project1","Project2"],["Project3"]],"favoriteCategories":[]}}}`, validate: func(t *testing.T, data map[string]interface{}) { createAuthor, ok := data["createAuthor"].(map[string]interface{}) require.True(t, ok, "createAuthor should be an object") @@ -3441,7 +3411,7 @@ func Test_DataSource_Load_WithNestedLists(t *testing.T) { "viewCounts":[300,400,500], "tagGroups":[["updated","tags"],["bulk","update"]], "commentThreads":[["Updated comment"]], - "relatedTopics":[["updated","topics"]], + "relatedTopics":[["updated","topics"]] } ]}}`, validate: func(t *testing.T, data map[string]interface{}) { diff --git a/v2/pkg/engine/datasource/grpc_datasource/program.go b/v2/pkg/engine/datasource/grpc_datasource/program.go index 43f2883cb..9b6cd7208 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/program.go +++ b/v2/pkg/engine/datasource/grpc_datasource/program.go @@ -5,6 +5,7 @@ import ( "github.com/wundergraph/astjson" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" + protoref "google.golang.org/protobuf/reflect/protoreflect" ) type program struct { @@ -30,6 +31,7 @@ type fetchProgram struct { type request struct { message *programMessage + fields []programField context *fetchRequestContext // The wire message will be created fromt the // request structure. @@ -81,6 +83,11 @@ func (f *request) createProtoWire(requestVariables *astjson.Value) ([]byte, erro return wire, nil } +// TODO: Implement this +func (f *request) createProtoWireWithContext(requestVariables *astjson.Value, contextMessage protoref.Message) ([]byte, error) { + return nil, nil +} + func compileProgram(plan *RPCExecutionPlan, runtime *runtimeSchema) (*program, error) { stageIndexes, err := compileStageIndexes(plan) if err != nil { @@ -173,7 +180,7 @@ func compileFetch(call *RPCCall, runtime *runtimeSchema, dependentCall *RPCCall) return fetchProgram{}, fmt.Errorf("dependent message not found for method %s", dependentCall.MethodName) } - fetchRequest, err := compileFetchRequestWithContext(requestMessage, dependentMessage, call.Request) + fetchRequest, err := compileFetchRequestWithContext(runtime, requestMessage, dependentMessage, &call.Request) if err != nil { return fetchProgram{}, err } @@ -191,23 +198,30 @@ func compileFetch(call *RPCCall, runtime *runtimeSchema, dependentCall *RPCCall) } func compileFetchRequest(runtime *runtimeSchema, rpcMessage *RPCMessage, message *runtimeMessage) (*request, error) { - requestMessage, err := compileMessage(runtime, rpcMessage, message) + requestMessage, err := compileMessage(runtime, rpcMessage, message, make(map[string]*programMessage)) if err != nil { return nil, err } return &request{ message: requestMessage, + fields: requestMessage.fields, }, nil } -func compileMessage(runtime *runtimeSchema, rpcMessage *RPCMessage, rtMessage *runtimeMessage) (*programMessage, error) { +func compileMessage(runtime *runtimeSchema, rpcMessage *RPCMessage, rtMessage *runtimeMessage, cycleMap map[string]*programMessage) (*programMessage, error) { + if seen, ok := cycleMap[rpcMessage.Name]; ok { + return seen, nil + } + msg := &programMessage{ name: rpcMessage.Name, runtime: rtMessage, fields: make([]programField, 0, len(rpcMessage.Fields)), } + cycleMap[rpcMessage.Name] = msg + for _, f := range rpcMessage.Fields { rtFieldMessage := runtime.getMessageByName(rpcMessage.Name) if rtFieldMessage == nil { @@ -219,7 +233,7 @@ func compileMessage(runtime *runtimeSchema, rpcMessage *RPCMessage, rtMessage *r return nil, fmt.Errorf("field not found for name %s", f.Name) } - requestField, err := compileField(runtime, f, runtimeField) + requestField, err := compileField(runtime, f, runtimeField, cycleMap) if err != nil { return nil, err } @@ -229,7 +243,7 @@ func compileMessage(runtime *runtimeSchema, rpcMessage *RPCMessage, rtMessage *r return msg, nil } -func compileField(runtime *runtimeSchema, rpcField RPCField, rtField *runtimeField) (programField, error) { +func compileField(runtime *runtimeSchema, rpcField RPCField, rtField *runtimeField, cycleMap map[string]*programMessage) (programField, error) { f := programField{ runtime: rtField, dataType: rpcField.ProtoTypeName, @@ -247,7 +261,7 @@ func compileField(runtime *runtimeSchema, rpcField RPCField, rtField *runtimeFie return programField{}, fmt.Errorf("child message not found for name %s", rpcField.Message.Name) } - childMessage, err := compileMessage(runtime, rpcField.Message, rtField.message) + childMessage, err := compileMessage(runtime, rpcField.Message, rtField.message, cycleMap) if err != nil { return programField{}, err } @@ -258,9 +272,17 @@ func compileField(runtime *runtimeSchema, rpcField RPCField, rtField *runtimeFie return f, nil } -func compileFetchRequestWithContext(message *runtimeMessage, dependentMessage *runtimeMessage, rpcMessage RPCMessage) (*request, error) { +func compileFetchRequestWithContext(runtime *runtimeSchema, message *runtimeMessage, dependentMessage *runtimeMessage, rpcMessage *RPCMessage) (*request, error) { request := &request{} + requestMessage, err := compileMessage(runtime, rpcMessage, message, make(map[string]*programMessage)) + if err != nil { + return nil, err + } + + request.message = requestMessage + request.fields = requestMessage.fields + // context and field_args for _, field := range rpcMessage.Fields { switch field.Name { diff --git a/v2/pkg/engine/datasource/grpc_datasource/wire.go b/v2/pkg/engine/datasource/grpc_datasource/wire.go index c9c616aad..30e1d6f0e 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/wire.go +++ b/v2/pkg/engine/datasource/grpc_datasource/wire.go @@ -58,10 +58,14 @@ func compileWireMessageFromRequest(schema *runtimeSchema, request *request) (*wi return nil, fmt.Errorf("unable to compile wire message from request: request is nil") } - return compileWireMessage(schema, request.message) + return compileWireMessage(schema, request.message, make(map[string]*wireMessage)) } -func compileWireMessage(schema *runtimeSchema, msg *programMessage) (*wireMessage, error) { +func compileWireMessage(schema *runtimeSchema, msg *programMessage, cycleMap map[string]*wireMessage) (*wireMessage, error) { + if seen, ok := cycleMap[msg.name]; ok { + return seen, nil + } + if msg == nil { return nil, fmt.Errorf("message not found for fetch request") } @@ -72,6 +76,8 @@ func compileWireMessage(schema *runtimeSchema, msg *programMessage) (*wireMessag fields: make([]wireField, len(messageFields)), } + cycleMap[msg.name] = wm + for i := range messageFields { messageField := messageFields[i] @@ -108,7 +114,7 @@ func compileWireMessage(schema *runtimeSchema, msg *programMessage) (*wireMessag } } - child, err := compileWireMessage(schema, messageField.child) + child, err := compileWireMessage(schema, messageField.child, cycleMap) if err != nil { return nil, err } @@ -121,78 +127,8 @@ func compileWireMessage(schema *runtimeSchema, msg *programMessage) (*wireMessag } return wm, nil - } -// func compileWireMessage(runtime *runtimeSchema, rpcMessage *RPCMessage, message *runtimeMessage) (*wireMessage, error) { -// if message == nil { -// return nil, fmt.Errorf("message not found for fetch request") -// } -// msg := &wireMessage{ -// fields: make([]wireField, len(rpcMessage.Fields)), -// } - -// // TODO: This is possible for `@requires` fields, but not yet supported. -// if rpcMessage.OneOfType != OneOfTypeNone { -// return nil, fmt.Errorf("oneof type not supported yet") -// } - -// for i := range rpcMessage.Fields { -// rpcField := &rpcMessage.Fields[i] - -// field, ok := message.fieldsByName[rpcField.Name] -// if !ok { -// return nil, fmt.Errorf("field not found for name %s", rpcField.Name) -// } - -// wf := wireField{ -// number: field.desc.Number(), -// runtimeMessage: field.message, -// dataType: rpcField.ProtoTypeName, -// wireType: getWireType(field.dataType), -// jsonPath: rpcField.JSONPath, -// resolvePath: rpcField.ResolvePath, -// staticValue: rpcField.StaticValue, -// optional: rpcField.Optional, -// repeated: rpcField.Repeated, -// listMetadata: rpcField.ListMetadata, -// } - -// if rpcField.EnumName != "" { -// rtEnum, ok := runtime.enumByName[rpcField.EnumName] -// if !ok { -// return nil, fmt.Errorf("enum not found for name %s", rpcField.EnumName) -// } - -// wf.runtimeEnum = rtEnum -// } - -// if rpcField.Message != nil { -// fieldMessage := field.message -// // we we are using wrapper messages, they are compiled from the protobuf schema but doesn't match with the RPC planner schema. -// // We need to resolve the correct message from the runtime schema. -// if rpcField.Message.Name != fieldMessage.name { -// fieldMessage = runtime.getMessageByName(rpcField.Message.Name) -// if fieldMessage == nil { -// return nil, fmt.Errorf("message not found for name %s", rpcField.Message.Name) -// } -// } - -// child, err := compileWireMessage(runtime, rpcField.Message, fieldMessage) -// if err != nil { -// return nil, err -// } - -// wf.child = child -// } - -// wf.tag = protowire.AppendTag(nil, wf.number, wf.wireType) -// msg.fields[i] = wf -// } - -// return msg, nil -// } - // createProtoWire creates a proto wire from the wire plan. func (w *wireMessage) createProtoWire(data *astjson.Value) ([]byte, error) { buf := bytes.NewBuffer(make([]byte, 0, minBufferSize)) diff --git a/v2/pkg/engine/datasource/grpc_datasource/wire_test.go b/v2/pkg/engine/datasource/grpc_datasource/wire_test.go index 0b4a535e1..e9623f84f 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/wire_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/wire_test.go @@ -153,7 +153,7 @@ func newWireTestRuntime(t *testing.T) *runtimeSchema { func compileTestWireMessage(t *testing.T, runtime *runtimeSchema, message *programMessage) *wireMessage { t.Helper() - wm, err := compileWireMessage(runtime, message) + wm, err := compileWireMessage(runtime, message, make(map[string]*wireMessage)) require.NoError(t, err) return wm } @@ -161,7 +161,7 @@ func compileTestWireMessage(t *testing.T, runtime *runtimeSchema, message *progr func compileTestProgramMessage(t *testing.T, runtime *runtimeSchema, rpcMessage *RPCMessage, rtMessage *runtimeMessage) *programMessage { t.Helper() - message, err := compileMessage(runtime, rpcMessage, rtMessage) + message, err := compileMessage(runtime, rpcMessage, rtMessage, make(map[string]*programMessage)) require.NoError(t, err) return message @@ -292,7 +292,7 @@ func TestCompileWireMessage(t *testing.T) { // Optional + IsListType: compileWireMessage must not treat this as a wrapper scalar. // Currently this errors because it tries to wrap in google.protobuf.*Value and looks for "value" in ListOfString. - wm, err := compileWireMessage(runtime, message) + wm, err := compileWireMessage(runtime, message, make(map[string]*wireMessage)) require.NoError(t, err) From f88a09cad7b972fdaa427e12abd63b34da3a2ed9 Mon Sep 17 00:00:00 2001 From: Ludwig Bedacht Date: Fri, 24 Apr 2026 15:53:00 +0200 Subject: [PATCH 08/12] feat: support oneof types in the wire for requires fields --- .../datasource/grpc_datasource/compiler.go | 42 -- .../grpc_datasource/grpc_datasource.go | 72 ++- .../datasource/grpc_datasource/program.go | 110 +++- .../datasource/grpc_datasource/runtime.go | 9 + .../engine/datasource/grpc_datasource/wire.go | 500 ++++++++++++++++-- 5 files changed, 612 insertions(+), 121 deletions(-) diff --git a/v2/pkg/engine/datasource/grpc_datasource/compiler.go b/v2/pkg/engine/datasource/grpc_datasource/compiler.go index 08b974e07..40a25977f 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/compiler.go +++ b/v2/pkg/engine/datasource/grpc_datasource/compiler.go @@ -377,48 +377,6 @@ func (s *ServiceCall) MethodFullName() string { return builder.String() } -// func (p *RPCCompiler) CompileFetches(graph *DependencyGraph, fetches []FetchItem, inputData gjson.Result) ([]Invocation, error) { -// invocations := make([]Invocation, 0, len(fetches)) - -// resultChan := make(chan Invocation, len(fetches)) -// errChan := make(chan error, len(fetches)) - -// wg := sync.WaitGroup{} -// wg.Add(len(fetches)) - -// for _, node := range fetches { -// go func() { -// defer wg.Done() -// invocation, err := p.CompileNode(graph, node, inputData) -// if err != nil { -// errChan <- err -// return -// } - -// resultChan <- invocation -// node.Invocation = &invocation -// }() -// } - -// close(resultChan) -// close(errChan) - -// var joinErr error -// for err := range errChan { -// joinErr = errors.Join(joinErr, err) -// } - -// if joinErr != nil { -// return nil, joinErr -// } - -// for invocation := range resultChan { -// invocations = append(invocations, invocation) -// } - -// return invocations, nil -// } - func (p *RPCCompiler) CompileFetches(graph *DependencyGraph, fetches []FetchItem, inputData gjson.Result) ([]ServiceCall, error) { serviceCalls := make([]ServiceCall, 0, len(fetches)) diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go index 8aa1e7ef2..b9cf3e761 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go @@ -35,6 +35,7 @@ type fetchData struct { response *astjson.Value responsePath ast.Path entityIndexMap entityIndexMap + skipped bool } // Verify DataSource implements the resolve.DataSource interface @@ -118,7 +119,6 @@ func (d *DataSource) Load(ctx context.Context, headers http.Header, input []byte item := d.acquirePoolItem(input, 0) poolItems = append(poolItems, item) - fmt.Println("input", string(input)) // get variables from input value, err := astjson.ParseBytesWithArena(item.Arena, input) if err != nil { @@ -164,19 +164,16 @@ func (d *DataSource) Load(ctx context.Context, headers http.Header, input []byte results := make([]fetchData, 0, len(stage.fetches)) for _, fetch := range stage.fetches { - // TODO: unmarshal with our own codec logic - responseMessage := dynamicpb.NewMessage(fetch.response.responseType.desc) - - requestVariables := astJsonVariables - if fetch.requestedEntityType != "" { - requestVariables = filterRepresentations(item.Arena, requestVariables, fetch.requestedEntityType) - } - - buffer, err := fetch.request.createProtoWire(requestVariables) + buffer, skip, err := createProtoWire(item.Arena, &fetch, callMap, astJsonVariables) if err != nil { return nil, err } + if skip { + continue + } + + responseMessage := dynamicpb.NewMessage(fetch.response.responseType.desc) err = d.cc.Invoke(ctx, fetch.methodFullName, NewPreWiredInputMessage(buffer), responseMessage) if err != nil { return builder.writeErrorBytes(err), nil @@ -238,6 +235,61 @@ func (d *DataSource) acquirePoolItem(input []byte, index int) *arena.PoolItem { return item } +func createProtoWire(a arena.Arena, fetch *fetchProgram, callMap map[int]fetchData, requestVariables *astjson.Value) ([]byte, bool, error) { + var buffer []byte + var err error + + switch fetch.kind { + case CallKindStandard: + buffer, err = fetch.request.createProtoWire(requestVariables) + if err != nil { + return nil, false, err + } + case CallKindEntity, CallKindRequired: + if fetch.requestedEntityType != "" { + requestVariables = filterRepresentations(a, requestVariables, fetch.requestedEntityType) + } + + buffer, err = fetch.request.createProtoWire(requestVariables) + if err != nil { + return nil, false, err + } + case CallKindResolve: + contextFetch, found := callMap[fetch.dependentCall.ID] + if !found { + return nil, false, fmt.Errorf("context fetch not found for dependent call %d", fetch.dependentCall.ID) + } + + if contextFetch.responseMessage == nil || contextFetch.skipped { + fetchResult := fetchData{ + kind: fetch.kind, + responsePath: fetch.responsePath, + skipped: true, + } + + callMap[fetch.id] = fetchResult + } + + buffer, err = fetch.request.createProtoWireWithContext(a, requestVariables, contextFetch.responseMessage) + if err != nil { + if err == errShouldSkip { + fetchResult := fetchData{ + kind: fetch.kind, + responsePath: fetch.responsePath, + skipped: true, + } + + callMap[fetch.id] = fetchResult + return nil, true, nil + } + + return nil, false, err + } + } + + return buffer, false, nil +} + // LoadWithFiles implements resolve.DataSource interface. // Similar to Load, but handles file uploads if needed. // diff --git a/v2/pkg/engine/datasource/grpc_datasource/program.go b/v2/pkg/engine/datasource/grpc_datasource/program.go index 9b6cd7208..d2080c566 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/program.go +++ b/v2/pkg/engine/datasource/grpc_datasource/program.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/wundergraph/astjson" + "github.com/wundergraph/go-arena" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" protoref "google.golang.org/protobuf/reflect/protoreflect" ) @@ -39,9 +40,12 @@ type request struct { } type programMessage struct { - name string - runtime *runtimeMessage - fields []programField + name string + runtime *runtimeMessage + oneOfType OneOfType + oneOfFields map[string][]programField + memberTypes []string + fields []programField } type programField struct { @@ -58,11 +62,14 @@ type programField struct { type fetchRequestContext struct { message *runtimeMessage + context *runtimeMessage fields []fetchRequestContextField } type fetchRequestContextField struct { runtime *runtimeField + jsonName string + p ast.Path resolvePath resolvePath } @@ -75,17 +82,12 @@ type response struct { } func (f *request) createProtoWire(requestVariables *astjson.Value) ([]byte, error) { - wire, err := f.wire.createProtoWire(requestVariables) - if err != nil { - return nil, err - } - - return wire, nil + return f.wire.createProtoWire(requestVariables) } // TODO: Implement this -func (f *request) createProtoWireWithContext(requestVariables *astjson.Value, contextMessage protoref.Message) ([]byte, error) { - return nil, nil +func (f *request) createProtoWireWithContext(a arena.Arena, requestVariables *astjson.Value, contextMessage protoref.Message) ([]byte, error) { + return f.wire.createProtoWireWithContext(a, requestVariables, f.context, contextMessage) } func compileProgram(plan *RPCExecutionPlan, runtime *runtimeSchema) (*program, error) { @@ -167,7 +169,7 @@ func compileFetch(call *RPCCall, runtime *runtimeSchema, dependentCall *RPCCall) } switch f.kind { - case CallKindStandard, CallKindEntity: + case CallKindStandard, CallKindEntity, CallKindRequired: fetchRequest, err := compileFetchRequest(runtime, &call.Request, requestMessage) if err != nil { return fetchProgram{}, err @@ -209,6 +211,17 @@ func compileFetchRequest(runtime *runtimeSchema, rpcMessage *RPCMessage, message }, nil } +func getOneOfDescriptor(rtMessage *runtimeMessage, oneOfType OneOfType) protoref.OneofDescriptor { + switch oneOfType { + case OneOfTypeInterface: + return rtMessage.desc.Oneofs().ByName(protoref.Name("instance")) + case OneOfTypeUnion: + return rtMessage.desc.Oneofs().ByName(protoref.Name("value")) + } + + return nil +} + func compileMessage(runtime *runtimeSchema, rpcMessage *RPCMessage, rtMessage *runtimeMessage, cycleMap map[string]*programMessage) (*programMessage, error) { if seen, ok := cycleMap[rpcMessage.Name]; ok { return seen, nil @@ -222,6 +235,59 @@ func compileMessage(runtime *runtimeSchema, rpcMessage *RPCMessage, rtMessage *r cycleMap[rpcMessage.Name] = msg + if rpcMessage.IsOneOf() { + msg.oneOfType = rpcMessage.OneOfType + msg.memberTypes = rpcMessage.MemberTypes + msg.oneOfFields = make(map[string][]programField) + + for _, memberType := range rpcMessage.MemberTypes { + fragmentFields, ok := rpcMessage.FragmentFields[memberType] + if !ok { + continue + } + + oneOfDescriptor := getOneOfDescriptor(rtMessage, rpcMessage.OneOfType) + if oneOfDescriptor == nil { + return nil, fmt.Errorf("oneof descriptor not found for message %s", rpcMessage.Name) + } + + fullName := "" + for i := range oneOfDescriptor.Fields().Len() { + field := oneOfDescriptor.Fields().Get(i) + if field.Kind() != protoref.MessageKind { + continue + } + + if field.Message().Name() == protoref.Name(memberType) { + fullName = string(field.Message().FullName()) + break + } + } + + memberTypeMessage := runtime.getMessageByFullName(fullName) + if memberTypeMessage == nil { + return nil, fmt.Errorf("message not found for name %s", fullName) + } + + oneOfFields := make([]programField, 0, len(fragmentFields)) + for _, fragmentField := range fragmentFields { + runtimeField := memberTypeMessage.fieldsByName[fragmentField.Name] + if runtimeField == nil { + return nil, fmt.Errorf("field not found for name %s", fragmentField.Name) + } + + requestField, err := compileField(runtime, fragmentField, runtimeField, cycleMap) + if err != nil { + return nil, err + } + oneOfFields = append(oneOfFields, requestField) + msg.oneOfFields[memberType] = oneOfFields + } + } + + return msg, nil + } + for _, f := range rpcMessage.Fields { rtFieldMessage := runtime.getMessageByName(rpcMessage.Name) if rtFieldMessage == nil { @@ -311,8 +377,8 @@ func compileFetchRequestWithContext(runtime *runtimeSchema, message *runtimeMess return request, nil } -func compileFetchRequestContext(message, dependentMessage *runtimeMessage, rpcMessage *RPCMessage) (*fetchRequestContext, error) { - if message == nil || dependentMessage == nil { +func compileFetchRequestContext(message, contextMessage *runtimeMessage, rpcMessage *RPCMessage) (*fetchRequestContext, error) { + if message == nil || contextMessage == nil { return nil, fmt.Errorf("unable to compile fetch request context: message or dependent message is nil") } @@ -322,9 +388,25 @@ func compileFetchRequestContext(message, dependentMessage *runtimeMessage, rpcMe fetchRequestContext := &fetchRequestContext{ message: message, + context: contextMessage, fields: make([]fetchRequestContextField, 0, len(rpcMessage.Fields)), } + for _, field := range rpcMessage.Fields { + rtField, found := message.fieldsByName[field.Name] + if !found { + return nil, fmt.Errorf("field not found for name %s", field.Name) + } + + fetchRequestContextField := &fetchRequestContextField{ + runtime: rtField, + p: field.ResolvePath, + jsonName: field.JSONPath, + // resolvePath: field.ResolvePath, + } + fetchRequestContext.fields = append(fetchRequestContext.fields, *fetchRequestContextField) + } + return fetchRequestContext, nil } diff --git a/v2/pkg/engine/datasource/grpc_datasource/runtime.go b/v2/pkg/engine/datasource/grpc_datasource/runtime.go index d2d88e44d..cb9ef5d67 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/runtime.go +++ b/v2/pkg/engine/datasource/grpc_datasource/runtime.go @@ -132,6 +132,15 @@ func (r *runtimeSchema) getMessageByName(name string) *runtimeMessage { return message } +func (r *runtimeSchema) getMessageByFullName(fullname string) *runtimeMessage { + message, found := r.messageByFullname[fullname] + if !found { + return nil + } + + return message +} + func (m *runtimeMessage) newEmptyMessage() protoref.Message { return m.dynamicType.New() } diff --git a/v2/pkg/engine/datasource/grpc_datasource/wire.go b/v2/pkg/engine/datasource/grpc_datasource/wire.go index 30e1d6f0e..c17cf999c 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/wire.go +++ b/v2/pkg/engine/datasource/grpc_datasource/wire.go @@ -2,13 +2,22 @@ package grpcdatasource import ( "bytes" + "errors" "fmt" "math" + "reflect" + "strconv" + "strings" "github.com/wundergraph/astjson" + "github.com/wundergraph/go-arena" + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" "google.golang.org/protobuf/encoding/protowire" + protoref "google.golang.org/protobuf/reflect/protoreflect" ) +var errShouldSkip = errors.New("skip") + type PreWiredInputMessage struct { size int buffer []byte @@ -30,23 +39,26 @@ func (c *PreWiredInputMessage) wire() ([]byte, error) { } type wireMessage struct { - fields []wireField - oneOfType OneOfType + fields []wireField + runtime *runtimeMessage + oneOfType OneOfType + oneOfFields map[string][]wireField } type wireField struct { - tag []byte - number protowire.Number - dataType DataType - wireType protowire.Type - runtimeMessage *runtimeMessage - runtimeEnum *runtimeEnum - staticValue string - jsonPath string - optional bool - repeated bool - listMetadata *ListMetadata - child *wireMessage + tag []byte + runtime *runtimeField + number protowire.Number + dataType DataType + wireType protowire.Type + runtimeEnum *runtimeEnum + staticValue string + jsonPath string + optional bool + repeated bool + listMetadata *ListMetadata + fieldMessage *runtimeMessage + child *wireMessage } const ( @@ -73,24 +85,57 @@ func compileWireMessage(schema *runtimeSchema, msg *programMessage, cycleMap map messageFields := msg.fields wm := &wireMessage{ - fields: make([]wireField, len(messageFields)), + runtime: msg.runtime, + fields: make([]wireField, len(messageFields)), + oneOfType: msg.oneOfType, } cycleMap[msg.name] = wm + if wm.oneOfType != OneOfTypeNone { + wm.oneOfFields = make(map[string][]wireField, len(msg.memberTypes)) + + for _, memberType := range msg.memberTypes { + + fields, err := compileMessageFields(schema, msg.oneOfFields[memberType], cycleMap) + if err != nil { + return nil, err + } + + wm.oneOfFields[memberType] = fields + } + } + + fields, err := compileMessageFields(schema, messageFields, cycleMap) + if err != nil { + return nil, err + } + + wm.fields = fields + return wm, nil +} + +func compileMessageFields(schema *runtimeSchema, messageFields []programField, cycleMap map[string]*wireMessage) ([]wireField, error) { + if len(messageFields) == 0 { + return nil, nil + } + + fields := make([]wireField, len(messageFields)) + for i := range messageFields { messageField := messageFields[i] wf := wireField{ - number: messageField.runtime.desc.Number(), - runtimeMessage: messageField.runtime.message, - dataType: messageField.dataType, - wireType: getWireType(messageField.runtime.dataType), - jsonPath: messageField.jsonPath, - staticValue: messageField.staticValue, - optional: messageField.optional, - repeated: messageField.repeated, - listMetadata: messageField.listMetadata, + runtime: messageField.runtime, + number: messageField.runtime.desc.Number(), + fieldMessage: messageField.runtime.message, + dataType: messageField.dataType, + wireType: getWireType(messageField.runtime.dataType), + jsonPath: messageField.jsonPath, + staticValue: messageField.staticValue, + optional: messageField.optional, + repeated: messageField.repeated, + listMetadata: messageField.listMetadata, } if messageField.enumName != "" { @@ -123,10 +168,59 @@ func compileWireMessage(schema *runtimeSchema, msg *programMessage, cycleMap map } wf.tag = protowire.AppendTag(nil, wf.number, wf.wireType) - wm.fields[i] = wf + fields[i] = wf } - return wm, nil + return fields, nil +} + +func (w *wireMessage) createProtoWireWithContext(a arena.Arena, data *astjson.Value, context *fetchRequestContext, contextMessage protoref.Message) ([]byte, error) { + buf := bytes.NewBuffer(make([]byte, 0, minBufferSize)) + + contextValues := make([]map[string]protoref.Value, 0) + for _, contextField := range context.fields { + values := resolveContextDataForPath(contextMessage, contextField.p) + + for index, value := range values { + if index >= len(contextValues) { + contextValues = append(contextValues, make(map[string]protoref.Value)) + } + + contextValues[index][contextField.jsonName] = value + } + } + + if len(contextValues) == 0 { + return nil, errShouldSkip + } + + contextVariables := astjson.ArrayValue(a) + arrayIndex := 0 + for _, contextValues := range contextValues { + contextVariable := astjson.ObjectValue(a) + for fieldName, contextValue := range contextValues { + contextVariable.Set(a, fieldName, convertProtoRefValue(a, contextValue)) + } + + contextVariables.SetArrayItem(a, arrayIndex, contextVariable) + arrayIndex++ + } + + for _, field := range w.fields { + if field.runtime.name == "context" { + if err := field.appendFieldWire(buf, contextVariables); err != nil { + return nil, err + } + + continue + } + + if err := field.appendFieldWire(buf, data); err != nil { + return nil, err + } + } + + return buf.Bytes(), nil } // createProtoWire creates a proto wire from the wire plan. @@ -140,6 +234,45 @@ func (w *wireMessage) createProtoWire(data *astjson.Value) ([]byte, error) { // appendProtoWire encodes the message fields into the given buffer. func (w *wireMessage) appendProtoWire(buf *bytes.Buffer, data *astjson.Value) error { + if w.oneOfType != OneOfTypeNone { + if !data.Exists("__typename") { + return fmt.Errorf("__typename is required for oneof fields") + } + + typeName := string(data.Get("__typename").GetStringBytes()) + + oneOfDescriptor := w.oneOfTypeDecriptor() + if oneOfDescriptor == nil { + return fmt.Errorf("oneof descriptor not found for message %s", w.runtime.name) + } + + fields := oneOfDescriptor.Fields() + for i := range fields.Len() { + field := fields.Get(i) + if field.Kind() != protoref.MessageKind { + continue + } + + if field.Message().Name() == protoref.Name(typeName) { + fieldNumber := field.Number() + buf.Write(protowire.AppendTag(buf.AvailableBuffer(), fieldNumber, protowire.BytesType)) + break + } + } + + oneOfFields := w.oneOfFields[typeName] + fieldsBuffer := bytes.NewBuffer(make([]byte, 0, minBufferSize)) + for _, field := range oneOfFields { + if err := field.appendFieldWire(fieldsBuffer, data); err != nil { + return err + } + } + + buf.Write(protowire.AppendBytes(buf.AvailableBuffer(), fieldsBuffer.Bytes())) + fieldsBuffer.Reset() + return nil + } + for _, field := range w.fields { if err := field.appendFieldWire(buf, data); err != nil { return err @@ -148,6 +281,22 @@ func (w *wireMessage) appendProtoWire(buf *bytes.Buffer, data *astjson.Value) er return nil } +func (w *wireMessage) oneOfTypeDecriptor() protoref.OneofDescriptor { + oneOfs := w.runtime.desc.Oneofs() + if oneOfs == nil || oneOfs.Len() == 0 { + return nil + } + + switch w.oneOfType { + case OneOfTypeInterface: + return oneOfs.ByName(protoref.Name("instance")) + case OneOfTypeUnion: + return oneOfs.ByName(protoref.Name("value")) + default: + return nil + } +} + func (f *wireField) appendFieldWire(buf *bytes.Buffer, data *astjson.Value) error { var fieldData *astjson.Value @@ -213,7 +362,7 @@ func (f *wireField) appendListFieldValue(buf *bytes.Buffer, data *astjson.Value, md := f.listMetadata.LevelInfo[level] level++ - runtimeMsg := f.runtimeMessage + runtimeMsg := f.fieldMessage if runtimeMsg == nil { return fmt.Errorf("runtime message not found for field %s", f.jsonPath) } @@ -247,12 +396,12 @@ func (f *wireField) appendListFieldValue(buf *bytes.Buffer, data *astjson.Value, for i := range elements { iwf := wireField{ - number: itemsField.desc.Number(), - dataType: f.dataType, - wireType: getWireType(itemsField.dataType), - runtimeMessage: itemsField.message, - listMetadata: f.listMetadata, - child: f.child, + number: itemsField.desc.Number(), + dataType: f.dataType, + wireType: getWireType(itemsField.dataType), + fieldMessage: itemsField.message, + listMetadata: f.listMetadata, + child: f.child, } iwf.tag = protowire.AppendTag(nil, iwf.number, iwf.wireType) @@ -272,21 +421,21 @@ func (f *wireField) appendListFieldValue(buf *bytes.Buffer, data *astjson.Value, } func (f *wireField) appendOptionalScalarFieldValue(buf *bytes.Buffer, data *astjson.Value) error { - if f.runtimeMessage == nil { + if f.fieldMessage == nil { return fmt.Errorf("runtime message not found for optional scalar field %s but was expected", f.jsonPath) } - wrapperField, ok := f.runtimeMessage.fieldsByName[knownTypeOptionalFieldValueName] + wrapperField, ok := f.fieldMessage.fieldsByName[knownTypeOptionalFieldValueName] if !ok { - return fmt.Errorf("wrapper field not found for message %s but was expected", f.runtimeMessage.name) + return fmt.Errorf("wrapper field not found for message %s but was expected", f.fieldMessage.name) } wf := wireField{ - number: wrapperField.desc.Number(), - dataType: wrapperField.dataType, - wireType: getWireType(wrapperField.dataType), - jsonPath: f.jsonPath, - runtimeMessage: wrapperField.message, + number: wrapperField.desc.Number(), + dataType: wrapperField.dataType, + wireType: getWireType(wrapperField.dataType), + jsonPath: f.jsonPath, + fieldMessage: wrapperField.message, } wf.tag = protowire.AppendTag(nil, wf.number, wf.wireType) @@ -313,8 +462,8 @@ func (f *wireField) appendFieldValue(buf *bytes.Buffer, data *astjson.Value) err return err } buf.Write(f.tag) - buf.Write(protowire.AppendVarint(buf.AvailableBuffer(), uint64(childBuf.Len()))) - buf.Write(childBuf.Bytes()) + buf.Write(protowire.AppendBytes(buf.AvailableBuffer(), childBuf.Bytes())) + childBuf.Reset() return nil } @@ -357,21 +506,29 @@ func getUint64Value(data *astjson.Value) uint64 { } func (f *wireField) getEnumValue(data *astjson.Value) (uint64, error) { - enumValueName := data.GetStringBytes() - if len(enumValueName) == 0 { - return 0, fmt.Errorf("enum value name is required for enum field %s", f.jsonPath) - } + switch data.Type() { + case astjson.TypeNumber: + return data.GetUint64(), nil + case astjson.TypeString: + enumValueName := data.GetStringBytes() + if len(enumValueName) == 0 { + return 0, fmt.Errorf("enum value name is required for enum field %s", f.jsonPath) + } - ev, found := f.runtimeEnum.valuesByName[string(enumValueName)] - if !found { - return 0, fmt.Errorf("enum value not found for name %s", string(enumValueName)) - } + ev, found := f.runtimeEnum.valuesByName[string(enumValueName)] + if !found { + return 0, fmt.Errorf("enum value not found for name %s", string(enumValueName)) + } - if ev.value < 0 { - return 0, fmt.Errorf("enum value %s is negative for enum field %s", string(enumValueName), f.jsonPath) - } + if ev.value < 0 { + return 0, fmt.Errorf("enum value %s is negative for enum field %s", string(enumValueName), f.jsonPath) + } + + return uint64(ev.value), nil - return uint64(ev.value), nil + default: + return 0, fmt.Errorf("unsupported enum type %s", data.Type()) + } } func getWireType(dataType DataType) protowire.Type { @@ -388,3 +545,236 @@ func getWireType(dataType DataType) protowire.Type { return protowire.VarintType } } + +// resolveContextDataForPath resolves the data for a given path in the context message. +func resolveContextDataForPath(message protoref.Message, path ast.Path) []protoref.Value { + if path.Len() == 0 { + return nil + } + + segment := path[0] + path = path[1:] + + msg, fd := getMessageField(message, segment.FieldName.String()) + if !msg.IsValid() { + return nil + } + + if fd.IsList() { + return resolveListDataForPath(msg.List(), fd, path) + } + + return resolveDataForPath(msg.Message(), path) +} + +// resolveListDataForPath resolves the data for a given path in a list message. +func resolveListDataForPath(message protoref.List, fd protoref.FieldDescriptor, path ast.Path) []protoref.Value { + if !message.IsValid() { + return nil + } + + if path.Len() == 0 { + return nil + } + + result := make([]protoref.Value, 0, message.Len()) + + for i := range message.Len() { + item := message.Get(i) + + switch fd.Kind() { + case protoref.MessageKind: + values := resolveDataForPath(item.Message(), path) + + for _, val := range values { + if list, isList := val.Interface().(protoref.List); isList { + values := resolveListDataForPath(list, fd, path[1:]) + result = append(result, values...) + continue + } else { + result = append(result, val) + } + } + + default: + result = append(result, item) + } + } + + return result +} + +// resolveDataForPath resolves the data for a given path in a message. +func resolveDataForPath(message protoref.Message, path ast.Path) []protoref.Value { + if !message.IsValid() { + return nil + } + + if path.Len() == 0 { + return nil + } + + segment := path[0] + + if fn := segment.FieldName.String(); strings.HasPrefix(fn, "@") { + list := resolveUnderlyingList(message, fn) + + result := make([]protoref.Value, 0, len(list)) + for _, item := range list { + result = append(result, resolveDataForPath(item.Message(), path[1:])...) + } + + return result + } + + field, fd := getMessageField(message, segment.FieldName.String()) + if !field.IsValid() { + return nil + } + + switch fd.Kind() { + case protoref.MessageKind: + if fd.IsList() { + if !field.List().IsValid() { + return nil + } + + return []protoref.Value{protoref.ValueOfList(field.List())} + } + + if !field.Message().IsValid() { + return nil + } + + return resolveDataForPath(field.Message(), path[1:]) + default: + return []protoref.Value{field} + } +} + +// getMessageField gets the field from the message by its name. +func getMessageField(message protoref.Message, fieldName string) (protoref.Value, protoref.FieldDescriptor) { + fd := message.Descriptor().Fields().ByName(protoref.Name(fieldName)) + if fd == nil { + return protoref.Value{}, nil + } + + return message.Get(fd), fd +} + +// resolveUnderlyingList resolves the underlying list message from a nested list message. +// +// message ListOfFloat { +// message List { +// repeated double items = 1; +// } +// List list = 1; +// } +func resolveUnderlyingList(msg protoref.Message, fieldName string) []protoref.Value { + nestingLevel := 0 + for _, char := range fieldName { + if char != '@' { + break + } + nestingLevel++ + } + + listFieldValue := msg.Get(msg.Descriptor().Fields().ByName(protoref.Name(fieldName[nestingLevel:]))) + if !listFieldValue.IsValid() { + return nil + } + + return resolveUnderlyingListItems(listFieldValue, nestingLevel) + +} + +// resolveUnderlyingListItems resolves the items in a list message. +// +// message ListOfFloat { +// message List { +// repeated double items = 1; +// } +// List list = 1; +// } +func resolveUnderlyingListItems(value protoref.Value, nestingLevel int) []protoref.Value { + // The field number of the list and items field in the message + const listAndItemsFieldNumber = 1 + msg := value.Message() + fd := msg.Descriptor().Fields().ByNumber(listAndItemsFieldNumber) + if fd == nil { + return nil + } + + listMsg := msg.Get(fd) + if !listMsg.IsValid() { + return nil + } + + itemsValue := listMsg.Message().Get(listMsg.Message().Descriptor().Fields().ByNumber(listAndItemsFieldNumber)) + if !itemsValue.IsValid() { + return nil + } + + itemsList := itemsValue.List() + itemsListLen := itemsList.Len() + if itemsListLen == 0 { + return nil + } + + if nestingLevel > 1 { + items := make([]protoref.Value, 0, itemsListLen) + for i := 0; i < itemsListLen; i++ { + items = append(items, resolveUnderlyingListItems(itemsList.Get(i), nestingLevel-1)...) + } + + return items + } + + result := make([]protoref.Value, itemsListLen) + for i := 0; i < itemsListLen; i++ { + result[i] = itemsList.Get(i) + } + + return result +} + +func convertProtoRefValue(a arena.Arena, value protoref.Value) *astjson.Value { + switch t := value.Interface().(type) { + case nil: + return astjson.NullValue + case bool: + if t { + return astjson.TrueValue(a) + } + return astjson.FalseValue(a) + case int32: + return astjson.IntValue(a, int(t)) + case int64: + return astjson.NumberValue(a, strconv.FormatInt(t, 10)) + case uint32: + return astjson.NumberValue(a, strconv.FormatUint(uint64(t), 10)) + case uint64: + return astjson.NumberValue(a, strconv.FormatUint(t, 10)) + case float32: + return astjson.FloatValue(a, float64(t)) + case float64: + return astjson.FloatValue(a, t) + case string: + return astjson.StringValue(a, t) + case []byte: + return astjson.StringValueBytes(a, t) + case protoref.EnumNumber: + return astjson.IntValue(a, int(t)) + // case protoref.Message: + // ov := astjson.ObjectValue(a) + // // TODO: Extract the message fields and set them on the object value + // return ov + // case protoref.List: + // av := astjson.ArrayValue(a) + // // TODO: Extract the list items and set them on the array value + // return av + default: + fmt.Println("unsupported type", reflect.TypeOf(t).Name()) + return astjson.NullValue + } +} From 63abefabd5739a436e1bb394bec9a284ec939b7a Mon Sep 17 00:00:00 2001 From: Ludwig Bedacht Date: Fri, 24 Apr 2026 16:19:50 +0200 Subject: [PATCH 09/12] chore: fix skipped call --- v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go | 1 + 1 file changed, 1 insertion(+) diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go index b9cf3e761..cd7b146d4 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go @@ -268,6 +268,7 @@ func createProtoWire(a arena.Arena, fetch *fetchProgram, callMap map[int]fetchDa } callMap[fetch.id] = fetchResult + return nil, true, nil } buffer, err = fetch.request.createProtoWireWithContext(a, requestVariables, contextFetch.responseMessage) From 0dd5b1cf5fbcf19356d5d34d85e3b02aa6f2ebb7 Mon Sep 17 00:00:00 2001 From: Ludwig Bedacht Date: Fri, 24 Apr 2026 16:23:32 +0200 Subject: [PATCH 10/12] chore: remove unused function --- v2/pkg/ast/path.go | 8 -------- 1 file changed, 8 deletions(-) diff --git a/v2/pkg/ast/path.go b/v2/pkg/ast/path.go index 392705fa2..19493ec60 100644 --- a/v2/pkg/ast/path.go +++ b/v2/pkg/ast/path.go @@ -128,14 +128,6 @@ func (p Path) String() string { return out } -func (p Path) ToPathItemStrings() []string { - out := make([]string, len(p)) - for i := range p { - out[i] = unsafebytes.BytesToString(p[i].FieldName) - } - return out -} - func (p Path) DotDelimitedString() string { builder := strings.Builder{} From 3298793de875e0662d24fecd934318805d30f210 Mon Sep 17 00:00:00 2001 From: Ludwig Bedacht Date: Fri, 24 Apr 2026 16:33:27 +0200 Subject: [PATCH 11/12] chore: make linter happy --- v2/pkg/engine/datasource/grpc_datasource/codec.go | 2 +- v2/pkg/engine/datasource/grpc_datasource/program.go | 4 +++- v2/pkg/engine/datasource/grpc_datasource/program_test.go | 1 + v2/pkg/engine/datasource/grpc_datasource/wire.go | 6 ++++-- v2/pkg/engine/datasource/grpc_datasource/wire_test.go | 3 ++- 5 files changed, 11 insertions(+), 5 deletions(-) diff --git a/v2/pkg/engine/datasource/grpc_datasource/codec.go b/v2/pkg/engine/datasource/grpc_datasource/codec.go index 712a8f086..67fbdc6cc 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/codec.go +++ b/v2/pkg/engine/datasource/grpc_datasource/codec.go @@ -5,7 +5,6 @@ import ( "google.golang.org/grpc/encoding" "google.golang.org/grpc/encoding/proto" - _ "google.golang.org/grpc/encoding/proto" "google.golang.org/grpc/mem" ) @@ -13,6 +12,7 @@ var defaultCodec = encoding.GetCodecV2("proto") type connectCodec struct{} +// TODO: force codec as client option instead of registering it globally. func init() { encoding.RegisterCodecV2(&connectCodec{}) } diff --git a/v2/pkg/engine/datasource/grpc_datasource/program.go b/v2/pkg/engine/datasource/grpc_datasource/program.go index d2080c566..93885d0ec 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/program.go +++ b/v2/pkg/engine/datasource/grpc_datasource/program.go @@ -3,10 +3,12 @@ package grpcdatasource import ( "fmt" + protoref "google.golang.org/protobuf/reflect/protoreflect" + "github.com/wundergraph/astjson" "github.com/wundergraph/go-arena" + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" - protoref "google.golang.org/protobuf/reflect/protoreflect" ) type program struct { diff --git a/v2/pkg/engine/datasource/grpc_datasource/program_test.go b/v2/pkg/engine/datasource/grpc_datasource/program_test.go index d47570776..7f40d1df3 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/program_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/program_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/stretchr/testify/require" + "github.com/wundergraph/graphql-go-tools/v2/pkg/astparser" "github.com/wundergraph/graphql-go-tools/v2/pkg/grpctest" ) diff --git a/v2/pkg/engine/datasource/grpc_datasource/wire.go b/v2/pkg/engine/datasource/grpc_datasource/wire.go index c17cf999c..68426232b 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/wire.go +++ b/v2/pkg/engine/datasource/grpc_datasource/wire.go @@ -9,11 +9,13 @@ import ( "strconv" "strings" + "google.golang.org/protobuf/encoding/protowire" + protoref "google.golang.org/protobuf/reflect/protoreflect" + "github.com/wundergraph/astjson" "github.com/wundergraph/go-arena" + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" - "google.golang.org/protobuf/encoding/protowire" - protoref "google.golang.org/protobuf/reflect/protoreflect" ) var errShouldSkip = errors.New("skip") diff --git a/v2/pkg/engine/datasource/grpc_datasource/wire_test.go b/v2/pkg/engine/datasource/grpc_datasource/wire_test.go index e9623f84f..6ac7148e3 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/wire_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/wire_test.go @@ -5,10 +5,11 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/wundergraph/astjson" "google.golang.org/protobuf/proto" protoref "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/types/dynamicpb" + + "github.com/wundergraph/astjson" ) var testWireSchema = ` From 1c88cb85b050686574e9a82fe4265ffaf89617fa Mon Sep 17 00:00:00 2001 From: Ludwig Bedacht Date: Fri, 24 Apr 2026 17:06:39 +0200 Subject: [PATCH 12/12] chore: improve code --- .../datasource/grpc_datasource/program.go | 34 +++++++------------ 1 file changed, 13 insertions(+), 21 deletions(-) diff --git a/v2/pkg/engine/datasource/grpc_datasource/program.go b/v2/pkg/engine/datasource/grpc_datasource/program.go index 93885d0ec..290620c52 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/program.go +++ b/v2/pkg/engine/datasource/grpc_datasource/program.go @@ -351,31 +351,23 @@ func compileFetchRequestWithContext(runtime *runtimeSchema, message *runtimeMess request.message = requestMessage request.fields = requestMessage.fields - // context and field_args - for _, field := range rpcMessage.Fields { - switch field.Name { - case "context": - contextField, found := message.fieldsByName[field.Name] - if !found { - return nil, fmt.Errorf("context message not found for method %s", rpcMessage.Name) - } - - fetchRequestContext, err := compileFetchRequestContext(contextField.message, dependentMessage, field.Message) - if err != nil { - return nil, err - } + contextField := rpcMessage.Fields.ByName(contextFieldName) + if contextField == nil { + return nil, fmt.Errorf("context field not found for method %s", rpcMessage.Name) + } - request.context = fetchRequestContext - case "field_args": - // wireMessage, err := compileWireMessage(field.Message, message) - // if err != nil { - // return nil, err - // } + contextRuntimeField, found := message.fieldsByName[contextFieldName] + if !found { + return nil, fmt.Errorf("context field not found for method %s", rpcMessage.Name) + } - // request.wire = wireMessage - } + fetchRequestContext, err := compileFetchRequestContext(contextRuntimeField.message, dependentMessage, contextField.Message) + if err != nil { + return nil, err } + request.context = fetchRequestContext + return request, nil }