diff --git a/v2/pkg/engine/datasource/grpc_datasource/entity.go b/v2/pkg/engine/datasource/grpc_datasource/entity.go new file mode 100644 index 000000000..850bb0606 --- /dev/null +++ b/v2/pkg/engine/datasource/grpc_datasource/entity.go @@ -0,0 +1,74 @@ +package grpcdatasource + +import ( + "errors" + "fmt" + + "github.com/tidwall/gjson" + + "github.com/wundergraph/astjson" +) + +// entityIndexMap maps positions in the typed gRPC response back to positions +// in the original representations array. The slice index is the response +// position; the value is the representation index. It is built per call by +// recording the position of every representation whose __typename matches +// the requested entity type. +type entityIndexMap []int + +// newEntityIndexMap builds the index map for a single entity call by collecting +// the positions of representations whose __typename matches the requested type. +// A single pass over representations populates the slice. +func newEntityIndexMap(requestedEntityType string, representations []gjson.Result) entityIndexMap { + indexMap := make(entityIndexMap, 0, len(representations)) + for i, representation := range representations { + if representation.Get(typenameFieldName).String() == requestedEntityType { + indexMap = append(indexMap, i) + } + } + return indexMap +} + +// getRepresentations gets the representations from the variables. +// If no representations are found, it returns nil. +func getRepresentations(variables gjson.Result) []gjson.Result { + r := variables.Get("representations") + if !r.Exists() { + return nil + } + + return r.Array() +} + +// validateEntityResponse verifies that the number of entities returned by the +// subgraph matches the number of representations of the requested type. +// Callers should subsequently build an entityIndexMap via newEntityIndexMap to +// merge the response — mergeEntities relies on the invariant that +// len(response entities) == len(indexMap), which this function establishes. +func validateEntityResponse(data *astjson.Value, requestedEntityType string, representations []gjson.Result) error { + if data == nil { + return errors.New("validateEntityResponse: subgraph response data is nil") + } + + if requestedEntityType == "" { + return errors.New("validateEntityResponse: requested entity type is empty; the entity RPC plan is missing a RequestedEntityType") + } + + if len(representations) == 0 { + return errors.New("validateEntityResponse: no entity representations provided in the request variables") + } + + expected := 0 + for _, representation := range representations { + if representation.Get(typenameFieldName).String() == requestedEntityType { + expected++ + } + } + + entities := data.Get(entityPath).GetArray() + if len(entities) != expected { + return fmt.Errorf("entity type %s received %d entities in the subgraph response, but %d are expected", requestedEntityType, len(entities), expected) + } + + return nil +} diff --git a/v2/pkg/engine/datasource/grpc_datasource/entity_test.go b/v2/pkg/engine/datasource/grpc_datasource/entity_test.go new file mode 100644 index 000000000..d9d7e10e3 --- /dev/null +++ b/v2/pkg/engine/datasource/grpc_datasource/entity_test.go @@ -0,0 +1,186 @@ +package grpcdatasource + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/tidwall/gjson" + + "github.com/wundergraph/astjson" +) + +func TestNewEntityIndexMap(t *testing.T) { + t.Run("returns empty map when no representations match", func(t *testing.T) { + reps := getRepresentations(gjson.Parse(`{"representations":[ + {"__typename":"Storage","id":"1"} + ]}`)) + idx := newEntityIndexMap("Product", reps) + assert.Equal(t, entityIndexMap{}, idx) + }) + + t.Run("returns empty map when representations are nil", func(t *testing.T) { + idx := newEntityIndexMap("Product", nil) + assert.Equal(t, entityIndexMap{}, idx) + }) + + t.Run("ordered representations [Product, Product, Storage, Storage]", func(t *testing.T) { + reps := getRepresentations(gjson.Parse(`{"representations":[ + {"__typename":"Product","id":"1"}, + {"__typename":"Product","id":"2"}, + {"__typename":"Storage","id":"3"}, + {"__typename":"Storage","id":"4"} + ]}`)) + + productIdx := newEntityIndexMap("Product", reps) + assert.Equal(t, entityIndexMap{0, 1}, productIdx) + + storageIdx := newEntityIndexMap("Storage", reps) + assert.Equal(t, entityIndexMap{2, 3}, storageIdx) + }) + + t.Run("unordered representations [Product, Storage, Product, Storage]", func(t *testing.T) { + reps := getRepresentations(gjson.Parse(`{"representations":[ + {"__typename":"Product","id":"1"}, + {"__typename":"Storage","id":"2"}, + {"__typename":"Product","id":"3"}, + {"__typename":"Storage","id":"4"} + ]}`)) + + productIdx := newEntityIndexMap("Product", reps) + assert.Equal(t, entityIndexMap{0, 2}, productIdx) + + storageIdx := newEntityIndexMap("Storage", reps) + assert.Equal(t, entityIndexMap{1, 3}, storageIdx) + }) + + t.Run("interleaved representations across three types", func(t *testing.T) { + reps := getRepresentations(gjson.Parse(`{"representations":[ + {"__typename":"Product","id":"1"}, + {"__typename":"Storage","id":"2"}, + {"__typename":"Warehouse","id":"3"}, + {"__typename":"Product","id":"4"}, + {"__typename":"Warehouse","id":"5"}, + {"__typename":"Storage","id":"6"} + ]}`)) + + assert.Equal(t, entityIndexMap{0, 3}, newEntityIndexMap("Product", reps)) + assert.Equal(t, entityIndexMap{1, 5}, newEntityIndexMap("Storage", reps)) + assert.Equal(t, entityIndexMap{2, 4}, newEntityIndexMap("Warehouse", reps)) + }) + + t.Run("single matching representation", func(t *testing.T) { + reps := getRepresentations(gjson.Parse(`{"representations":[ + {"__typename":"Storage","id":"1"}, + {"__typename":"Product","id":"2"}, + {"__typename":"Storage","id":"3"} + ]}`)) + + assert.Equal(t, entityIndexMap{1}, newEntityIndexMap("Product", reps)) + }) + + t.Run("preserves original positions for fully matching list", func(t *testing.T) { + reps := getRepresentations(gjson.Parse(`{"representations":[ + {"__typename":"Product","id":"1"}, + {"__typename":"Product","id":"2"}, + {"__typename":"Product","id":"3"} + ]}`)) + + assert.Equal(t, entityIndexMap{0, 1, 2}, newEntityIndexMap("Product", reps)) + }) + + t.Run("interface entity matches by typename string only", func(t *testing.T) { + // Interface-entity representations carry the interface name as __typename + // (e.g. "Resource"). The index map cares only about the typename string, + // not whether it refers to an interface or a concrete type. + reps := getRepresentations(gjson.Parse(`{"representations":[ + {"__typename":"Resource","id":"1"}, + {"__typename":"Product","id":"2"}, + {"__typename":"Resource","id":"3"}, + {"__typename":"Storage","id":"4"}, + {"__typename":"Resource","id":"5"} + ]}`)) + + assert.Equal(t, entityIndexMap{0, 2, 4}, newEntityIndexMap("Resource", reps)) + // Concrete types in the same list are independent. + assert.Equal(t, entityIndexMap{1}, newEntityIndexMap("Product", reps)) + assert.Equal(t, entityIndexMap{3}, newEntityIndexMap("Storage", reps)) + }) +} + +func TestGetRepresentations(t *testing.T) { + t.Run("returns nil when representations key missing", func(t *testing.T) { + vars := gjson.Parse(`{"other":"value"}`) + assert.Nil(t, getRepresentations(vars)) + }) + + t.Run("returns empty slice when representations is empty array", func(t *testing.T) { + vars := gjson.Parse(`{"representations":[]}`) + reps := getRepresentations(vars) + assert.NotNil(t, reps) + assert.Empty(t, reps) + }) + + t.Run("returns representations when present", func(t *testing.T) { + vars := gjson.Parse(`{"representations":[{"__typename":"Product","id":"1"},{"__typename":"Storage","id":"2"}]}`) + reps := getRepresentations(vars) + assert.Len(t, reps, 2) + assert.Equal(t, "Product", reps[0].Get("__typename").String()) + assert.Equal(t, "Storage", reps[1].Get("__typename").String()) + }) +} +func TestValidateEntityResponse(t *testing.T) { + reps := getRepresentations(gjson.Parse(`{"representations":[ + {"__typename":"Product","id":"1"}, + {"__typename":"Product","id":"2"} + ]}`)) + + t.Run("returns error when data is nil", func(t *testing.T) { + err := validateEntityResponse(nil, "Product", reps) + assert.ErrorContains(t, err, "validateEntityResponse: subgraph response data is nil") + }) + + t.Run("returns error when requested entity type is empty", func(t *testing.T) { + data := astjson.MustParse(`{"_entities":[]}`) + err := validateEntityResponse(data, "", reps) + assert.ErrorContains(t, err, "validateEntityResponse: requested entity type is empty") + }) + + t.Run("returns error when representations are empty", func(t *testing.T) { + data := astjson.MustParse(`{"_entities":[]}`) + err := validateEntityResponse(data, "Product", nil) + assert.ErrorContains(t, err, "validateEntityResponse: no entity representations provided") + }) + + t.Run("returns error when entity count mismatches representation count", func(t *testing.T) { + data := astjson.MustParse(`{"_entities":[{"__typename":"Product","id":"1"}]}`) + err := validateEntityResponse(data, "Product", reps) + assert.ErrorContains(t, err, "entity type Product received 1 entities in the subgraph response, but 2 are expected") + }) + + t.Run("returns nil when entity count matches representation count", func(t *testing.T) { + data := astjson.MustParse(`{"_entities":[{"__typename":"Product","id":"1"},{"__typename":"Product","id":"2"}]}`) + assert.NoError(t, validateEntityResponse(data, "Product", reps)) + }) + + t.Run("counts only representations of the requested type", func(t *testing.T) { + mixedReps := getRepresentations(gjson.Parse(`{"representations":[ + {"__typename":"Product","id":"1"}, + {"__typename":"Storage","id":"2"}, + {"__typename":"Product","id":"3"} + ]}`)) + data := astjson.MustParse(`{"_entities":[{"__typename":"Product","id":"1"},{"__typename":"Product","id":"3"}]}`) + assert.NoError(t, validateEntityResponse(data, "Product", mixedReps)) + }) + + t.Run("returns error when _entities key is missing", func(t *testing.T) { + data := astjson.MustParse(`{}`) + err := validateEntityResponse(data, "Product", reps) + assert.ErrorContains(t, err, "entity type Product received 0 entities in the subgraph response, but 2 are expected") + }) + + t.Run("returns error when _entities path is not an array", func(t *testing.T) { + data := astjson.MustParse(`{"_entities":"not an array"}`) + err := validateEntityResponse(data, "Product", reps) + assert.Error(t, err) + }) +} diff --git a/v2/pkg/engine/datasource/grpc_datasource/execution_plan.go b/v2/pkg/engine/datasource/grpc_datasource/execution_plan.go index a66184f5e..ed3e9d305 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/execution_plan.go +++ b/v2/pkg/engine/datasource/grpc_datasource/execution_plan.go @@ -110,6 +110,9 @@ type RPCCall struct { Response RPCMessage // ResponsePath is the path to the response in the JSON response ResponsePath ast.Path + // RequestedEntityType is the type of the entity that is being requested + // Empty if the call is not an entity lookup. + RequestedEntityType string } // RPCMessage represents a gRPC message structure for requests and responses. diff --git a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_federation_test.go b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_federation_test.go index 0117c80e3..a56882751 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_federation_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_federation_test.go @@ -52,9 +52,10 @@ func TestExecutionPlan_Federation_EntityLookup(t *testing.T) { expectedPlan: &RPCExecutionPlan{ Calls: []RPCCall{ { - ServiceName: "Products", - MethodName: "LookupProductById", - Kind: CallKindEntity, + ServiceName: "Products", + MethodName: "LookupProductById", + Kind: CallKindEntity, + RequestedEntityType: "Product", // Define the structure of the request message Request: RPCMessage{ Name: "LookupProductByIdRequest", @@ -137,9 +138,10 @@ func TestExecutionPlan_Federation_EntityLookup(t *testing.T) { expectedPlan: &RPCExecutionPlan{ Calls: []RPCCall{ { - ServiceName: "Products", - MethodName: "LookupProductById", - Kind: CallKindEntity, + ServiceName: "Products", + MethodName: "LookupProductById", + Kind: CallKindEntity, + RequestedEntityType: "Product", Request: RPCMessage{ Name: "LookupProductByIdRequest", Fields: []RPCField{ @@ -202,10 +204,11 @@ func TestExecutionPlan_Federation_EntityLookup(t *testing.T) { }, }, { - ID: 1, - ServiceName: "Products", - MethodName: "LookupStorageById", - Kind: CallKindEntity, + ID: 1, + ServiceName: "Products", + MethodName: "LookupStorageById", + Kind: CallKindEntity, + RequestedEntityType: "Storage", Request: RPCMessage{ Name: "LookupStorageByIdRequest", Fields: []RPCField{ @@ -348,9 +351,10 @@ func TestExecutionPlan_Federation_EntityKeys(t *testing.T) { expectedPlan: &RPCExecutionPlan{ Calls: []RPCCall{ { - ServiceName: "Products", - MethodName: "LookupUserById", - Kind: CallKindEntity, + ServiceName: "Products", + MethodName: "LookupUserById", + Kind: CallKindEntity, + RequestedEntityType: "User", // Define the structure of the request message Request: RPCMessage{ Name: "LookupUserByIdRequest", @@ -458,9 +462,10 @@ func TestExecutionPlan_Federation_EntityKeys(t *testing.T) { expectedPlan: &RPCExecutionPlan{ Calls: []RPCCall{ { - ServiceName: "Products", - MethodName: "LookupUserByIdAndAddress", - Kind: CallKindEntity, + ServiceName: "Products", + MethodName: "LookupUserByIdAndAddress", + Kind: CallKindEntity, + RequestedEntityType: "User", // Define the structure of the request message Request: RPCMessage{ Name: "LookupUserByIdAndAddressRequest", @@ -572,9 +577,10 @@ func TestExecutionPlan_Federation_EntityKeys(t *testing.T) { expectedPlan: &RPCExecutionPlan{ Calls: []RPCCall{ { - ServiceName: "Products", - MethodName: "LookupUserByIdAndName", - Kind: CallKindEntity, + ServiceName: "Products", + MethodName: "LookupUserByIdAndName", + Kind: CallKindEntity, + RequestedEntityType: "User", Request: RPCMessage{ Name: "LookupUserByIdAndNameRequest", Fields: []RPCField{ @@ -674,9 +680,10 @@ func TestExecutionPlan_Federation_EntityKeys(t *testing.T) { expectedPlan: &RPCExecutionPlan{ Calls: []RPCCall{ { - ServiceName: "Products", - MethodName: "LookupUserByIdAndName", - Kind: CallKindEntity, + ServiceName: "Products", + MethodName: "LookupUserByIdAndName", + Kind: CallKindEntity, + RequestedEntityType: "User", Request: RPCMessage{ Name: "LookupUserByIdAndNameRequest", Fields: []RPCField{ @@ -782,9 +789,10 @@ func TestExecutionPlan_Federation_EntityKeys(t *testing.T) { expectedPlan: &RPCExecutionPlan{ Calls: []RPCCall{ { - ServiceName: "Products", - MethodName: "LookupUserByIdAndNameAndAddress", - Kind: CallKindEntity, + ServiceName: "Products", + MethodName: "LookupUserByIdAndNameAndAddress", + Kind: CallKindEntity, + RequestedEntityType: "User", Request: RPCMessage{ Name: "LookupUserByIdAndNameAndAddressRequest", Fields: []RPCField{ @@ -906,9 +914,10 @@ func TestExecutionPlan_Federation_EntityKeys(t *testing.T) { expectedPlan: &RPCExecutionPlan{ Calls: []RPCCall{ { - ServiceName: "Products", - MethodName: "LookupUserById", - Kind: CallKindEntity, + ServiceName: "Products", + MethodName: "LookupUserById", + Kind: CallKindEntity, + RequestedEntityType: "User", // Define the structure of the request message Request: RPCMessage{ Name: "LookupUserByIdRequest", @@ -1085,9 +1094,10 @@ func TestEntityLookupWithNestedInlineFragments(t *testing.T) { expectedPlan: &RPCExecutionPlan{ Calls: []RPCCall{ { - ServiceName: "Products", - MethodName: "LookupUserById", - Kind: CallKindEntity, + ServiceName: "Products", + MethodName: "LookupUserById", + Kind: CallKindEntity, + RequestedEntityType: "User", Request: RPCMessage{ Name: "LookupUserByIdRequest", Fields: []RPCField{ @@ -1210,9 +1220,10 @@ func TestEntityLookupWithNestedInlineFragments(t *testing.T) { expectedPlan: &RPCExecutionPlan{ Calls: []RPCCall{ { - ServiceName: "Products", - MethodName: "LookupUserById", - Kind: CallKindEntity, + ServiceName: "Products", + MethodName: "LookupUserById", + Kind: CallKindEntity, + RequestedEntityType: "User", Request: RPCMessage{ Name: "LookupUserByIdRequest", Fields: []RPCField{ @@ -1314,9 +1325,10 @@ func TestEntityLookupWithNestedInlineFragments(t *testing.T) { expectedPlan: &RPCExecutionPlan{ Calls: []RPCCall{ { - ServiceName: "Products", - MethodName: "LookupUserById", - Kind: CallKindEntity, + ServiceName: "Products", + MethodName: "LookupUserById", + Kind: CallKindEntity, + RequestedEntityType: "User", Request: RPCMessage{ Name: "LookupUserByIdRequest", Fields: []RPCField{ @@ -1423,9 +1435,10 @@ func TestEntityLookupWithNestedInlineFragments(t *testing.T) { expectedPlan: &RPCExecutionPlan{ Calls: []RPCCall{ { - ServiceName: "Products", - MethodName: "LookupUserById", - Kind: CallKindEntity, + ServiceName: "Products", + MethodName: "LookupUserById", + Kind: CallKindEntity, + RequestedEntityType: "User", Request: RPCMessage{ Name: "LookupUserByIdRequest", Fields: []RPCField{ @@ -1687,9 +1700,10 @@ func TestEntityLookupWithFieldResolvers_ComplexResolverInNestedMessage(t *testin expectedPlan := &RPCExecutionPlan{ Calls: []RPCCall{ { - ServiceName: "Products", - MethodName: "LookupProductById", - Kind: CallKindEntity, + ServiceName: "Products", + MethodName: "LookupProductById", + Kind: CallKindEntity, + RequestedEntityType: "Product", Request: RPCMessage{ Name: "LookupProductByIdRequest", Fields: []RPCField{ diff --git a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_field_resolvers_test.go b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_field_resolvers_test.go index ae4a63217..3b7ea647d 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_field_resolvers_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_field_resolvers_test.go @@ -4286,9 +4286,10 @@ func TestExecutionPlan_FederationFieldResolvers(t *testing.T) { expectedPlan: &RPCExecutionPlan{ Calls: []RPCCall{ { - ServiceName: "Products", - MethodName: "LookupProductById", - Kind: CallKindEntity, + ServiceName: "Products", + MethodName: "LookupProductById", + Kind: CallKindEntity, + RequestedEntityType: "Product", Request: RPCMessage{ Name: "LookupProductByIdRequest", Fields: []RPCField{ @@ -4464,9 +4465,10 @@ func TestExecutionPlan_FederationFieldResolvers(t *testing.T) { expectedPlan: &RPCExecutionPlan{ Calls: []RPCCall{ { - ServiceName: "Products", - MethodName: "LookupStorageById", - Kind: CallKindEntity, + ServiceName: "Products", + MethodName: "LookupStorageById", + Kind: CallKindEntity, + RequestedEntityType: "Storage", Request: RPCMessage{ Name: "LookupStorageByIdRequest", Fields: []RPCField{ @@ -4528,10 +4530,11 @@ func TestExecutionPlan_FederationFieldResolvers(t *testing.T) { }, }, { - ID: 1, - ServiceName: "Products", - MethodName: "LookupProductById", - Kind: CallKindEntity, + ID: 1, + ServiceName: "Products", + MethodName: "LookupProductById", + Kind: CallKindEntity, + RequestedEntityType: "Product", Request: RPCMessage{ Name: "LookupProductByIdRequest", Fields: []RPCField{ @@ -4742,9 +4745,10 @@ func TestExecutionPlan_FederationFieldResolvers_WithCompositeTypes(t *testing.T) expectedPlan: &RPCExecutionPlan{ Calls: []RPCCall{ { - ServiceName: "Products", - MethodName: "LookupProductById", - Kind: CallKindEntity, + ServiceName: "Products", + MethodName: "LookupProductById", + Kind: CallKindEntity, + RequestedEntityType: "Product", Request: RPCMessage{ Name: "LookupProductByIdRequest", Fields: []RPCField{ @@ -4920,9 +4924,10 @@ func TestExecutionPlan_FederationFieldResolvers_WithCompositeTypes(t *testing.T) expectedPlan: &RPCExecutionPlan{ Calls: []RPCCall{ { - ServiceName: "Products", - MethodName: "LookupProductById", - Kind: CallKindEntity, + ServiceName: "Products", + MethodName: "LookupProductById", + Kind: CallKindEntity, + RequestedEntityType: "Product", Request: RPCMessage{ Name: "LookupProductByIdRequest", Fields: []RPCField{ @@ -5103,9 +5108,10 @@ func TestExecutionPlan_FederationFieldResolvers_WithCompositeTypes(t *testing.T) expectedPlan: &RPCExecutionPlan{ Calls: []RPCCall{ { - ServiceName: "Products", - MethodName: "LookupProductById", - Kind: CallKindEntity, + ServiceName: "Products", + MethodName: "LookupProductById", + Kind: CallKindEntity, + RequestedEntityType: "Product", Request: RPCMessage{ Name: "LookupProductByIdRequest", Fields: []RPCField{ diff --git a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_requires_test.go b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_requires_test.go index e5c140609..d70a2c5e9 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_requires_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_requires_test.go @@ -37,9 +37,10 @@ func TestExecutionPlan_FederationRequires(t *testing.T) { expectedPlan: &RPCExecutionPlan{ Calls: []RPCCall{ { - ServiceName: "Products", - MethodName: "LookupWarehouseById", - Kind: CallKindEntity, + ServiceName: "Products", + MethodName: "LookupWarehouseById", + Kind: CallKindEntity, + RequestedEntityType: "Warehouse", Request: RPCMessage{ Name: "LookupWarehouseByIdRequest", Fields: []RPCField{ @@ -204,9 +205,10 @@ func TestExecutionPlan_FederationRequires(t *testing.T) { expectedPlan: &RPCExecutionPlan{ Calls: []RPCCall{ { - ServiceName: "Products", - MethodName: "LookupStorageById", - Kind: CallKindEntity, + ServiceName: "Products", + MethodName: "LookupStorageById", + Kind: CallKindEntity, + RequestedEntityType: "Storage", Request: RPCMessage{ Name: "LookupStorageByIdRequest", Fields: []RPCField{ @@ -352,9 +354,10 @@ func TestExecutionPlan_FederationRequires(t *testing.T) { expectedPlan: &RPCExecutionPlan{ Calls: []RPCCall{ { - ServiceName: "Products", - MethodName: "LookupStorageById", - Kind: CallKindEntity, + ServiceName: "Products", + MethodName: "LookupStorageById", + Kind: CallKindEntity, + RequestedEntityType: "Storage", Request: RPCMessage{ Name: "LookupStorageByIdRequest", Fields: []RPCField{ @@ -501,9 +504,10 @@ func TestExecutionPlan_FederationRequires(t *testing.T) { expectedPlan: &RPCExecutionPlan{ Calls: []RPCCall{ { - ServiceName: "Products", - MethodName: "LookupStorageById", - Kind: CallKindEntity, + ServiceName: "Products", + MethodName: "LookupStorageById", + Kind: CallKindEntity, + RequestedEntityType: "Storage", Request: RPCMessage{ Name: "LookupStorageByIdRequest", Fields: []RPCField{ @@ -658,9 +662,10 @@ func TestExecutionPlan_FederationRequires(t *testing.T) { expectedPlan: &RPCExecutionPlan{ Calls: []RPCCall{ { - ServiceName: "Products", - MethodName: "LookupStorageById", - Kind: CallKindEntity, + ServiceName: "Products", + MethodName: "LookupStorageById", + Kind: CallKindEntity, + RequestedEntityType: "Storage", Request: RPCMessage{ Name: "LookupStorageByIdRequest", Fields: []RPCField{ @@ -840,9 +845,10 @@ func TestExecutionPlan_FederationRequires(t *testing.T) { expectedPlan: &RPCExecutionPlan{ Calls: []RPCCall{ { - ServiceName: "Products", - MethodName: "LookupStorageById", - Kind: CallKindEntity, + ServiceName: "Products", + MethodName: "LookupStorageById", + Kind: CallKindEntity, + RequestedEntityType: "Storage", Request: RPCMessage{ Name: "LookupStorageByIdRequest", Fields: []RPCField{ @@ -1013,9 +1019,10 @@ func TestExecutionPlan_FederationRequires(t *testing.T) { expectedPlan: &RPCExecutionPlan{ Calls: []RPCCall{ { - ServiceName: "Products", - MethodName: "LookupStorageById", - Kind: CallKindEntity, + ServiceName: "Products", + MethodName: "LookupStorageById", + Kind: CallKindEntity, + RequestedEntityType: "Storage", Request: RPCMessage{ Name: "LookupStorageByIdRequest", Fields: []RPCField{ @@ -1162,9 +1169,10 @@ func TestExecutionPlan_FederationRequires(t *testing.T) { expectedPlan: &RPCExecutionPlan{ Calls: []RPCCall{ { - ServiceName: "Products", - MethodName: "LookupStorageById", - Kind: CallKindEntity, + ServiceName: "Products", + MethodName: "LookupStorageById", + Kind: CallKindEntity, + RequestedEntityType: "Storage", Request: RPCMessage{ Name: "LookupStorageByIdRequest", Fields: []RPCField{ @@ -1398,9 +1406,10 @@ func TestExecutionPlan_FederationRequires(t *testing.T) { expectedPlan: &RPCExecutionPlan{ Calls: []RPCCall{ { - ServiceName: "Products", - MethodName: "LookupStorageById", - Kind: CallKindEntity, + ServiceName: "Products", + MethodName: "LookupStorageById", + Kind: CallKindEntity, + RequestedEntityType: "Storage", Request: RPCMessage{ Name: "LookupStorageByIdRequest", Fields: []RPCField{ @@ -1546,9 +1555,10 @@ func TestExecutionPlan_FederationRequires(t *testing.T) { expectedPlan: &RPCExecutionPlan{ Calls: []RPCCall{ { - ServiceName: "Products", - MethodName: "LookupStorageById", - Kind: CallKindEntity, + ServiceName: "Products", + MethodName: "LookupStorageById", + Kind: CallKindEntity, + RequestedEntityType: "Storage", Request: RPCMessage{ Name: "LookupStorageByIdRequest", Fields: []RPCField{ @@ -1709,9 +1719,10 @@ func TestExecutionPlan_FederationRequires(t *testing.T) { expectedPlan: &RPCExecutionPlan{ Calls: []RPCCall{ { - ServiceName: "Products", - MethodName: "LookupStorageById", - Kind: CallKindEntity, + ServiceName: "Products", + MethodName: "LookupStorageById", + Kind: CallKindEntity, + RequestedEntityType: "Storage", Request: RPCMessage{ Name: "LookupStorageByIdRequest", Fields: []RPCField{ @@ -1878,9 +1889,10 @@ func TestExecutionPlan_FederationRequires(t *testing.T) { expectedPlan: &RPCExecutionPlan{ Calls: []RPCCall{ { - ServiceName: "Products", - MethodName: "LookupStorageById", - Kind: CallKindEntity, + ServiceName: "Products", + MethodName: "LookupStorageById", + Kind: CallKindEntity, + RequestedEntityType: "Storage", Request: RPCMessage{ Name: "LookupStorageByIdRequest", Fields: []RPCField{ @@ -2042,9 +2054,10 @@ func TestExecutionPlan_FederationRequires(t *testing.T) { expectedPlan: &RPCExecutionPlan{ Calls: []RPCCall{ { - ServiceName: "Products", - MethodName: "LookupStorageById", - Kind: CallKindEntity, + ServiceName: "Products", + MethodName: "LookupStorageById", + Kind: CallKindEntity, + RequestedEntityType: "Storage", Request: RPCMessage{ Name: "LookupStorageByIdRequest", Fields: []RPCField{ @@ -2224,9 +2237,10 @@ func TestExecutionPlan_FederationRequires_AbstractTypes(t *testing.T) { // storageEntityLookupCall returns the common entity lookup call shared by all tests storageEntityLookupCall := func() RPCCall { return RPCCall{ - ServiceName: "Products", - MethodName: "LookupStorageById", - Kind: CallKindEntity, + ServiceName: "Products", + MethodName: "LookupStorageById", + Kind: CallKindEntity, + RequestedEntityType: "Storage", Request: RPCMessage{ Name: "LookupStorageByIdRequest", Fields: []RPCField{ @@ -3165,9 +3179,10 @@ func TestExecutionPlan_FederationRequires_WithFieldResolvers(t *testing.T) { expectedPlan: &RPCExecutionPlan{ Calls: []RPCCall{ { - ServiceName: "Products", - MethodName: "LookupStorageById", - Kind: CallKindEntity, + ServiceName: "Products", + MethodName: "LookupStorageById", + Kind: CallKindEntity, + RequestedEntityType: "Storage", Request: RPCMessage{ Name: "LookupStorageByIdRequest", Fields: []RPCField{ @@ -3399,9 +3414,10 @@ func TestExecutionPlan_FederationRequires_WithFieldResolvers(t *testing.T) { expectedPlan: &RPCExecutionPlan{ Calls: []RPCCall{ { - ServiceName: "Products", - MethodName: "LookupStorageById", - Kind: CallKindEntity, + ServiceName: "Products", + MethodName: "LookupStorageById", + Kind: CallKindEntity, + RequestedEntityType: "Storage", Request: RPCMessage{ Name: "LookupStorageByIdRequest", Fields: []RPCField{ @@ -3655,9 +3671,10 @@ func TestExecutionPlan_FederationRequires_WithFieldResolvers(t *testing.T) { expectedPlan: &RPCExecutionPlan{ Calls: []RPCCall{ { - ServiceName: "Products", - MethodName: "LookupStorageById", - Kind: CallKindEntity, + ServiceName: "Products", + MethodName: "LookupStorageById", + Kind: CallKindEntity, + RequestedEntityType: "Storage", Request: RPCMessage{ Name: "LookupStorageByIdRequest", Fields: []RPCField{ @@ -3908,9 +3925,10 @@ func TestExecutionPlan_FederationRequires_WithFieldResolvers(t *testing.T) { expectedPlan: &RPCExecutionPlan{ Calls: []RPCCall{ { - ServiceName: "Products", - MethodName: "LookupStorageById", - Kind: CallKindEntity, + ServiceName: "Products", + MethodName: "LookupStorageById", + Kind: CallKindEntity, + RequestedEntityType: "Storage", Request: RPCMessage{ Name: "LookupStorageByIdRequest", Fields: []RPCField{ @@ -4230,9 +4248,10 @@ func TestExecutionPlan_FederationRequires_WithFieldResolvers(t *testing.T) { expectedPlan: &RPCExecutionPlan{ Calls: []RPCCall{ { - ServiceName: "Products", - MethodName: "LookupStorageById", - Kind: CallKindEntity, + ServiceName: "Products", + MethodName: "LookupStorageById", + Kind: CallKindEntity, + RequestedEntityType: "Storage", Request: RPCMessage{ Name: "LookupStorageByIdRequest", Fields: []RPCField{ diff --git a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor_federation.go b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor_federation.go index e95b1693f..7656f9759 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor_federation.go +++ b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor_federation.go @@ -132,9 +132,10 @@ func (r *rpcPlanVisitorFederation) EnterInlineFragment(ref int) { } r.currentCall = &RPCCall{ - ID: r.callIndex, - ServiceName: r.planCtx.resolveServiceName(r.subgraphName), - Kind: CallKindEntity, + ID: r.callIndex, + ServiceName: r.planCtx.resolveServiceName(r.subgraphName), + Kind: CallKindEntity, + RequestedEntityType: fragmentName, } r.callIndex++ @@ -508,9 +509,8 @@ func (r *rpcPlanVisitorFederation) resolveEntityInformation(inlineFragmentRef in return errors.New("definition node not found for inline fragment: " + fragmentName) } - // Only process object type definitions - // TODO: handle interfaces - if node.Kind != ast.NodeKindObjectTypeDefinition { + // Only process object type definitions and interface type definitions + if node.Kind != ast.NodeKindObjectTypeDefinition && node.Kind != ast.NodeKindInterfaceTypeDefinition { return nil } @@ -568,6 +568,16 @@ func (r *rpcPlanVisitorFederation) scaffoldEntityLookup(typeName string, ecd ent }, } + // Check if the entity type is an interface and set oneof type and member types. + if node, found := r.definition.NodeByNameStr(typeName); found { + if node.Kind == ast.NodeKindInterfaceTypeDefinition { + entityMessage.OneOfType = OneOfTypeInterface + if memberTypes, ok := r.definition.InterfaceTypeDefinitionImplementedByObjectWithNames(node.Ref); ok { + entityMessage.MemberTypes = memberTypes + } + } + } + // The proto response message has a field `result` which is a list of entities. // As this is a special case we directly map it to _entities. r.planInfo.currentResponseMessage.Fields = []RPCField{ diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go index 4bc774ea1..527b3e244 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go @@ -30,9 +30,10 @@ import ( ) type resultData struct { - kind CallKind - response *astjson.Value - responsePath ast.Path + kind CallKind + response *astjson.Value + responsePath ast.Path + entityIndexMap entityIndexMap } // Verify DataSource implements the resolve.DataSource interface @@ -47,6 +48,7 @@ type DataSource struct { rc *RPCCompiler mapping *GRPCMapping federationConfigs plan.FederationFieldConfigurations + definition *ast.Document disabled bool pool *arena.Pool @@ -82,6 +84,7 @@ func NewDataSource(client grpc.ClientConnInterface, config DataSourceConfig) (*D cc: client, rc: config.Compiler, mapping: config.Mapping, + definition: config.Definition, federationConfigs: config.FederationConfigs, disabled: config.Disabled, pool: arena.NewArenaPool(), @@ -107,6 +110,7 @@ func (d *DataSource) Load(ctx context.Context, headers http.Header, input []byte item := d.acquirePoolItem(input, 0) poolItems = append(poolItems, item) + builder := newJSONBuilder(item.Arena, d.mapping, variables) if d.disabled { @@ -130,6 +134,7 @@ func (d *DataSource) Load(ctx context.Context, headers http.Header, input []byte root := astjson.ObjectValue(nil) + representations := getRepresentations(variables) if err := graph.TopologicalSortResolve(func(nodes []FetchItem) error { serviceCalls, err := d.rc.CompileFetches(graph, nodes, variables) if err != nil { @@ -143,10 +148,10 @@ func (d *DataSource) Load(ctx context.Context, headers http.Header, input []byte for index, serviceCall := range serviceCalls { item := d.acquirePoolItem(input, index) poolItems = append(poolItems, item) + builder := newJSONBuilder(item.Arena, d.mapping, variables) errGrp.Go(func() error { // Invoke the gRPC method - this will populate serviceCall.Output - err := d.cc.Invoke(errGrpCtx, serviceCall.MethodFullName(), serviceCall.Input, serviceCall.Output) if err != nil { return err @@ -157,19 +162,22 @@ func (d *DataSource) Load(ctx context.Context, headers http.Header, input []byte return err } + results[index] = resultData{ + kind: serviceCall.RPC.Kind, + response: response, + responsePath: serviceCall.RPC.ResponsePath, + } + // In case of a federated response, we need to ensure that the response is valid. - // The number of entities per type must match the number of lookup keys in the variablese + // The number of entities per type must match the number of lookup keys in the variables. + // On success we build the index map used by mergeEntities to place each response + // entity at the correct position in the final _entities array. if serviceCall.RPC.Kind == CallKindEntity { - err = builder.validateFederatedResponse(response) - if err != nil { + if err := validateEntityResponse(response, serviceCall.RPC.RequestedEntityType, representations); err != nil { return err } - } - results[index] = resultData{ - kind: serviceCall.RPC.Kind, - response: response, - responsePath: serviceCall.RPC.ResponsePath, + results[index].entityIndexMap = newEntityIndexMap(serviceCall.RPC.RequestedEntityType, representations) } return nil @@ -185,7 +193,7 @@ func (d *DataSource) Load(ctx context.Context, headers http.Header, input []byte case CallKindResolve, CallKindRequired: err = builder.mergeWithPath(root, result.response, result.responsePath) default: - root, err = builder.mergeValues(root, result.response) + root, err = builder.mergeValues(root, result) } if err != nil { return err diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go index d912c2fe2..112ce26d6 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go @@ -611,8 +611,7 @@ func TestMarshalResponseJSON(t *testing.T) { responseMessageDesc := responseMsg.Desc responseMessage := dynamicpb.NewMessage(responseMessageDesc) responseMessage.Mutable(responseMessageDesc.Fields().ByName("result")).List().Append(protoref.ValueOfMessage(productMessage)) - - jsonBuilder := newJSONBuilder(nil, nil, gjson.Result{}) + jsonBuilder := newJSONBuilder(nil, testMapping(), gjson.Result{}) responseJSON, err := jsonBuilder.marshalResponseJSON(&response, responseMessage) require.NoError(t, err) require.Equal(t, `{"_entities":[{"__typename":"Product","id":"123","name_different":"test","price_different":123.45}]}`, responseJSON.String()) diff --git a/v2/pkg/engine/datasource/grpc_datasource/json_builder.go b/v2/pkg/engine/datasource/grpc_datasource/json_builder.go index 3eeff2a04..556371d51 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/json_builder.go +++ b/v2/pkg/engine/datasource/grpc_datasource/json_builder.go @@ -24,80 +24,6 @@ const ( resolveResponsePath = "result" // Path for resolve response ) -// entityIndex represents the mapping between representation order and result order -// for GraphQL federation entities. This is crucial for maintaining correct entity -// order when multiple subgraphs return entities in different orders. -type entityIndex struct { - representationIndex int // Index in the original representation array - resultIndex int // Index where this entity should appear in the final result -} - -// indexMap maps GraphQL type names to their corresponding entity indices -// This allows proper ordering of federated entities by type -type indexMap map[string][]entityIndex - -// getResultIndex returns the correct result index for an entity based on its type -// and representation index. This ensures federated entities maintain proper ordering -// across multiple subgraph responses. -func (i indexMap) getResultIndex(val *astjson.Value, representationIndex int) int { - if i == nil { - return representationIndex - } - - if val == nil { - return representationIndex - } - - // Extract the __typename field to determine entity type - typeName := val.Get("__typename").GetStringBytes() - - // Find the correct result index for this type and representation index - for _, entityIndex := range i[string(typeName)] { - if entityIndex.representationIndex == representationIndex { - return entityIndex.resultIndex - } - } - - // Fallback to representation index if no mapping found - return representationIndex -} - -// createRepresentationIndexMap builds an index mapping for GraphQL federation entities -// from the variables containing entity representations. This map is used to ensure -// that entities are returned in the correct order when merging responses from multiple -// subgraphs, which is critical for GraphQL federation correctness. -func createRepresentationIndexMap(variables gjson.Result) indexMap { - var representations []gjson.Result - r := variables.Get("representations") - if !r.Exists() { - return nil - } - - representations = r.Array() - im := make(indexMap) - indexSet := make(map[string]int) // Track count per type name - - // Build mapping for each representation - for i, representation := range representations { - typeName := representation.Get("__typename").String() - - // Initialize counter for new type names - if _, ok := indexSet[typeName]; !ok { - indexSet[typeName] = -1 - } - - // Increment index for this type - indexSet[typeName]++ - - // Create mapping entry for this entity - im[typeName] = append(im[typeName], entityIndex{ - representationIndex: indexSet[typeName], // Position within entities of this type - resultIndex: i, // Position in the overall result array - }) - } - return im -} - // jsonBuilder is the core component responsible for converting gRPC protobuf responses // into GraphQL-compatible JSON format. It handles complex scenarios including: // - GraphQL federation entity merging and ordering @@ -107,7 +33,6 @@ func createRepresentationIndexMap(variables gjson.Result) indexMap { type jsonBuilder struct { mapping *GRPCMapping // Mapping configuration for GraphQL to gRPC translation variables gjson.Result // GraphQL variables containing entity representations - indexMap indexMap // Entity index mapping for federation ordering jsonArena arena.Arena } @@ -118,72 +43,24 @@ func newJSONBuilder(a arena.Arena, mapping *GRPCMapping, variables gjson.Result) return &jsonBuilder{ mapping: mapping, variables: variables, - indexMap: createRepresentationIndexMap(variables), jsonArena: a, } } -// validateFederatedResponse validates that the federated response is valid -// by checking that the number of entities per type is correct. -// For non-federated responses, this function is a no-op. -func (j *jsonBuilder) validateFederatedResponse(response *astjson.Value) error { - if j.indexMap == nil { - return nil - } - - // Get the entities array from the response - // If we have an index map, we expect it to be a federated response - entities, err := response.Get(entityPath).Array() - if err != nil { - return err - } - - // Count the number of entities per type - entitiyCountPerType := make(map[string]int) - for _, entity := range entities { - entityType := entity.Get("__typename").GetStringBytes() - entitiyCountPerType[string(entityType)]++ - } - - // Check that the number of entities per type is correct and exists in the index map. - for typeName, count := range entitiyCountPerType { - em, found := j.indexMap[typeName] - if !found { - return fmt.Errorf("entity type %s received in the subgraph response, but was not expected", typeName) - } - - if len(em) != count { - return fmt.Errorf("entity type %s received %d entities in the subgraph response, but %d are expected", typeName, count, len(em)) - } - } - return nil -} - // mergeValues combines two JSON values while preserving proper federation entity ordering. // This is a critical function for GraphQL federation where multiple subgraphs may // return entities that need to be merged in the correct order. -func (j *jsonBuilder) mergeValues(left *astjson.Value, right *astjson.Value) (*astjson.Value, error) { - if len(j.indexMap) == 0 { +func (j *jsonBuilder) mergeValues(left *astjson.Value, right resultData) (*astjson.Value, error) { + if right.kind != CallKindEntity { // No federation index map available - use simple merge // This path is taken for non-federated queries - root, _, err := astjson.MergeValues(j.jsonArena, left, right) + root, _, err := astjson.MergeValues(j.jsonArena, left, right.response) if err != nil { return nil, err } return root, nil } - // Federation entities present - must preserve representation order - leftObject, err := left.Object() - if err != nil { - return nil, err - } - - // If left side is empty, just return right side - if leftObject.Len() == 0 { - return right, nil - } - // Perform federation-aware entity merging return j.mergeEntities(left, right) } @@ -191,35 +68,34 @@ func (j *jsonBuilder) mergeValues(left *astjson.Value, right *astjson.Value) (*a // mergeEntities performs federation-aware merging of entity arrays from multiple subgraph responses. // This function ensures that entities are placed in the correct positions in the final response // array based on their original representation order, which is critical for GraphQL federation. -func (j *jsonBuilder) mergeEntities(left *astjson.Value, right *astjson.Value) (*astjson.Value, error) { - - // Create the response structure with _entities array - entities := astjson.ObjectValue(j.jsonArena) - entities.Set(j.jsonArena, entityPath, astjson.ArrayValue(j.jsonArena)) - arr := entities.Get(entityPath) - - // Extract entity arrays from both responses - leftRepresentations, err := left.Get(entityPath).Array() - if err != nil { - return nil, err - } +// +// entityIndexMap is indexed directly without bounds checks: if the lengths +// did not match we would have already aborted in validateEntityResponse. +// +// On the first call, left is the empty root object and we allocate a fresh +// _entities array. On subsequent calls, left is the result of a previous +// mergeEntities call and already holds the _entities array, so we mutate it +// in place rather than copying every accumulated entity into a new array. +func (j *jsonBuilder) mergeEntities(left *astjson.Value, rightResult resultData) (*astjson.Value, error) { + right := rightResult.response + rightEntities := right.Get(entityPath).GetArray() - rightRepresentations, err := right.Get(entityPath).Array() - if err != nil { - return nil, err + if left == nil { + left = astjson.ObjectValue(j.jsonArena) } - // Merge left entities using index mapping to preserve order - for index, lr := range leftRepresentations { - arr.SetArrayItem(j.jsonArena, j.indexMap.getResultIndex(lr, index), lr) + arr := left.Get(entityPath) + if arr == nil || arr.Type() != astjson.TypeArray { + left.Set(j.jsonArena, entityPath, astjson.ArrayValue(j.jsonArena)) + arr = left.Get(entityPath) } - // Merge right entities using index mapping to preserve order - for index, rr := range rightRepresentations { - arr.SetArrayItem(j.jsonArena, j.indexMap.getResultIndex(rr, index), rr) + // Place right's entities at their global positions in the merged array. + for index, rr := range rightEntities { + arr.SetArrayItem(j.jsonArena, rightResult.entityIndexMap[index], rr) } - return entities, nil + return left, nil } // mergeWithPath merges a JSON value with a resolved value by its path. @@ -368,9 +244,9 @@ func (j *jsonBuilder) marshalResponseJSON(message *RPCMessage, data protoref.Mes } // Type-specific static value - match against member types - for _, memberTypes := range message.MemberTypes { - if memberTypes == string(data.Type().Descriptor().Name()) { - root.Set(j.jsonArena, field.AliasOrPath(), astjson.StringValue(j.jsonArena, memberTypes)) + for _, memberType := range message.MemberTypes { + if memberType == string(data.Type().Descriptor().Name()) { + root.Set(j.jsonArena, field.AliasOrPath(), astjson.StringValue(j.jsonArena, memberType)) break } }