@@ -10,13 +10,16 @@ import (
1010 "bytes"
1111 "context"
1212 "fmt"
13+ "strconv"
1314
1415 "github.com/tidwall/gjson"
1516 "github.com/wundergraph/astjson"
1617 "github.com/wundergraph/graphql-go-tools/v2/pkg/ast"
1718 "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient"
1819 "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve"
1920 "google.golang.org/grpc"
21+ "google.golang.org/grpc/codes"
22+ "google.golang.org/grpc/status"
2023 protoref "google.golang.org/protobuf/reflect/protoreflect"
2124)
2225
@@ -28,10 +31,11 @@ var _ resolve.DataSource = (*DataSource)(nil)
2831// transforms the responses back to GraphQL format.
2932type DataSource struct {
3033 // Invocations is a list of gRPC invocations to be executed
31- plan * RPCExecutionPlan
32- cc grpc.ClientConnInterface
33- rc * RPCCompiler
34- mapping * GRPCMapping
34+ plan * RPCExecutionPlan
35+ cc grpc.ClientConnInterface
36+ rc * RPCCompiler
37+ mapping * GRPCMapping
38+ disabled bool
3539}
3640
3741type ProtoConfig struct {
@@ -44,6 +48,7 @@ type DataSourceConfig struct {
4448 Compiler * RPCCompiler
4549 SubgraphName string
4650 Mapping * GRPCMapping
51+ Disabled bool
4752}
4853
4954// NewDataSource creates a new gRPC datasource
@@ -55,10 +60,11 @@ func NewDataSource(client grpc.ClientConnInterface, config DataSourceConfig) (*D
5560 }
5661
5762 return & DataSource {
58- plan : plan ,
59- cc : client ,
60- rc : config .Compiler ,
61- mapping : config .Mapping ,
63+ plan : plan ,
64+ cc : client ,
65+ rc : config .Compiler ,
66+ mapping : config .Mapping ,
67+ disabled : config .Disabled ,
6268 }, nil
6369}
6470
@@ -69,6 +75,11 @@ func NewDataSource(client grpc.ClientConnInterface, config DataSourceConfig) (*D
6975// The input is expected to contain the necessary information to make
7076// a gRPC call, including service name, method name, and request data.
7177func (d * DataSource ) Load (ctx context.Context , input []byte , out * bytes.Buffer ) (err error ) {
78+ if d .disabled {
79+ out .Write (writeErrorBytes (fmt .Errorf ("gRPC datasource needs to be enabled to be used" )))
80+ return nil
81+ }
82+
7283 // get variables from input
7384 variables := gjson .Parse (string (input )).Get ("body.variables" )
7485
@@ -123,7 +134,7 @@ func (d *DataSource) LoadWithFiles(ctx context.Context, input []byte, files []*h
123134
124135func (d * DataSource ) marshalResponseJSON (arena * astjson.Arena , message * RPCMessage , data protoref.Message ) (* astjson.Value , error ) {
125136 if message == nil {
126- return nil , nil
137+ return arena . NewNull () , nil
127138 }
128139
129140 root := arena .NewObject ()
@@ -158,24 +169,41 @@ func (d *DataSource) marshalResponseJSON(arena *astjson.Arena, message *RPCMessa
158169 }
159170
160171 if fd .IsList () {
172+ list := data .Get (fd ).List ()
173+ if ! list .IsValid () {
174+ root .Set (field .JSONPath , arena .NewNull ())
175+ continue
176+ }
177+
161178 arr := arena .NewArray ()
162179 root .Set (field .JSONPath , arr )
163- list := data .Get (fd ).List ()
164180 for i := 0 ; i < list .Len (); i ++ {
165- message := list .Get (i ).Message ()
166- value , err := d .marshalResponseJSON (arena , field .Message , message )
167- if err != nil {
168- return nil , err
181+
182+ switch fd .Kind () {
183+ case protoref .MessageKind :
184+ message := list .Get (i ).Message ()
185+ value , err := d .marshalResponseJSON (arena , field .Message , message )
186+ if err != nil {
187+ return nil , err
188+ }
189+
190+ arr .SetArrayItem (i , value )
191+ default :
192+ d .setArrayItem (i , arena , arr , list .Get (i ), fd )
169193 }
170194
171- arr .SetArrayItem (i , value )
172195 }
173196
174197 continue
175198 }
176199
177200 if fd .Kind () == protoref .MessageKind {
178201 msg := data .Get (fd ).Message ()
202+ if ! msg .IsValid () {
203+ root .Set (field .JSONPath , arena .NewNull ())
204+ continue
205+ }
206+
179207 value , err := d .marshalResponseJSON (arena , field .Message , msg )
180208 if err != nil {
181209 return nil , err
@@ -200,6 +228,11 @@ func (d *DataSource) marshalResponseJSON(arena *astjson.Arena, message *RPCMessa
200228}
201229
202230func (d * DataSource ) setJSONValue (arena * astjson.Arena , root * astjson.Value , name string , data protoref.Message , fd protoref.FieldDescriptor ) {
231+ if ! data .IsValid () {
232+ root .Set (name , arena .NewNull ())
233+ return
234+ }
235+
203236 switch fd .Kind () {
204237 case protoref .BoolKind :
205238 boolValue := data .Get (fd ).Bool ()
@@ -213,7 +246,7 @@ func (d *DataSource) setJSONValue(arena *astjson.Arena, root *astjson.Value, nam
213246 case protoref .Int32Kind , protoref .Int64Kind :
214247 root .Set (name , arena .NewNumberInt (int (data .Get (fd ).Int ())))
215248 case protoref .Uint32Kind , protoref .Uint64Kind :
216- root .Set (name , arena .NewNumberString (fmt . Sprintf ( "%d" , data .Get (fd ).Uint ())))
249+ root .Set (name , arena .NewNumberString (strconv . FormatUint ( data .Get (fd ).Uint (), 10 )))
217250 case protoref .FloatKind , protoref .DoubleKind :
218251 root .Set (name , arena .NewNumberFloat64 (data .Get (fd ).Float ()))
219252 case protoref .BytesKind :
@@ -236,6 +269,48 @@ func (d *DataSource) setJSONValue(arena *astjson.Arena, root *astjson.Value, nam
236269 }
237270}
238271
272+ func (d * DataSource ) setArrayItem (index int , arena * astjson.Arena , array * astjson.Value , data protoref.Value , fd protoref.FieldDescriptor ) {
273+ if ! data .IsValid () {
274+ array .SetArrayItem (index , arena .NewNull ())
275+ return
276+ }
277+
278+ switch fd .Kind () {
279+ case protoref .BoolKind :
280+ boolValue := data .Bool ()
281+ if boolValue {
282+ array .SetArrayItem (index , arena .NewTrue ())
283+ } else {
284+ array .SetArrayItem (index , arena .NewFalse ())
285+ }
286+ case protoref .StringKind :
287+ array .SetArrayItem (index , arena .NewString (data .String ()))
288+ case protoref .Int32Kind , protoref .Int64Kind :
289+ array .SetArrayItem (index , arena .NewNumberInt (int (data .Int ())))
290+ case protoref .Uint32Kind , protoref .Uint64Kind :
291+ array .SetArrayItem (index , arena .NewNumberString (strconv .FormatUint (data .Uint (), 10 )))
292+ case protoref .FloatKind , protoref .DoubleKind :
293+ array .SetArrayItem (index , arena .NewNumberFloat64 (data .Float ()))
294+ case protoref .BytesKind :
295+ array .SetArrayItem (index , arena .NewStringBytes (data .Bytes ()))
296+ case protoref .EnumKind :
297+ enumDesc := fd .Enum ()
298+ enumValueDesc := enumDesc .Values ().ByNumber (data .Enum ())
299+ if enumValueDesc == nil {
300+ array .SetArrayItem (index , arena .NewNull ())
301+ return
302+ }
303+
304+ graphqlValue , ok := d .mapping .ResolveEnumValue (string (enumDesc .Name ()), string (enumValueDesc .Name ()))
305+ if ! ok {
306+ array .SetArrayItem (index , arena .NewNull ())
307+ return
308+ }
309+
310+ array .SetArrayItem (index , arena .NewString (graphqlValue ))
311+ }
312+ }
313+
239314func writeErrorBytes (err error ) []byte {
240315 a := astjson.Arena {}
241316 errorRoot := a .NewObject ()
@@ -244,6 +319,15 @@ func writeErrorBytes(err error) []byte {
244319
245320 errorItem := a .NewObject ()
246321 errorItem .Set ("message" , a .NewString (err .Error ()))
322+
323+ extensions := a .NewObject ()
324+ if st , ok := status .FromError (err ); ok {
325+ extensions .Set ("code" , a .NewString (st .Code ().String ()))
326+ } else {
327+ extensions .Set ("code" , a .NewString (codes .Internal .String ()))
328+ }
329+
330+ errorItem .Set ("extensions" , extensions )
247331 errorArray .SetArrayItem (0 , errorItem )
248332
249333 return errorRoot .MarshalTo (nil )
0 commit comments