Skip to content

Commit 94031e5

Browse files
authored
fix: handle scalar values for lists (#1155)
1 parent 193fa3c commit 94031e5

13 files changed

Lines changed: 508 additions & 127 deletions

File tree

execution/engine/execution_engine_grpc_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -359,13 +359,13 @@ func TestGRPCSubgraphExecution(t *testing.T) {
359359
t.Run("should run query with a recursive type", func(t *testing.T) {
360360
operation := graphql.Request{
361361
OperationName: "RecursiveTypeQuery",
362-
Query: `query RecursiveTypeQuery { recursiveType { id name recursiveType { id recursiveType { id name recursiveType { id name } } name } } }`,
362+
Query: `query RecursiveTypeQuery { recursiveType { id name recursiveType { id recursiveType { id name } name } } }`,
363363
}
364364

365365
response, err := executeOperation(t, conn, operation, withGRPCMapping(mapping.DefaultGRPCMapping()))
366366

367367
require.NoError(t, err)
368-
require.Equal(t, `{"data":{"recursiveType":{"id":"recursive-1","name":"Level 1","recursiveType":{"id":"recursive-2","recursiveType":{"id":"recursive-3","name":"Level 3","recursiveType":{"id":"","name":""}},"name":"Level 2"}}}}`, response)
368+
require.Equal(t, `{"data":{"recursiveType":{"id":"recursive-1","name":"Level 1","recursiveType":{"id":"recursive-2","recursiveType":{"id":"recursive-3","name":"Level 3"},"name":"Level 2"}}}}`, response)
369369
})
370370

371371
t.Run("should stop when no mapping is found for the operation request", func(t *testing.T) {

v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ func (p *Planner[T]) DownstreamResponseFieldAlias(downstreamFieldRef int) (alias
266266

267267
func (p *Planner[T]) DataSourcePlanningBehavior() plan.DataSourcePlanningBehavior {
268268
return plan.DataSourcePlanningBehavior{
269-
MergeAliasedRootNodes: !p.config.IsGRPC(),
269+
MergeAliasedRootNodes: true,
270270
OverrideFieldPathFromAlias: true,
271271
IncludeTypeNameFields: true,
272272
}
@@ -367,6 +367,7 @@ func (p *Planner[T]) ConfigureFetch() resolve.FetchConfiguration {
367367
Definition: p.config.schemaConfiguration.upstreamSchemaAst,
368368
Mapping: p.config.grpc.Mapping,
369369
Compiler: p.config.grpc.Compiler,
370+
Disabled: p.config.grpc.Disabled,
370371
// TODO: remove fallback logic in visitor for subgraph name and
371372
// add proper error handling if the subgraph name is not set in the mapping
372373
SubgraphName: p.dataSourceConfig.Name(),
@@ -1704,6 +1705,7 @@ type Factory[T Configuration] struct {
17041705
executionContext context.Context
17051706
httpClient *http.Client
17061707
grpcClient grpc.ClientConnInterface
1708+
grpcClientProvider func() grpc.ClientConnInterface
17071709
subscriptionClient GraphQLSubscriptionClient
17081710
}
17091711

@@ -1746,6 +1748,26 @@ func NewFactoryGRPC(executionContext context.Context, grpcClient grpc.ClientConn
17461748
}, nil
17471749
}
17481750

1751+
// NewFactoryGRPCClientProvider creates a new factory for the GraphQL datasource planner
1752+
// This factory is used when the gRPC client is provided by a function.
1753+
// This is useful when you don't want to provide a static client to the factory and let the consumer
1754+
// decide how to provide the client to the datasource.
1755+
// For example when you need to recreate the client in case of a connection error.
1756+
func NewFactoryGRPCClientProvider(executionContext context.Context, clientProvider func() grpc.ClientConnInterface) (*Factory[Configuration], error) {
1757+
if executionContext == nil {
1758+
return nil, fmt.Errorf("execution context is required")
1759+
}
1760+
1761+
if clientProvider == nil {
1762+
return nil, fmt.Errorf("provider function is required")
1763+
}
1764+
1765+
return &Factory[Configuration]{
1766+
executionContext: executionContext,
1767+
grpcClientProvider: clientProvider,
1768+
}, nil
1769+
}
1770+
17491771
func (p *Planner[T]) getKit() *printKit {
17501772
return printKitPool.Get().(*printKit)
17511773
}
@@ -1757,9 +1779,14 @@ func (p *Planner[T]) releaseKit(kit *printKit) {
17571779
}
17581780

17591781
func (f *Factory[T]) Planner(logger abstractlogger.Logger) plan.DataSourcePlanner[T] {
1782+
grpcClient := f.grpcClient
1783+
if f.grpcClientProvider != nil {
1784+
grpcClient = f.grpcClientProvider()
1785+
}
1786+
17601787
return &Planner[T]{
17611788
fetchClient: f.httpClient,
1762-
grpcClient: f.grpcClient,
1789+
grpcClient: grpcClient,
17631790
subscriptionClient: f.subscriptionClient,
17641791
}
17651792
}

v2/pkg/engine/datasource/grpc_datasource/compiler.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -408,8 +408,13 @@ func (p *RPCCompiler) buildProtoMessage(inputMessage Message, rpcMessage *RPCMes
408408

409409
// Process each element and append to the list
410410
for _, element := range elements {
411-
fieldMsg := p.buildProtoMessage(p.doc.Messages[field.MessageRef], rpcField.Message, element)
412-
list.Append(protoref.ValueOfMessage(fieldMsg))
411+
switch field.Type {
412+
case DataTypeMessage:
413+
fieldMsg := p.buildProtoMessage(p.doc.Messages[field.MessageRef], rpcField.Message, element)
414+
list.Append(protoref.ValueOfMessage(fieldMsg))
415+
default:
416+
list.Append(p.setValueForKind(field.Type, element))
417+
}
413418
}
414419

415420
continue

v2/pkg/engine/datasource/grpc_datasource/configuration.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ type EnumValueMapping struct {
3030
}
3131

3232
type GRPCConfiguration struct {
33+
Disabled bool
3334
Mapping *GRPCMapping
3435
Compiler *RPCCompiler
3536
}

v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go

Lines changed: 100 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,16 @@ import (
1010
"bytes"
1111
"context"
1212
"fmt"
13+
"strconv"
1314

1415
"github.com/tidwall/gjson"
1516
"github.com/wundergraph/astjson"
1617
"github.com/wundergraph/graphql-go-tools/v2/pkg/ast"
1718
"github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient"
1819
"github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve"
1920
"google.golang.org/grpc"
21+
"google.golang.org/grpc/codes"
22+
"google.golang.org/grpc/status"
2023
protoref "google.golang.org/protobuf/reflect/protoreflect"
2124
)
2225

@@ -28,10 +31,11 @@ var _ resolve.DataSource = (*DataSource)(nil)
2831
// transforms the responses back to GraphQL format.
2932
type DataSource struct {
3033
// Invocations is a list of gRPC invocations to be executed
31-
plan *RPCExecutionPlan
32-
cc grpc.ClientConnInterface
33-
rc *RPCCompiler
34-
mapping *GRPCMapping
34+
plan *RPCExecutionPlan
35+
cc grpc.ClientConnInterface
36+
rc *RPCCompiler
37+
mapping *GRPCMapping
38+
disabled bool
3539
}
3640

3741
type ProtoConfig struct {
@@ -44,6 +48,7 @@ type DataSourceConfig struct {
4448
Compiler *RPCCompiler
4549
SubgraphName string
4650
Mapping *GRPCMapping
51+
Disabled bool
4752
}
4853

4954
// NewDataSource creates a new gRPC datasource
@@ -55,10 +60,11 @@ func NewDataSource(client grpc.ClientConnInterface, config DataSourceConfig) (*D
5560
}
5661

5762
return &DataSource{
58-
plan: plan,
59-
cc: client,
60-
rc: config.Compiler,
61-
mapping: config.Mapping,
63+
plan: plan,
64+
cc: client,
65+
rc: config.Compiler,
66+
mapping: config.Mapping,
67+
disabled: config.Disabled,
6268
}, nil
6369
}
6470

@@ -69,6 +75,11 @@ func NewDataSource(client grpc.ClientConnInterface, config DataSourceConfig) (*D
6975
// The input is expected to contain the necessary information to make
7076
// a gRPC call, including service name, method name, and request data.
7177
func (d *DataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) (err error) {
78+
if d.disabled {
79+
out.Write(writeErrorBytes(fmt.Errorf("gRPC datasource needs to be enabled to be used")))
80+
return nil
81+
}
82+
7283
// get variables from input
7384
variables := gjson.Parse(string(input)).Get("body.variables")
7485

@@ -123,7 +134,7 @@ func (d *DataSource) LoadWithFiles(ctx context.Context, input []byte, files []*h
123134

124135
func (d *DataSource) marshalResponseJSON(arena *astjson.Arena, message *RPCMessage, data protoref.Message) (*astjson.Value, error) {
125136
if message == nil {
126-
return nil, nil
137+
return arena.NewNull(), nil
127138
}
128139

129140
root := arena.NewObject()
@@ -158,24 +169,41 @@ func (d *DataSource) marshalResponseJSON(arena *astjson.Arena, message *RPCMessa
158169
}
159170

160171
if fd.IsList() {
172+
list := data.Get(fd).List()
173+
if !list.IsValid() {
174+
root.Set(field.JSONPath, arena.NewNull())
175+
continue
176+
}
177+
161178
arr := arena.NewArray()
162179
root.Set(field.JSONPath, arr)
163-
list := data.Get(fd).List()
164180
for i := 0; i < list.Len(); i++ {
165-
message := list.Get(i).Message()
166-
value, err := d.marshalResponseJSON(arena, field.Message, message)
167-
if err != nil {
168-
return nil, err
181+
182+
switch fd.Kind() {
183+
case protoref.MessageKind:
184+
message := list.Get(i).Message()
185+
value, err := d.marshalResponseJSON(arena, field.Message, message)
186+
if err != nil {
187+
return nil, err
188+
}
189+
190+
arr.SetArrayItem(i, value)
191+
default:
192+
d.setArrayItem(i, arena, arr, list.Get(i), fd)
169193
}
170194

171-
arr.SetArrayItem(i, value)
172195
}
173196

174197
continue
175198
}
176199

177200
if fd.Kind() == protoref.MessageKind {
178201
msg := data.Get(fd).Message()
202+
if !msg.IsValid() {
203+
root.Set(field.JSONPath, arena.NewNull())
204+
continue
205+
}
206+
179207
value, err := d.marshalResponseJSON(arena, field.Message, msg)
180208
if err != nil {
181209
return nil, err
@@ -200,6 +228,11 @@ func (d *DataSource) marshalResponseJSON(arena *astjson.Arena, message *RPCMessa
200228
}
201229

202230
func (d *DataSource) setJSONValue(arena *astjson.Arena, root *astjson.Value, name string, data protoref.Message, fd protoref.FieldDescriptor) {
231+
if !data.IsValid() {
232+
root.Set(name, arena.NewNull())
233+
return
234+
}
235+
203236
switch fd.Kind() {
204237
case protoref.BoolKind:
205238
boolValue := data.Get(fd).Bool()
@@ -213,7 +246,7 @@ func (d *DataSource) setJSONValue(arena *astjson.Arena, root *astjson.Value, nam
213246
case protoref.Int32Kind, protoref.Int64Kind:
214247
root.Set(name, arena.NewNumberInt(int(data.Get(fd).Int())))
215248
case protoref.Uint32Kind, protoref.Uint64Kind:
216-
root.Set(name, arena.NewNumberString(fmt.Sprintf("%d", data.Get(fd).Uint())))
249+
root.Set(name, arena.NewNumberString(strconv.FormatUint(data.Get(fd).Uint(), 10)))
217250
case protoref.FloatKind, protoref.DoubleKind:
218251
root.Set(name, arena.NewNumberFloat64(data.Get(fd).Float()))
219252
case protoref.BytesKind:
@@ -236,6 +269,48 @@ func (d *DataSource) setJSONValue(arena *astjson.Arena, root *astjson.Value, nam
236269
}
237270
}
238271

272+
func (d *DataSource) setArrayItem(index int, arena *astjson.Arena, array *astjson.Value, data protoref.Value, fd protoref.FieldDescriptor) {
273+
if !data.IsValid() {
274+
array.SetArrayItem(index, arena.NewNull())
275+
return
276+
}
277+
278+
switch fd.Kind() {
279+
case protoref.BoolKind:
280+
boolValue := data.Bool()
281+
if boolValue {
282+
array.SetArrayItem(index, arena.NewTrue())
283+
} else {
284+
array.SetArrayItem(index, arena.NewFalse())
285+
}
286+
case protoref.StringKind:
287+
array.SetArrayItem(index, arena.NewString(data.String()))
288+
case protoref.Int32Kind, protoref.Int64Kind:
289+
array.SetArrayItem(index, arena.NewNumberInt(int(data.Int())))
290+
case protoref.Uint32Kind, protoref.Uint64Kind:
291+
array.SetArrayItem(index, arena.NewNumberString(strconv.FormatUint(data.Uint(), 10)))
292+
case protoref.FloatKind, protoref.DoubleKind:
293+
array.SetArrayItem(index, arena.NewNumberFloat64(data.Float()))
294+
case protoref.BytesKind:
295+
array.SetArrayItem(index, arena.NewStringBytes(data.Bytes()))
296+
case protoref.EnumKind:
297+
enumDesc := fd.Enum()
298+
enumValueDesc := enumDesc.Values().ByNumber(data.Enum())
299+
if enumValueDesc == nil {
300+
array.SetArrayItem(index, arena.NewNull())
301+
return
302+
}
303+
304+
graphqlValue, ok := d.mapping.ResolveEnumValue(string(enumDesc.Name()), string(enumValueDesc.Name()))
305+
if !ok {
306+
array.SetArrayItem(index, arena.NewNull())
307+
return
308+
}
309+
310+
array.SetArrayItem(index, arena.NewString(graphqlValue))
311+
}
312+
}
313+
239314
func writeErrorBytes(err error) []byte {
240315
a := astjson.Arena{}
241316
errorRoot := a.NewObject()
@@ -244,6 +319,15 @@ func writeErrorBytes(err error) []byte {
244319

245320
errorItem := a.NewObject()
246321
errorItem.Set("message", a.NewString(err.Error()))
322+
323+
extensions := a.NewObject()
324+
if st, ok := status.FromError(err); ok {
325+
extensions.Set("code", a.NewString(st.Code().String()))
326+
} else {
327+
extensions.Set("code", a.NewString(codes.Internal.String()))
328+
}
329+
330+
errorItem.Set("extensions", extensions)
247331
errorArray.SetArrayItem(0, errorItem)
248332

249333
return errorRoot.MarshalTo(nil)

0 commit comments

Comments
 (0)