diff --git a/.gitignore b/.gitignore index 960ca0c8d5..17c4571439 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,5 @@ .DS_Store pkg/parser/testdata/lotto.graphql *node_modules* -*vendor* \ No newline at end of file +*vendor* +.serena \ No newline at end of file diff --git a/examples/federation/go.mod b/examples/federation/go.mod index 882ee7fd46..2c6235bf58 100644 --- a/examples/federation/go.mod +++ b/examples/federation/go.mod @@ -19,6 +19,7 @@ require ( github.com/bufbuild/protocompile v0.14.1 // indirect github.com/buger/jsonparser v1.1.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/coder/websocket v1.8.14 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.7 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dlclark/regexp2 v1.11.0 // indirect @@ -40,6 +41,7 @@ require ( github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/r3labs/sse/v2 v2.8.1 // indirect + github.com/rs/xid v1.6.0 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/sosodev/duration v1.3.1 // indirect diff --git a/examples/federation/go.sum b/examples/federation/go.sum index 4bcbb63885..b7957488be 100644 --- a/examples/federation/go.sum +++ b/examples/federation/go.sum @@ -20,8 +20,8 @@ github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL github.com/chzyer/logex v1.2.0/go.mod h1:9+9sk7u7pGNWYMkh0hdiL++6OeibzJccyQU4p4MedaY= github.com/chzyer/readline v1.5.0/go.mod h1:x22KAscuvRqlLoK9CsoYsmxoXZMMFVyOl86cAH8qUic= github.com/chzyer/test v0.0.0-20210722231415-061457976a23/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= -github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo= -github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= +github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g= +github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg= github.com/cpuguy83/go-md2man/v2 v2.0.7 h1:zbFlGlXEAKlwXpmvle3d8Oe3YnkKIK4xSRTd3sHPnBo= github.com/cpuguy83/go-md2man/v2 v2.0.7/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= @@ -59,8 +59,7 @@ github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= -github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/pprof v0.0.0-20230207041349-798e818bf904 h1:4/hN5RUoecvl+RmJRE2YxKWtnnQls6rQjjW5oV7qg2U= github.com/google/pprof v0.0.0-20230207041349-798e818bf904/go.mod h1:uglQLonpP8qtYCYyzA+8c/9qtqgA3qsXGYqCPKARAFg= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= @@ -121,6 +120,8 @@ github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFR github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= +github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU= +github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/sebdah/goldie/v2 v2.7.1 h1:PkBHymaYdtvEkZV7TmyqKxdmn5/Vcj+8TpATWZjnG5E= @@ -154,8 +155,7 @@ github.com/urfave/cli/v2 v2.27.7 h1:bH59vdhbjLv3LAvIu6gd0usJHgoTTPhCFib8qqOwXYU= github.com/urfave/cli/v2 v2.27.7/go.mod h1:CyNAG/xg+iAOg0N4MPGZqVmv2rCoP267496AOXUZjA4= github.com/vektah/gqlparser/v2 v2.5.30 h1:EqLwGAFLIzt1wpx1IPpY67DwUujF1OfzgEyDsLrN6kE= github.com/vektah/gqlparser/v2 v2.5.30/go.mod h1:D1/VCZtV3LPnQrcPBeR/q5jkSQIPti0uYCP/RI0gIeo= -github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083 h1:8/D7f8gKxTBjW+SZK4mhxTTBVpxcqeBgWF1Rfmltbfk= -github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083/go.mod h1:eOTL6acwctsN4F3b7YE+eE2t8zcJ/doLm9sZzsxxxrE= +github.com/wundergraph/astjson v1.0.0 h1:rETLJuQkMWWW03HCF6WBttEBOu8gi5vznj5KEUPVV2Q= github.com/wundergraph/astjson v1.0.0/go.mod h1:h12D/dxxnedtLzsKyBLK7/Oe4TAoGpRVC9nDpDrZSWw= github.com/wundergraph/cosmo/composition-go v0.0.0-20241020204711-78f240a77c99 h1:TGXDYfDhwFLFTuNuCwkuqXT5aXGz47zcurXLfTBS9w4= github.com/wundergraph/cosmo/composition-go v0.0.0-20241020204711-78f240a77c99/go.mod h1:fUuOAUAXUFB/mlSkAaImGeE4A841AKR5dTMWhV4ibxI= diff --git a/execution/engine/config_factory_federation.go b/execution/engine/config_factory_federation.go index fca8b342b1..dab9c708a0 100644 --- a/execution/engine/config_factory_federation.go +++ b/execution/engine/config_factory_federation.go @@ -451,25 +451,14 @@ func (f *FederationEngineConfigFactory) graphqlDataSourceFactory() (plan.Planner func (f *FederationEngineConfigFactory) subscriptionClient( httpClient *http.Client, streamingClient *http.Client, - subscriptionType SubscriptionType, + _ SubscriptionType, subscriptionClientFactory graphql_datasource.GraphQLSubscriptionClientFactory, ) (graphql_datasource.GraphQLSubscriptionClient, error) { - var graphqlSubscriptionClient graphql_datasource.GraphQLSubscriptionClient - switch subscriptionType { - case SubscriptionTypeGraphQLTransportWS: - graphqlSubscriptionClient = subscriptionClientFactory.NewSubscriptionClient( - httpClient, - streamingClient, - f.engineCtx, - ) - default: - // for compatibility reasons we fall back to graphql-ws protocol - graphqlSubscriptionClient = subscriptionClientFactory.NewSubscriptionClient( - httpClient, - streamingClient, - f.engineCtx, - ) - } + graphqlSubscriptionClient := subscriptionClientFactory.NewSubscriptionClient( + f.engineCtx, + graphql_datasource.WithUpgradeClient(httpClient), + graphql_datasource.WithStreamingClient(streamingClient), + ) ok := graphql_datasource.IsDefaultGraphQLSubscriptionClient(graphqlSubscriptionClient) if !ok { diff --git a/execution/engine/engine_config.go b/execution/engine/engine_config.go index 551a9858d6..67a38039b7 100644 --- a/execution/engine/engine_config.go +++ b/execution/engine/engine_config.go @@ -143,7 +143,7 @@ func (d *graphqlDataSourceGenerator) Generate(dsID string, config graphql_dataso return nil, err } - return plan.NewDataSourceConfiguration[graphql_datasource.Configuration]( + return plan.NewDataSourceConfiguration( dsID, factory, &plan.DataSourceMetadata{ @@ -155,27 +155,16 @@ func (d *graphqlDataSourceGenerator) Generate(dsID string, config graphql_dataso } func (d *graphqlDataSourceGenerator) generateSubscriptionClient(httpClient *http.Client, definedOptions *dataSourceGeneratorOptions) (graphql_datasource.GraphQLSubscriptionClient, error) { - var graphqlSubscriptionClient graphql_datasource.GraphQLSubscriptionClient - switch definedOptions.subscriptionType { - case SubscriptionTypeGraphQLTransportWS: - graphqlSubscriptionClient = definedOptions.subscriptionClientFactory.NewSubscriptionClient( - httpClient, - definedOptions.streamingClient, - nil, - ) - default: - // for compatibility reasons we fall back to graphql-ws protocol - graphqlSubscriptionClient = definedOptions.subscriptionClientFactory.NewSubscriptionClient( - httpClient, - definedOptions.streamingClient, - nil, - ) - } + graphqlSubscriptionClient := definedOptions.subscriptionClientFactory.NewSubscriptionClient(d.engineCtx, + graphql_datasource.WithUpgradeClient(httpClient), + graphql_datasource.WithStreamingClient(definedOptions.streamingClient), + ) ok := graphql_datasource.IsDefaultGraphQLSubscriptionClient(graphqlSubscriptionClient) if !ok { return nil, errors.New("invalid subscriptionClient was instantiated") } + return graphqlSubscriptionClient, nil } diff --git a/execution/engine/engine_config_test.go b/execution/engine/engine_config_test.go index db6427d70b..334202f939 100644 --- a/execution/engine/engine_config_test.go +++ b/execution/engine/engine_config_test.go @@ -280,11 +280,11 @@ func TestGraphqlFieldConfigurationsGenerator_Generate(t *testing.T) { } -var mockSubscriptionClient = graphqlDataSource.NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, context.Background()) +var mockSubscriptionClient = graphqlDataSource.NewGraphQLSubscriptionClient(context.Background()) type MockSubscriptionClientFactory struct{} -func (m *MockSubscriptionClientFactory) NewSubscriptionClient(httpClient, streamingClient *http.Client, engineCtx context.Context, options ...graphqlDataSource.Options) graphqlDataSource.GraphQLSubscriptionClient { +func (m *MockSubscriptionClientFactory) NewSubscriptionClient(engineCtx context.Context, options ...graphqlDataSource.SubscriptionClientOption) graphqlDataSource.GraphQLSubscriptionClient { return mockSubscriptionClient } diff --git a/execution/engine/execution_engine_test.go b/execution/engine/execution_engine_test.go index 1459cbdeaf..192a7cae77 100644 --- a/execution/engine/execution_engine_test.go +++ b/execution/engine/execution_engine_test.go @@ -54,7 +54,8 @@ func mustConfiguration(t *testing.T, input graphql_datasource.ConfigurationInput func mustFactory(t testing.TB, httpClient *http.Client) plan.PlannerFactory[graphql_datasource.Configuration] { t.Helper() - factory, err := graphql_datasource.NewFactory(context.Background(), httpClient, graphql_datasource.NewGraphQLSubscriptionClient(httpClient, httpClient, context.Background())) + factory, err := graphql_datasource.NewFactory(context.Background(), httpClient, graphql_datasource.NewGraphQLSubscriptionClient(context.Background(), + graphql_datasource.WithUpgradeClient(httpClient), graphql_datasource.WithStreamingClient(httpClient))) require.NoError(t, err) return factory @@ -6079,10 +6080,9 @@ func newFederationEngineStaticConfig(ctx context.Context, setup *federationtesti return } - subscriptionClient := graphql_datasource.NewGraphQLSubscriptionClient( - httpclient.DefaultNetHttpClient, - httpclient.DefaultNetHttpClient, - ctx, + subscriptionClient := graphql_datasource.NewGraphQLSubscriptionClient(ctx, + graphql_datasource.WithUpgradeClient(httpclient.DefaultNetHttpClient), + graphql_datasource.WithStreamingClient(httpclient.DefaultNetHttpClient), ) graphqlFactory, err := graphql_datasource.NewFactory(ctx, httpclient.DefaultNetHttpClient, subscriptionClient) diff --git a/execution/go.mod b/execution/go.mod index 8fb7c1fcb4..31ccb51c24 100644 --- a/execution/go.mod +++ b/execution/go.mod @@ -14,12 +14,12 @@ require ( github.com/sebdah/goldie/v2 v2.7.1 github.com/stretchr/testify v1.11.1 github.com/vektah/gqlparser/v2 v2.5.30 - github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083 + github.com/wundergraph/astjson v1.0.0 github.com/wundergraph/cosmo/composition-go v0.0.0-20241020204711-78f240a77c99 github.com/wundergraph/cosmo/router v0.0.0-20251013094319-c611abf26b17 github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.231 go.uber.org/atomic v1.11.0 - google.golang.org/grpc v1.68.1 + google.golang.org/grpc v1.71.0 google.golang.org/protobuf v1.36.9 ) @@ -29,6 +29,7 @@ require ( github.com/bufbuild/protocompile v0.14.1 // indirect github.com/buger/jsonparser v1.1.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/coder/websocket v1.8.14 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.7 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dlclark/regexp2 v1.11.0 // indirect @@ -65,6 +66,7 @@ require ( github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/sjson v1.2.5 // indirect github.com/urfave/cli/v2 v2.27.7 // indirect + github.com/wundergraph/go-arena v1.1.0 // indirect github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342 // indirect go.uber.org/multierr v1.11.0 // indirect go.uber.org/zap v1.27.0 // indirect diff --git a/execution/go.sum b/execution/go.sum index babde00b17..5d29f427f1 100644 --- a/execution/go.sum +++ b/execution/go.sum @@ -18,8 +18,8 @@ github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL github.com/chzyer/logex v1.2.0/go.mod h1:9+9sk7u7pGNWYMkh0hdiL++6OeibzJccyQU4p4MedaY= github.com/chzyer/readline v1.5.0/go.mod h1:x22KAscuvRqlLoK9CsoYsmxoXZMMFVyOl86cAH8qUic= github.com/chzyer/test v0.0.0-20210722231415-061457976a23/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= -github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo= -github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= +github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g= +github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg= github.com/cpuguy83/go-md2man/v2 v2.0.7 h1:zbFlGlXEAKlwXpmvle3d8Oe3YnkKIK4xSRTd3sHPnBo= github.com/cpuguy83/go-md2man/v2 v2.0.7/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= @@ -44,6 +44,10 @@ github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-sourcemap/sourcemap v2.1.3+incompatible h1:W1iEw64niKVGogNgBN3ePyLFfuisuzeidWPMPWmECqU= github.com/go-sourcemap/sourcemap v2.1.3+incompatible/go.mod h1:F8jJfvm2KbVjc5NqelyYJmf/v5J0dwNLS2mL4sNA1Jg= github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= @@ -163,18 +167,32 @@ github.com/urfave/cli/v2 v2.27.7 h1:bH59vdhbjLv3LAvIu6gd0usJHgoTTPhCFib8qqOwXYU= github.com/urfave/cli/v2 v2.27.7/go.mod h1:CyNAG/xg+iAOg0N4MPGZqVmv2rCoP267496AOXUZjA4= github.com/vektah/gqlparser/v2 v2.5.30 h1:EqLwGAFLIzt1wpx1IPpY67DwUujF1OfzgEyDsLrN6kE= github.com/vektah/gqlparser/v2 v2.5.30/go.mod h1:D1/VCZtV3LPnQrcPBeR/q5jkSQIPti0uYCP/RI0gIeo= -github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083 h1:8/D7f8gKxTBjW+SZK4mhxTTBVpxcqeBgWF1Rfmltbfk= -github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083/go.mod h1:eOTL6acwctsN4F3b7YE+eE2t8zcJ/doLm9sZzsxxxrE= +github.com/wundergraph/astjson v1.0.0 h1:rETLJuQkMWWW03HCF6WBttEBOu8gi5vznj5KEUPVV2Q= +github.com/wundergraph/astjson v1.0.0/go.mod h1:h12D/dxxnedtLzsKyBLK7/Oe4TAoGpRVC9nDpDrZSWw= github.com/wundergraph/cosmo/composition-go v0.0.0-20241020204711-78f240a77c99 h1:TGXDYfDhwFLFTuNuCwkuqXT5aXGz47zcurXLfTBS9w4= github.com/wundergraph/cosmo/composition-go v0.0.0-20241020204711-78f240a77c99/go.mod h1:fUuOAUAXUFB/mlSkAaImGeE4A841AKR5dTMWhV4ibxI= github.com/wundergraph/cosmo/router v0.0.0-20251013094319-c611abf26b17 h1:GjO2E8LTf3U5JiQJCY4MmlRcAjVt7IvAbWFSgEjQdl8= github.com/wundergraph/cosmo/router v0.0.0-20251013094319-c611abf26b17/go.mod h1:7kt64e0LOLMBqOzrfu9PuLRn9cVT9YN1Bb3EennVtws= +github.com/wundergraph/go-arena v1.1.0 h1:9+wSRkJAkA2vbYHp6s8tEGhPViRGQNGXqPHT0QzhdIc= +github.com/wundergraph/go-arena v1.1.0/go.mod h1:ROOysEHWJjLQ8FSfNxZCziagb7Qw2nXY3/vgKRh7eWw= github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.231 h1:2C8LNFGs8MtI2yPy2/a2WRf9/X2FoMqXlEJkpTjvsTg= github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.231/go.mod h1:ErOQH1ki2+SZB8JjpTyGVnoBpg5picIyjvuWQJP4abg= github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342 h1:FnBeRrxr7OU4VvAzt5X7s6266i6cSVkkFPS0TuXWbIg= github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= +go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= +go.opentelemetry.io/otel v1.36.0 h1:UumtzIklRBY6cI/lllNZlALOF5nNIzJVb16APdvgTXg= +go.opentelemetry.io/otel v1.36.0/go.mod h1:/TcFMXYjyRNh8khOAO9ybYkqaDBb/70aVwkNML4pP8E= +go.opentelemetry.io/otel/metric v1.36.0 h1:MoWPKVhQvJ+eeXWHFBOPoBOi20jh6Iq2CcCREuTYufE= +go.opentelemetry.io/otel/metric v1.36.0/go.mod h1:zC7Ks+yeyJt4xig9DEw9kuUFe5C3zLbVjV2PzT6qzbs= +go.opentelemetry.io/otel/sdk v1.36.0 h1:b6SYIuLRs88ztox4EyrvRti80uXIFy+Sqzoh9kFULbs= +go.opentelemetry.io/otel/sdk v1.36.0/go.mod h1:+lC+mTgD+MUWfjJubi2vvXWcVxyr9rmlshZni72pXeY= +go.opentelemetry.io/otel/sdk/metric v1.36.0 h1:r0ntwwGosWGaa0CrSt8cuNuTcccMXERFwHX4dThiPis= +go.opentelemetry.io/otel/sdk/metric v1.36.0/go.mod h1:qTNOhFDfKRwX0yXOqJYegL5WRaW376QbB7P4Pb0qva4= +go.opentelemetry.io/otel/trace v1.36.0 h1:ahxWNuqZjpdiFAyrIoQ4GIiAIhxAunQR6MUoKrsNd4w= +go.opentelemetry.io/otel/trace v1.36.0/go.mod h1:gQ+OnDZzrybY4k4seLzPAWNwVBBVlF2szhehOBB/tGA= go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= @@ -258,8 +276,8 @@ gonum.org/v1/gonum v0.14.0 h1:2NiG67LD1tEH0D7kM+ps2V+fXmsAnpUeec7n8tcr4S0= gonum.org/v1/gonum v0.14.0/go.mod h1:AoWeoz0becf9QMWtE8iWXNXc27fK4fNeHNf/oMejGfU= google.golang.org/genproto/googleapis/rpc v0.0.0-20250115164207-1a7da9e5054f h1:OxYkA3wjPsZyBylwymxSHa7ViiW1Sml4ToBrncvFehI= google.golang.org/genproto/googleapis/rpc v0.0.0-20250115164207-1a7da9e5054f/go.mod h1:+2Yz8+CLJbIfL9z73EW45avw8Lmge3xVElCP9zEKi50= -google.golang.org/grpc v1.68.1 h1:oI5oTa11+ng8r8XMMN7jAOmWfPZWbYpCFaMUTACxkM0= -google.golang.org/grpc v1.68.1/go.mod h1:+q1XYFJjShcqn0QZHvCyeR4CXPA+llXIeUIfIe00waw= +google.golang.org/grpc v1.71.0 h1:kF77BGdPTQ4/JZWMlb9VpJ5pa25aqvVqogsxNHHdeBg= +google.golang.org/grpc v1.71.0/go.mod h1:H0GRtasmQOh9LkFoCPDu3ZrwUtD1YGE+b2vYBYd/8Ec= google.golang.org/protobuf v1.36.9 h1:w2gp2mA27hUeUzj9Ex9FBjsBm40zfaDtEWow293U7Iw= google.golang.org/protobuf v1.36.9/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU= gopkg.in/cenkalti/backoff.v1 v1.1.0 h1:Arh75ttbsvlpVA7WtVpH4u9h6Zl46xuptxqLxPiSo4Y= diff --git a/execution/graphql/result_writer.go b/execution/graphql/result_writer.go index f97cbf721b..e3905207d3 100644 --- a/execution/graphql/result_writer.go +++ b/execution/graphql/result_writer.go @@ -39,8 +39,9 @@ func (e *EngineResultWriter) Heartbeat() error { return nil } -func (e *EngineResultWriter) Close(_ resolve.SubscriptionCloseKind) { - +func (e *EngineResultWriter) Error(data []byte) { + e.buf.Write(data) + e.Flush() } func (e *EngineResultWriter) SetFlushCallback(flushCb func(data []byte)) { diff --git a/execution/subscription/legacy_handler_test.go b/execution/subscription/legacy_handler_test.go index 1ca4933258..efc42b4354 100644 --- a/execution/subscription/legacy_handler_test.go +++ b/execution/subscription/legacy_handler_test.go @@ -567,10 +567,9 @@ func setupEngineV2(t *testing.T, ctx context.Context, chatServerURL string) (*Ex engineConf := engine.NewConfiguration(chatSchema) - subscriptionClient := graphql_datasource.NewGraphQLSubscriptionClient( - httpclient.DefaultNetHttpClient, - httpclient.DefaultNetHttpClient, - ctx, + subscriptionClient := graphql_datasource.NewGraphQLSubscriptionClient(ctx, + graphql_datasource.WithUpgradeClient(httpclient.DefaultNetHttpClient), + graphql_datasource.WithStreamingClient(httpclient.DefaultNetHttpClient), ) factory, err := graphql_datasource.NewFactory(ctx, httpclient.DefaultNetHttpClient, subscriptionClient) diff --git a/execution/subscription/websocket/handler_test.go b/execution/subscription/websocket/handler_test.go index acfe4d4e24..21849b271b 100644 --- a/execution/subscription/websocket/handler_test.go +++ b/execution/subscription/websocket/handler_test.go @@ -279,10 +279,9 @@ func setupExecutorPoolV2(t *testing.T, ctx context.Context, chatServerURL string engineConf := engine.NewConfiguration(chatSchema) engineConf.SetWebsocketBeforeStartHook(onBeforeStartHook) - subscriptionClient := graphql_datasource.NewGraphQLSubscriptionClient( - httpclient.DefaultNetHttpClient, - httpclient.DefaultNetHttpClient, - ctx, + subscriptionClient := graphql_datasource.NewGraphQLSubscriptionClient(ctx, + graphql_datasource.WithUpgradeClient(httpclient.DefaultNetHttpClient), + graphql_datasource.WithStreamingClient(httpclient.DefaultNetHttpClient), ) factory, err := graphql_datasource.NewFactory(ctx, httpclient.DefaultNetHttpClient, subscriptionClient) diff --git a/go.work.sum b/go.work.sum index 1aecd8d220..f3972893ca 100644 --- a/go.work.sum +++ b/go.work.sum @@ -49,6 +49,7 @@ github.com/cloudflare/backoff v0.0.0-20161212185259-647f3cdfc87a h1:8d1CEOF1xlde github.com/cloudflare/backoff v0.0.0-20161212185259-647f3cdfc87a/go.mod h1:rzgs2ZOiguV6/NpiDgADjRLPNyZlApIWxKpkT+X8SdY= github.com/cncf/xds/go v0.0.0-20241223141626-cff3c89139a3 h1:boJj011Hh+874zpIySeApCX4GeOjPl9qhRF3QuIZq+Q= github.com/cncf/xds/go v0.0.0-20241223141626-cff3c89139a3/go.mod h1:W+zGtBO5Y1IgJhy4+A9GOqVhqLpfZi+vwmdNXUehLA8= +github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= github.com/containerd/cgroups/v3 v3.0.2 h1:f5WFqIVSgo5IZmtTT3qVBo6TzI1ON6sycSBKkymb9L0= github.com/containerd/cgroups/v3 v3.0.2/go.mod h1:JUgITrzdFqp42uI2ryGA+ge0ap/nxzYgkGmIcetmErE= github.com/containerd/stargz-snapshotter/estargz v0.16.3 h1:7evrXtoh1mSbGj/pfRccTampEyKpjpOnS3CyiV1Ebr8= @@ -56,6 +57,7 @@ github.com/containerd/stargz-snapshotter/estargz v0.16.3/go.mod h1:uyr4BfYfOj3G9 github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/creack/pty v1.1.9 h1:uDmaGzcdjhF4i/plgjmEsriH11Y0o7RKapEf/LDaM3w= +github.com/davecgh/go-spew v0.0.0-20161028175848-04cdfd42973b/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgraph-io/ristretto/v2 v2.1.0 h1:59LjpOJLNDULHh8MC4UaegN52lC4JnO2dITsie/Pa8I= github.com/dgraph-io/ristretto/v2 v2.1.0/go.mod h1:uejeqfYXpUomfse0+lO+13ATz4TypQYLJZzBSAemuB4= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= @@ -81,6 +83,8 @@ github.com/envoyproxy/protoc-gen-validate v1.2.1 h1:DEo3O99U8j4hBFwbJfrz9VtgcDfU github.com/envoyproxy/protoc-gen-validate v1.2.1/go.mod h1:d/C80l/jxXLdfEIhX1W2TmLfsJ31lvEjwamM4DxlWXU= github.com/expr-lang/expr v1.17.6 h1:1h6i8ONk9cexhDmowO/A64VPxHScu7qfSl2k8OlINec= github.com/expr-lang/expr v1.17.6/go.mod h1:8/vRC7+7HBzESEqt5kKpYXxrxkr31SaO8r40VO/1IT4= +github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= +github.com/fatih/color v1.16.0/go.mod h1:fL2Sau1YI5c0pdGEVCbKQbLXB6edEj1ZgiY4NijnWvE= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/go-chi/chi/v5 v5.2.2 h1:CMwsvRVTbXVytCk1Wd72Zy1LAsAh9GxMmSNWLHCG618= @@ -116,8 +120,7 @@ github.com/golang/glog v1.2.4 h1:CNNw5U8lSiiBk7druxtSHHTsRWcxKoac6kZKm2peBBc= github.com/golang/glog v1.2.4/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= -github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-containerregistry v0.20.3 h1:oNx7IdTI936V8CQRveCjaxOiegWwvM7kqkbXTpyiovI= github.com/google/go-containerregistry v0.20.3/go.mod h1:w00pIgBRDVUDFM6bq+Qx8lwNWK+cxgCuX1vd3PIBDNI= github.com/google/renameio v0.1.0 h1:GOZbcHa3HfsPKPlmyPyN2KEohoMXOhdMbHrvbpl2QaA= @@ -132,6 +135,7 @@ github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= +github.com/hashicorp/go-hclog v0.14.1/go.mod h1:whpDNt7SSdeAju8AWKIWsul05p54N/39EeqMAyrmvFQ= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= github.com/hashicorp/go-retryablehttp v0.7.7 h1:C8hUCYzor8PIfXHa4UrZkU4VvK8o9ISHxT2Q8+VepXU= @@ -141,6 +145,7 @@ github.com/iancoleman/strcase v0.3.0/go.mod h1:iwCmte+B7n89clKwxIoIXy/HfoL7AsD47 github.com/ianlancetaylor/demangle v0.0.0-20220319035150-800ac71e25c2 h1:rcanfLhLDA8nozr/K289V1zcntHr3V+SHlXwzz1ZI2g= github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= +github.com/jensneuse/byte-template v0.0.0-20200214152254-4f3cf06e5c68/go.mod h1:0D5r/VSW6D/o65rKLL9xk7sZxL2+oku2HvFPYeIMFr4= github.com/jessevdk/go-flags v1.4.0 h1:4IU2WS7AumrZ/40jfhf4QVDMsQwqA7VEHozFRrGARJA= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= @@ -164,7 +169,10 @@ github.com/mark3labs/mcp-go v0.36.0/go.mod h1:T7tUa2jO6MavG+3P25Oy/jR7iCeJPHImCZ github.com/matryer/moq v0.2.7 h1:RtpiPUM8L7ZSCbSwK+QcZH/E9tgqAkFjKQxsRs25b4w= github.com/matryer/moq v0.5.2 h1:b2bsanSaO6IdraaIvPBzHnqcrkkQmk1/310HdT2nNQs= github.com/matryer/moq v0.5.2/go.mod h1:W/k5PLfou4f+bzke9VPXTbfJljxoeR1tLHigsmbshmU= +github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.10/go.mod h1:qgIWMr58cqv1PHHyhnkY9lrL7etaEgOFcMEpPG5Rm84= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34= github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM= @@ -195,6 +203,7 @@ github.com/pierrec/lz4/v4 v4.1.21/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFu github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e h1:aoZm08cpOy4WuID//EZDgcC4zIxODThtZNPirFr42+A= github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 h1:GFCKgmp0tecUJ0sJuv4pzYCqS9+RGSn52M3FUwPs+uo= github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8= +github.com/pmezard/go-difflib v0.0.0-20151028094244-d8ed2627bdf0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/posthog/posthog-go v1.5.5 h1:2o3j7IrHbTIfxRtj4MPaXKeimuTYg49onNzNBZbwksM= github.com/posthog/posthog-go v1.5.5/go.mod h1:3RqUmSnPuwmeVj/GYrS75wNGqcAKdpODiwc83xZWgdE= github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw= @@ -213,6 +222,7 @@ github.com/redis/go-redis/v9 v9.4.0 h1:Yzoz33UZw9I/mFhx4MNrB6Fk+XHO1VukNcCa1+lwy github.com/redis/go-redis/v9 v9.4.0/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M= github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc= github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= +github.com/sanity-io/litter v1.5.8/go.mod h1:9gzJgR2i4ZpjZHsKvUXIRQVk7P+yM3e+jAF7bU2UI5U= github.com/santhosh-tekuri/jsonschema/v6 v6.0.1 h1:PKK9DyHxif4LZo+uQSgXNqs0jj5+xZwwfKHgph2lxBw= github.com/santhosh-tekuri/jsonschema/v6 v6.0.1/go.mod h1:JXeL+ps8p7/KNMjDQk3TCwPpBy0wYklyWTfbkIzdIFU= github.com/shirou/gopsutil/v3 v3.24.3 h1:eoUGJSmdfLzJ3mxIhmOAhgKEKgQkeOwKpz1NbhVnuPE= @@ -227,10 +237,12 @@ github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSS github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v0.0.0-20161117074351-18a02ba4a312/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tidwall/gjson v1.17.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/sjson v1.0.4 h1:UcdIRXff12Lpnu3OLtZvnc03g4vH2suXDXhBwBqmzYg= github.com/tidwall/sjson v1.0.4/go.mod h1:bURseu1nuBkFpIES5cz6zBtjmYeOQmEESshn7VpF15Y= github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU= @@ -246,9 +258,12 @@ github.com/twmb/franz-go/pkg/kmsg v1.7.0/go.mod h1:se9Mjdt0Nwzc9lnjJ0HyDtLyBnaBD github.com/vbatts/tar-split v0.12.1 h1:CqKoORW7BUWBe7UL/iqTVvkTBOF8UvOMKOIZykxnnbo= github.com/vbatts/tar-split v0.12.1/go.mod h1:eF6B6i6ftWQcDqEn3/iGFRFRo8cBIMSJVOpnNdfTMFA= github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= +github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= +github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083/go.mod h1:eOTL6acwctsN4F3b7YE+eE2t8zcJ/doLm9sZzsxxxrE= +github.com/wundergraph/astjson v1.1.0 h1:xORDosrZ87zQFJwNGe/HIHXqzpdHOFmqWgykCLVL040= +github.com/wundergraph/astjson v1.1.0/go.mod h1:h12D/dxxnedtLzsKyBLK7/Oe4TAoGpRVC9nDpDrZSWw= github.com/wundergraph/go-arena v0.0.0-20251008210416-55cb97e6f68f h1:5snewyMaIpajTu4wj22L/DgrGimICqXtUVjkZInBH3Y= github.com/wundergraph/go-arena v0.0.0-20251008210416-55cb97e6f68f/go.mod h1:ROOysEHWJjLQ8FSfNxZCziagb7Qw2nXY3/vgKRh7eWw= -github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/yuin/goldmark v1.4.13 h1:fVcFKWvrslecOb/tg+Cc05dkeYx540o0FuFt3nUVDoE= @@ -289,6 +304,7 @@ go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN8 go.uber.org/ratelimit v0.3.1 h1:K4qVE+byfv/B3tC+4nYWP7v/6SimcO7HzHekoMNBma0= go.uber.org/ratelimit v0.3.1/go.mod h1:6euWsTB6U/Nb3X++xEUXA8ciPJvr19Q/0h1+oDcJhRk= go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee h1:0mgffUl7nfd+FpvXMVz4IDEaUSmT1ysygQC7qYo7sG4= +go.uber.org/zap v1.26.0/go.mod h1:dtElttAiwGvoJ/vj4IwHBS/gXsEu/pZ50mUIRWuG0so= go.withmatt.com/connect-brotli v0.4.0 h1:7ObWkYmEbUXK3EKglD0Lgj0BBnnD3jNdAxeDRct3l8E= go.withmatt.com/connect-brotli v0.4.0/go.mod h1:c2eELz56za+/Mxh1yJrlglZ4VM9krpOCPqS2Vxf8NVk= golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= @@ -326,6 +342,8 @@ golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20191008105621-543471e840be/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -367,6 +385,7 @@ gonum.org/v1/plot v0.10.1 h1:dnifSs43YJuNMDzB7v8wV64O4ABBHReuAVAoBxqBqS4= gonum.org/v1/plot v0.10.1/go.mod h1:VZW5OlhkL1mysU9vaqNHnsy86inf6Ot+jB3r+BczCEo= google.golang.org/genproto/googleapis/api v0.0.0-20250106144421-5f5ef82da422 h1:GVIKPyP/kLIyVOgOnTwFOrvQaQUzOzGMCxgFUOEmm24= google.golang.org/genproto/googleapis/api v0.0.0-20250106144421-5f5ef82da422/go.mod h1:b6h1vNKhxaSoEI+5jc3PJUCustfli/mRab7295pY7rw= +google.golang.org/grpc v1.68.1/go.mod h1:+q1XYFJjShcqn0QZHvCyeR4CXPA+llXIeUIfIe00waw= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= google.golang.org/protobuf v1.36.4/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= diff --git a/v2/go.mod b/v2/go.mod index ad5d096fc1..534dfc0540 100644 --- a/v2/go.mod +++ b/v2/go.mod @@ -7,38 +7,36 @@ require ( github.com/bufbuild/protocompile v0.14.1 github.com/buger/jsonparser v1.1.1 github.com/cespare/xxhash/v2 v2.3.0 - github.com/coder/websocket v1.8.12 + github.com/coder/websocket v1.8.14 github.com/davecgh/go-spew v1.1.1 - github.com/gobwas/ws v1.4.0 github.com/golang/mock v1.6.0 github.com/google/go-cmp v0.7.0 github.com/google/uuid v1.6.0 github.com/gorilla/websocket v1.5.1 github.com/hashicorp/go-plugin v1.6.3 github.com/jensneuse/abstractlogger v0.0.4 - github.com/jensneuse/byte-template v0.0.0-20200214152254-4f3cf06e5c68 + github.com/jensneuse/byte-template v0.0.0-20231025215717-69252eb3ed56 github.com/jensneuse/diffview v1.0.0 github.com/kingledion/go-tools v0.6.0 github.com/kylelemons/godebug v1.1.0 github.com/phf/go-queue v0.0.0-20170504031614-9abe38d0371d github.com/pkg/errors v0.9.1 github.com/r3labs/sse/v2 v2.8.1 + github.com/rs/xid v1.6.0 github.com/santhosh-tekuri/jsonschema/v5 v5.3.1 github.com/sebdah/goldie/v2 v2.7.1 github.com/stretchr/testify v1.11.1 - github.com/tidwall/gjson v1.17.0 - github.com/tidwall/sjson v1.0.4 + github.com/tidwall/gjson v1.18.0 + github.com/tidwall/sjson v1.2.5 github.com/vektah/gqlparser/v2 v2.5.30 github.com/wundergraph/astjson v1.1.0 github.com/wundergraph/go-arena v1.1.0 - go.uber.org/atomic v1.11.0 go.uber.org/goleak v1.3.0 - go.uber.org/zap v1.26.0 golang.org/x/sync v0.17.0 golang.org/x/sys v0.37.0 golang.org/x/text v0.30.0 gonum.org/v1/gonum v0.14.0 - google.golang.org/grpc v1.68.1 + google.golang.org/grpc v1.71.0 google.golang.org/protobuf v1.36.9 gopkg.in/yaml.v2 v2.4.0 ) @@ -50,12 +48,11 @@ require ( github.com/dnephin/pflag v1.0.7 // indirect github.com/fatih/color v1.18.0 // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect + github.com/go-logr/logr v1.4.3 // indirect github.com/go-viper/mapstructure/v2 v2.4.0 // indirect - github.com/gobwas/httphead v0.1.0 // indirect - github.com/gobwas/pool v0.2.1 // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect - github.com/hashicorp/go-hclog v0.14.1 // indirect + github.com/hashicorp/go-hclog v1.6.3 // indirect github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect github.com/hashicorp/yamux v0.1.1 // indirect github.com/kr/pretty v0.3.1 // indirect @@ -72,7 +69,10 @@ require ( github.com/tidwall/pretty v1.2.1 // indirect github.com/urfave/cli/v2 v2.27.7 // indirect github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342 // indirect + go.opentelemetry.io/otel v1.36.0 // indirect + go.opentelemetry.io/otel/sdk/metric v1.36.0 // indirect go.uber.org/multierr v1.11.0 // indirect + go.uber.org/zap v1.27.0 // indirect golang.org/x/mod v0.29.0 // indirect golang.org/x/net v0.46.0 // indirect golang.org/x/term v0.36.0 // indirect diff --git a/v2/go.sum b/v2/go.sum index 13adfeb881..f8d6249fc8 100644 --- a/v2/go.sum +++ b/v2/go.sum @@ -15,8 +15,8 @@ github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMU github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo= -github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= +github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g= +github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg= github.com/cpuguy83/go-md2man/v2 v2.0.7 h1:zbFlGlXEAKlwXpmvle3d8Oe3YnkKIK4xSRTd3sHPnBo= github.com/cpuguy83/go-md2man/v2 v2.0.7/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= @@ -27,19 +27,17 @@ github.com/dgryski/trifles v0.0.0-20230903005119-f50d829f2e54 h1:SG7nF6SRlWhcT7c github.com/dgryski/trifles v0.0.0-20230903005119-f50d829f2e54/go.mod h1:if7Fbed8SFyPtHLHbg49SI7NAdJiC5WIA09pe59rfAA= github.com/dnephin/pflag v1.0.7 h1:oxONGlWxhmUct0YzKTgrpQv9AUA1wtPBn7zuSjJqptk= github.com/dnephin/pflag v1.0.7/go.mod h1:uxE91IoWURlOiTUIA8Mq5ZZkAv3dPUfZNaT80Zm7OQE= -github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= +github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= -github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU= -github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM= -github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og= -github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= -github.com/gobwas/ws v1.4.0 h1:CTaoG1tojrh4ucGPcoJFiAQUAsEWekEWvLy7GsVNqGs= -github.com/gobwas/ws v1.4.0/go.mod h1:G3gNqMNtPppf5XUz7O4shetPpcZ1VJ7zt18dlUeakrc= github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= @@ -53,8 +51,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY= github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= -github.com/hashicorp/go-hclog v0.14.1 h1:nQcJDQwIAGnmoUWp8ubocEX40cCml/17YkF6csQLReU= -github.com/hashicorp/go-hclog v0.14.1/go.mod h1:whpDNt7SSdeAju8AWKIWsul05p54N/39EeqMAyrmvFQ= +github.com/hashicorp/go-hclog v1.6.3 h1:Qr2kF+eVWjTiYmU7Y31tYlP1h0q/X3Nl3tPGdaB11/k= +github.com/hashicorp/go-hclog v1.6.3/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M= github.com/hashicorp/go-plugin v1.6.3 h1:xgHB+ZUSYeuJi96WtxEjzi23uh7YQpznjGh0U0UUrwg= github.com/hashicorp/go-plugin v1.6.3/go.mod h1:MRobyh+Wc/nYy1V4KAXUiYfzxoYhs7V1mlH1Z7iY2h0= github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= @@ -63,8 +61,8 @@ github.com/hashicorp/yamux v0.1.1 h1:yrQxtgseBDrq9Y652vSRDvsKCJKOUD+GzTS4Y0Y8pvE github.com/hashicorp/yamux v0.1.1/go.mod h1:CtWFDAQgb7dxtzFs4tWbplKIe2jSi3+5vKbgIO0SLnQ= github.com/jensneuse/abstractlogger v0.0.4 h1:sa4EH8fhWk3zlTDbSncaWKfwxYM8tYSlQ054ETLyyQY= github.com/jensneuse/abstractlogger v0.0.4/go.mod h1:6WuamOHuykJk8zED/R0LNiLhWR6C7FIAo43ocUEB3mo= -github.com/jensneuse/byte-template v0.0.0-20200214152254-4f3cf06e5c68 h1:E80wOd3IFQcoBxLkAUpUQ3BoGrZ4DxhQdP21+HH1s6A= -github.com/jensneuse/byte-template v0.0.0-20200214152254-4f3cf06e5c68/go.mod h1:0D5r/VSW6D/o65rKLL9xk7sZxL2+oku2HvFPYeIMFr4= +github.com/jensneuse/byte-template v0.0.0-20231025215717-69252eb3ed56 h1:wo26fh6a6Za0cOMZIopD2sfH/kq83SJ89ixUWl7pCWc= +github.com/jensneuse/byte-template v0.0.0-20231025215717-69252eb3ed56/go.mod h1:0D5r/VSW6D/o65rKLL9xk7sZxL2+oku2HvFPYeIMFr4= github.com/jensneuse/diffview v1.0.0 h1:4b6FQJ7y3295JUHU3tRko6euyEboL825ZsXeZZM47Z4= github.com/jensneuse/diffview v1.0.0/go.mod h1:i6IacuD8LnEaPuiyzMHA+Wfz5mAuycMOf3R/orUY9y4= github.com/jhump/protoreflect v1.15.1 h1:HUMERORf3I3ZdX05WaQ6MIpd/NJ434hTp5YiKgfCL6c= @@ -83,11 +81,12 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= -github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= +github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= +github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= -github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= -github.com/mattn/go-isatty v0.0.10/go.mod h1:qgIWMr58cqv1PHHyhnkY9lrL7etaEgOFcMEpPG5Rm84= +github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/oklog/run v1.0.0 h1:Ru7dDtJNOyC66gQ5dQmaCa0qIsAUFY3sFpK1Xk8igrw= @@ -106,6 +105,8 @@ github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFR github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= +github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU= +github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/santhosh-tekuri/jsonschema/v5 v5.3.1 h1:lZUw3E0/J3roVtGQ+SCrUrg3ON6NgVqpn3+iol9aGu4= @@ -126,17 +127,19 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= -github.com/tidwall/gjson v1.17.0 h1:/Jocvlh98kcTfpN2+JzGQWQcqrPQwDrVEMApx/M5ZwM= -github.com/tidwall/gjson v1.17.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= +github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= -github.com/tidwall/sjson v1.0.4 h1:UcdIRXff12Lpnu3OLtZvnc03g4vH2suXDXhBwBqmzYg= -github.com/tidwall/sjson v1.0.4/go.mod h1:bURseu1nuBkFpIES5cz6zBtjmYeOQmEESshn7VpF15Y= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/urfave/cli/v2 v2.27.7 h1:bH59vdhbjLv3LAvIu6gd0usJHgoTTPhCFib8qqOwXYU= github.com/urfave/cli/v2 v2.27.7/go.mod h1:CyNAG/xg+iAOg0N4MPGZqVmv2rCoP267496AOXUZjA4= github.com/vektah/gqlparser/v2 v2.5.30 h1:EqLwGAFLIzt1wpx1IPpY67DwUujF1OfzgEyDsLrN6kE= @@ -148,9 +151,19 @@ github.com/wundergraph/go-arena v1.1.0/go.mod h1:ROOysEHWJjLQ8FSfNxZCziagb7Qw2nX github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342 h1:FnBeRrxr7OU4VvAzt5X7s6266i6cSVkkFPS0TuXWbIg= github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= +go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= +go.opentelemetry.io/otel v1.36.0 h1:UumtzIklRBY6cI/lllNZlALOF5nNIzJVb16APdvgTXg= +go.opentelemetry.io/otel v1.36.0/go.mod h1:/TcFMXYjyRNh8khOAO9ybYkqaDBb/70aVwkNML4pP8E= +go.opentelemetry.io/otel/metric v1.36.0 h1:MoWPKVhQvJ+eeXWHFBOPoBOi20jh6Iq2CcCREuTYufE= +go.opentelemetry.io/otel/metric v1.36.0/go.mod h1:zC7Ks+yeyJt4xig9DEw9kuUFe5C3zLbVjV2PzT6qzbs= +go.opentelemetry.io/otel/sdk v1.36.0 h1:b6SYIuLRs88ztox4EyrvRti80uXIFy+Sqzoh9kFULbs= +go.opentelemetry.io/otel/sdk v1.36.0/go.mod h1:+lC+mTgD+MUWfjJubi2vvXWcVxyr9rmlshZni72pXeY= +go.opentelemetry.io/otel/sdk/metric v1.36.0 h1:r0ntwwGosWGaa0CrSt8cuNuTcccMXERFwHX4dThiPis= +go.opentelemetry.io/otel/sdk/metric v1.36.0/go.mod h1:qTNOhFDfKRwX0yXOqJYegL5WRaW376QbB7P4Pb0qva4= +go.opentelemetry.io/otel/trace v1.36.0 h1:ahxWNuqZjpdiFAyrIoQ4GIiAIhxAunQR6MUoKrsNd4w= +go.opentelemetry.io/otel/trace v1.36.0/go.mod h1:gQ+OnDZzrybY4k4seLzPAWNwVBBVlF2szhehOBB/tGA= go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= -go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= -go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4= @@ -158,8 +171,8 @@ go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA= go.uber.org/zap v1.13.0/go.mod h1:zwrFLgMcdUuIBviXEYEH1YKNaOBnKXsx2IPda5bBwHM= -go.uber.org/zap v1.26.0 h1:sI7k6L95XOKS281NhVKOFCUNIvv9e0w4BF8N3u+tCRo= -go.uber.org/zap v1.26.0/go.mod h1:dtElttAiwGvoJ/vj4IwHBS/gXsEu/pZ50mUIRWuG0so= +go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= +go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= @@ -180,13 +193,16 @@ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191008105621-543471e840be/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ= @@ -214,8 +230,8 @@ gonum.org/v1/gonum v0.14.0 h1:2NiG67LD1tEH0D7kM+ps2V+fXmsAnpUeec7n8tcr4S0= gonum.org/v1/gonum v0.14.0/go.mod h1:AoWeoz0becf9QMWtE8iWXNXc27fK4fNeHNf/oMejGfU= google.golang.org/genproto/googleapis/rpc v0.0.0-20250115164207-1a7da9e5054f h1:OxYkA3wjPsZyBylwymxSHa7ViiW1Sml4ToBrncvFehI= google.golang.org/genproto/googleapis/rpc v0.0.0-20250115164207-1a7da9e5054f/go.mod h1:+2Yz8+CLJbIfL9z73EW45avw8Lmge3xVElCP9zEKi50= -google.golang.org/grpc v1.68.1 h1:oI5oTa11+ng8r8XMMN7jAOmWfPZWbYpCFaMUTACxkM0= -google.golang.org/grpc v1.68.1/go.mod h1:+q1XYFJjShcqn0QZHvCyeR4CXPA+llXIeUIfIe00waw= +google.golang.org/grpc v1.71.0 h1:kF77BGdPTQ4/JZWMlb9VpJ5pa25aqvVqogsxNHHdeBg= +google.golang.org/grpc v1.71.0/go.mod h1:H0GRtasmQOh9LkFoCPDu3ZrwUtD1YGE+b2vYBYd/8Ec= google.golang.org/protobuf v1.36.9 h1:w2gp2mA27hUeUzj9Ex9FBjsBm40zfaDtEWow293U7Iw= google.golang.org/protobuf v1.36.9/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU= gopkg.in/cenkalti/backoff.v1 v1.1.0 h1:Arh75ttbsvlpVA7WtVpH4u9h6Zl46xuptxqLxPiSo4Y= diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go index f4268d1f6a..e2de738182 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go @@ -425,7 +425,7 @@ func (p *Planner[T]) ConfigureSubscription() plan.SubscriptionConfiguration { p.stopWithError(errors.WithStack(fmt.Errorf("ConfigureSubscription: failed to marshal header: %w", err))) return plan.SubscriptionConfiguration{} } - if err == nil && len(header) != 0 && !bytes.Equal(header, literal.NULL) { + if len(header) != 0 && !bytes.Equal(header, literal.NULL) { input = httpclient.SetInputHeader(input, header) } @@ -1960,8 +1960,6 @@ func (s *Source) Load(ctx context.Context, headers http.Header, input []byte) (d type GraphQLSubscriptionClient interface { // Subscribe to the origin source. The implementation must not block the calling goroutine. Subscribe(ctx *resolve.Context, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) error - SubscribeAsync(ctx *resolve.Context, id uint64, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) error - Unsubscribe(id uint64) } type GraphQLSubscriptionOptions struct { @@ -1996,25 +1994,6 @@ type SubscriptionSource struct { subscriptionOnStartFns []SubscriptionOnStartFn } -func (s *SubscriptionSource) AsyncStart(ctx *resolve.Context, id uint64, headers http.Header, input []byte, updater resolve.SubscriptionUpdater) error { - var options GraphQLSubscriptionOptions - err := json.Unmarshal(input, &options) - if err != nil { - return err - } - options.Header = headers - if options.Body.Query == "" { - return resolve.ErrUnableToResolve - } - return s.client.SubscribeAsync(ctx, id, options, updater) -} - -// AsyncStop stops the subscription with the given id. AsyncStop is only effective when netPoll is enabled -// because without netPoll we manage the lifecycle of the connection in the subscription client. -func (s *SubscriptionSource) AsyncStop(id uint64) { - s.client.Unsubscribe(id) -} - // Start the subscription. The updater is called on new events. Start needs to be called in a separate goroutine. func (s *SubscriptionSource) Start(ctx *resolve.Context, headers http.Header, input []byte, updater resolve.SubscriptionUpdater) error { var options GraphQLSubscriptionOptions @@ -2029,10 +2008,6 @@ func (s *SubscriptionSource) Start(ctx *resolve.Context, headers http.Header, in return s.client.Subscribe(ctx, options, updater) } -var ( - dataSouceName = []byte("graphql") -) - // SubscriptionOnStart is called when a subscription is started. // Hooks are invoked sequentially, short-circuiting on the first error. func (s *SubscriptionSource) SubscriptionOnStart(ctx resolve.StartupHookContext, input []byte) error { diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go index 05b07df9e1..f1ab22708c 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource_test.go @@ -4006,7 +4006,7 @@ func TestGraphQLDataSource(t *testing.T) { Trigger: resolve.GraphQLSubscriptionTrigger{ Input: []byte(`{"url":"wss://swapi.com/graphql","body":{"query":"subscription{remainingJedis}"}}`), Source: &SubscriptionSource{ - client: NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, ctx), + client: NewGraphQLSubscriptionClient(ctx), }, PostProcessing: DefaultPostProcessingConfiguration, SourceName: "ds-id", @@ -4049,7 +4049,7 @@ func TestGraphQLDataSource(t *testing.T) { }, ), Source: &SubscriptionSource{ - client: NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, ctx), + client: NewGraphQLSubscriptionClient(ctx), }, PostProcessing: DefaultPostProcessingConfiguration, SourceName: "ds-id", @@ -8381,97 +8381,12 @@ func (f *FailingSubscriptionClient) Subscribe(ctx *resolve.Context, options Grap return errSubscriptionClientFail } -type testSubscriptionUpdaterChan struct { - updates chan string - complete chan struct{} - closed chan resolve.SubscriptionCloseKind -} - -func newTestSubscriptionUpdaterChan() *testSubscriptionUpdaterChan { - return &testSubscriptionUpdaterChan{ - updates: make(chan string), - complete: make(chan struct{}), - closed: make(chan resolve.SubscriptionCloseKind), - } -} - -func (t *testSubscriptionUpdaterChan) Heartbeat() { - t.updates <- "{}" -} - -func (t *testSubscriptionUpdaterChan) Update(data []byte) { - t.updates <- string(data) -} - -// empty method to satisfy the interface, not used in this tests -func (t *testSubscriptionUpdaterChan) UpdateSubscription(id resolve.SubscriptionIdentifier, data []byte) { -} - -// empty method to satisfy the interface, not used in this tests -func (t *testSubscriptionUpdaterChan) CloseSubscription(kind resolve.SubscriptionCloseKind, id resolve.SubscriptionIdentifier) { -} - -// empty method to satisfy the interface, not used in this tests -func (t *testSubscriptionUpdaterChan) Subscriptions() map[context.Context]resolve.SubscriptionIdentifier { - return make(map[context.Context]resolve.SubscriptionIdentifier) -} - -func (t *testSubscriptionUpdaterChan) Complete() { - close(t.complete) -} - -func (t *testSubscriptionUpdaterChan) Close(kind resolve.SubscriptionCloseKind) { - t.closed <- kind -} - -func (t *testSubscriptionUpdaterChan) AwaitUpdateWithT(tt *testing.T, timeout time.Duration, f func(t *testing.T, update string), msgAndArgs ...any) { - tt.Helper() - - select { - case args := <-t.updates: - f(tt, args) - case <-time.After(timeout): - require.Fail(tt, "unable to receive update before timeout", msgAndArgs...) - } -} - -func (t *testSubscriptionUpdaterChan) AwaitClose(tt *testing.T, timeout time.Duration, msgAndArgs ...any) { - tt.Helper() - - select { - case <-t.closed: - case <-time.After(timeout): - require.Fail(tt, "updater not closed before timeout", msgAndArgs...) - } -} - -func (t *testSubscriptionUpdaterChan) AwaitCloseKind(tt *testing.T, timeout time.Duration, expectedCloseKind resolve.SubscriptionCloseKind, msgAndArgs ...any) { - tt.Helper() - - select { - case closeKind := <-t.closed: - require.Equal(tt, expectedCloseKind, closeKind, msgAndArgs...) - case <-time.After(timeout): - require.Fail(tt, "updater not closed before timeout", msgAndArgs...) - } -} - -func (t *testSubscriptionUpdaterChan) AwaitComplete(tt *testing.T, timeout time.Duration, msgAndArgs ...any) { - tt.Helper() - - select { - case <-t.complete: - case <-time.After(timeout): - require.Fail(tt, "updater not completed before timeout", msgAndArgs...) - } -} - // !! If you see this in a test you're working on, please replace it with the new testSubscriptionUpdaterChan // It's faster, more ergonomic and more reliable. See SSE handler tests for usage examples. type testSubscriptionUpdater struct { updates []string + errors []string done bool - closed bool mux sync.Mutex } @@ -8496,6 +8411,27 @@ func (t *testSubscriptionUpdater) AwaitUpdates(tt *testing.T, timeout time.Durat } } +func (t *testSubscriptionUpdater) AwaitErrors(tt *testing.T, timeout time.Duration, count int) { + tt.Helper() + + ticker := time.NewTicker(timeout) + defer ticker.Stop() + for { + time.Sleep(10 * time.Millisecond) + select { + case <-ticker.C: + tt.Fatalf("timed out waiting for errors") + default: + t.mux.Lock() + if len(t.errors) == count { + t.mux.Unlock() + return + } + t.mux.Unlock() + } + } +} + func (t *testSubscriptionUpdater) AwaitDone(tt *testing.T, timeout time.Duration) { tt.Helper() @@ -8535,14 +8471,20 @@ func (t *testSubscriptionUpdater) Complete() { t.done = true } -func (t *testSubscriptionUpdater) Close(kind resolve.SubscriptionCloseKind) { +func (t *testSubscriptionUpdater) Error(data []byte) { + t.mux.Lock() + defer t.mux.Unlock() + t.errors = append(t.errors, string(data)) +} + +func (t *testSubscriptionUpdater) Done() { t.mux.Lock() defer t.mux.Unlock() - t.closed = true + t.done = true } // empty method to satisfy the interface, not used in this tests -func (t *testSubscriptionUpdater) CloseSubscription(kind resolve.SubscriptionCloseKind, id resolve.SubscriptionIdentifier) { +func (t *testSubscriptionUpdater) CloseSubscription(id resolve.SubscriptionIdentifier) { } // empty method to satisfy the interface, not used in this tests @@ -8591,8 +8533,7 @@ func TestSubscriptionSource_Start(t *testing.T) { } newSubscriptionSource := func(ctx context.Context) SubscriptionSource { - httpClient := http.Client{} - subscriptionSource := SubscriptionSource{client: NewGraphQLSubscriptionClient(&httpClient, http.DefaultClient, ctx)} + subscriptionSource := SubscriptionSource{client: NewGraphQLSubscriptionClient(ctx)} return subscriptionSource } @@ -8631,9 +8572,9 @@ func TestSubscriptionSource_Start(t *testing.T) { chatSubscriptionOptions := chatServerSubscriptionOptions(t, `{"variables": {}, "extensions": {}, "operationName": "LiveMessages", "query": "subscription LiveMessages { messageAdded(roomNam: \"#test\") { text createdBy } }"}`) err := source.Start(ctx, nil, chatSubscriptionOptions, updater) require.NoError(t, err) - updater.AwaitUpdates(t, time.Second, 1) - assert.Len(t, updater.updates, 1) - assert.Equal(t, `{"errors":[{"message":"Unknown argument \"roomNam\" on field \"Subscription.messageAdded\". Did you mean \"roomName\"?","locations":[{"line":1,"column":29}],"extensions":{"code":"GRAPHQL_VALIDATION_FAILED"}},{"message":"Field \"messageAdded\" argument \"roomName\" of type \"String!\" is required, but it was not provided.","locations":[{"line":1,"column":29}],"extensions":{"code":"GRAPHQL_VALIDATION_FAILED"}}]}`, updater.updates[0]) + updater.AwaitErrors(t, time.Second, 1) + assert.Len(t, updater.errors, 1) + assert.Equal(t, `{"errors":[{"message":"Unknown argument \"roomNam\" on field \"Subscription.messageAdded\". Did you mean \"roomName\"?","locations":[{"line":1,"column":29}],"extensions":{"code":"GRAPHQL_VALIDATION_FAILED"}},{"message":"Field \"messageAdded\" argument \"roomName\" of type \"String!\" is required, but it was not provided.","locations":[{"line":1,"column":29}],"extensions":{"code":"GRAPHQL_VALIDATION_FAILED"}}]}`, updater.errors[0]) updater.AwaitDone(t, time.Second) }) @@ -8719,9 +8660,8 @@ func TestSubscription_GTWS_SubProtocol(t *testing.T) { } newSubscriptionSource := func(ctx context.Context) SubscriptionSource { - httpClient := http.Client{} subscriptionSource := SubscriptionSource{ - client: NewGraphQLSubscriptionClient(&httpClient, http.DefaultClient, ctx), + client: NewGraphQLSubscriptionClient(ctx), } return subscriptionSource } @@ -8737,9 +8677,9 @@ func TestSubscription_GTWS_SubProtocol(t *testing.T) { err := source.Start(ctx, nil, chatSubscriptionOptions, updater) require.NoError(t, err) - updater.AwaitUpdates(t, time.Second, 1) - assert.Len(t, updater.updates, 1) - assert.Equal(t, `{"errors":[{"message":"Unknown argument \"roomNam\" on field \"Subscription.messageAdded\". Did you mean \"roomName\"?","locations":[{"line":1,"column":29}],"extensions":{"code":"GRAPHQL_VALIDATION_FAILED"}},{"message":"Field \"messageAdded\" argument \"roomName\" of type \"String!\" is required, but it was not provided.","locations":[{"line":1,"column":29}],"extensions":{"code":"GRAPHQL_VALIDATION_FAILED"}}]}`, updater.updates[0]) + updater.AwaitErrors(t, time.Second, 1) + assert.Len(t, updater.errors, 1) + assert.Equal(t, `{"errors":[{"message":"Unknown argument \"roomNam\" on field \"Subscription.messageAdded\". Did you mean \"roomName\"?","locations":[{"line":1,"column":29}],"extensions":{"code":"GRAPHQL_VALIDATION_FAILED"}},{"message":"Field \"messageAdded\" argument \"roomName\" of type \"String!\" is required, but it was not provided.","locations":[{"line":1,"column":29}],"extensions":{"code":"GRAPHQL_VALIDATION_FAILED"}}]}`, updater.errors[0]) updater.AwaitDone(t, time.Second) assert.Equal(t, true, updater.done) }) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_sse_handler.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_sse_handler.go deleted file mode 100644 index 512d7ca4b0..0000000000 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_sse_handler.go +++ /dev/null @@ -1,275 +0,0 @@ -package graphql_datasource - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "io" - "math" - "net/http" - - "github.com/buger/jsonparser" - log "github.com/jensneuse/abstractlogger" - "github.com/r3labs/sse/v2" - - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" -) - -var ( - headerData = []byte("data:") - headerEvent = []byte("event:") - - eventTypeComplete = []byte("complete") - eventTypeNext = []byte("next") -) - -type gqlSSEConnectionHandler struct { - conn *http.Client - requestContext, engineContext context.Context - log log.Logger - options GraphQLSubscriptionOptions - updater resolve.SubscriptionUpdater -} - -func newSSEConnectionHandler(requestContext, engineContext context.Context, conn *http.Client, updater resolve.SubscriptionUpdater, options GraphQLSubscriptionOptions, l log.Logger) *gqlSSEConnectionHandler { - return &gqlSSEConnectionHandler{ - conn: conn, - requestContext: requestContext, - engineContext: engineContext, - log: l, - updater: updater, - options: options, - } -} - -func (h *gqlSSEConnectionHandler) StartBlocking() { - defer h.updater.Close(resolve.SubscriptionCloseKindNormal) - - resp, err := h.performSubscriptionRequest() - if err != nil { - h.log.Error("failed to perform subscription request", log.Error(err)) - - if h.requestContext.Err() != nil { - // request context was canceled do not send an error as channel will be closed - return - } - - h.updater.Update([]byte(internalError)) - - return - } - - defer func() { - _ = resp.Body.Close() - }() - - reader := sse.NewEventStreamReader(resp.Body, math.MaxInt) - - for { - select { - case <-h.requestContext.Done(): - return - case <-h.engineContext.Done(): - return - default: - } - - msg, err := reader.ReadEvent() - if err != nil { - if err == io.EOF { - return - } - - h.log.Error("failed to read event", log.Error(err)) - h.updater.Update([]byte(internalError)) - return - } - - if len(msg) == 0 { - continue - } - - // normalize the crlf to lf to make it easier to split the lines. - // split the line by "\n" or "\r", per the spec. - lines := bytes.FieldsFunc(msg, func(r rune) bool { return r == '\n' || r == '\r' }) - for _, line := range lines { - switch { - case bytes.HasPrefix(line, headerData): - data := trim(line[len(headerData):]) - - if len(data) == 0 { - continue - } - - if h.requestContext.Err() != nil { - // request context was canceled do not send an error as channel will be closed - return - } - - h.updater.Update(data) - case bytes.HasPrefix(line, headerEvent): - event := trim(line[len(headerEvent):]) - - switch { - case bytes.Equal(event, eventTypeComplete): - h.updater.Complete() - return - case bytes.Equal(event, eventTypeNext): - continue - } - case bytes.HasPrefix(msg, []byte(":")): - // according to the spec, we ignore messages starting with a colon - continue - default: - // ideally we should not get here, or if we do, we should ignore it - // but some providers send a json object with the error messages, without the event header - - // check for errors which came without event header - data := trim(line) - - val, valueType, _, err := jsonparser.Get(data, "errors") - switch err { - case jsonparser.KeyPathNotFoundError: - continue - case jsonparser.MalformedJsonError: - // ignore garbage - continue - case nil: - switch valueType { - case jsonparser.Array: - response := []byte(`{}`) - response, err = jsonparser.Set(response, val, "errors") - if err != nil { - h.log.Error("failed to set errors", log.Error(err)) - h.updater.Update([]byte(internalError)) - return - } - h.updater.Update(response) - return - case jsonparser.Object: - response := []byte(`{"errors":[]}`) - response, err = jsonparser.Set(response, val, "errors", "[0]") - if err != nil { - h.log.Error("failed to set errors", log.Error(err)) - h.updater.Update([]byte(internalError)) - return - } - h.updater.Update(response) - return - default: - // don't crash on unexpected payloads from upstream - h.log.Error(fmt.Sprintf("unexpected value type: %d", valueType)) - h.updater.Update([]byte(internalError)) - return - } - - default: - h.log.Error("failed to parse errors", log.Error(err)) - h.updater.Update([]byte(internalError)) - return - } - } - } - } -} - -func trim(data []byte) []byte { - // remove the leading space - data = bytes.TrimLeft(data, " \t") - - // remove the trailing new line - data = bytes.TrimRight(data, "\n") - - return data -} - -func (h *gqlSSEConnectionHandler) performSubscriptionRequest() (*http.Response, error) { - var req *http.Request - var err error - - // default to GET requests when SSEMethodPost is not enabled in the SubscriptionConfiguration - if h.options.SSEMethodPost { - req, err = h.buildPOSTRequest(h.requestContext) - } else { - req, err = h.buildGETRequest(h.requestContext) - } - - if err != nil { - return nil, err - } - - resp, err := h.conn.Do(req) - if err != nil { - return nil, err - } - - switch resp.StatusCode { - case http.StatusOK: - return resp, nil - default: - return nil, fmt.Errorf("failed to connect to stream unexpected resp status code: %d", resp.StatusCode) - } -} - -func (h *gqlSSEConnectionHandler) buildGETRequest(ctx context.Context) (*http.Request, error) { - req, err := http.NewRequestWithContext(ctx, "GET", h.options.URL, nil) - if err != nil { - return nil, err - } - - if h.options.Header != nil { - req.Header = h.options.Header - } - - query := req.URL.Query() - query.Add("query", h.options.Body.Query) - - if h.options.Body.Variables != nil { - variables, _ := h.options.Body.Variables.MarshalJSON() - - query.Add("variables", string(variables)) - } - - if h.options.Body.OperationName != "" { - query.Add("operationName", h.options.Body.OperationName) - } - - if h.options.Body.Extensions != nil { - extensions, _ := h.options.Body.Extensions.MarshalJSON() - - query.Add("extensions", string(extensions)) - } - - req.URL.RawQuery = query.Encode() - h.setSSEHeaders(req) - - return req, nil -} - -func (h *gqlSSEConnectionHandler) buildPOSTRequest(ctx context.Context) (*http.Request, error) { - body, err := json.Marshal(h.options.Body) - if err != nil { - return nil, err - } - - req, err := http.NewRequestWithContext(ctx, "POST", h.options.URL, bytes.NewBuffer(body)) - if err != nil { - return nil, err - } - - if h.options.Header != nil { - req.Header = h.options.Header - } - - req.Header.Set("Content-Type", "application/json") - h.setSSEHeaders(req) - return req, nil -} - -// setSSEHeaders sets the headers required for SSE for both GET and POST requests -func (h *gqlSSEConnectionHandler) setSSEHeaders(req *http.Request) { - req.Header.Set("Accept", "text/event-stream") - req.Header.Set("Connection", "keep-alive") - req.Header.Set("Cache-Control", "no-cache") -} diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_sse_handler_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_sse_handler_test.go deleted file mode 100644 index 4c50f19176..0000000000 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_sse_handler_test.go +++ /dev/null @@ -1,646 +0,0 @@ -package graphql_datasource - -import ( - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "net/http/httptest" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" - "github.com/wundergraph/graphql-go-tools/v2/pkg/testing/flags" -) - -func TestGraphQLSubscriptionClientSubscribe_SSE(t *testing.T) { - if flags.IsWindows { - t.Skip("skipping test on windows") - } - - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - urlQuery := r.URL.Query() - assert.Equal(t, "subscription {messageAdded(roomName: \"room\"){text}}", urlQuery.Get("query")) - - // Make sure that the writer supports flushing. - flusher, ok := w.(http.Flusher) - require.True(t, ok) - - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - w.Header().Set("Access-Control-Allow-Origin", "*") - - _, _ = fmt.Fprintf(w, "data: %s\n\n", `{"data":{"messageAdded":{"text":"first"}}}`) - flusher.Flush() - - _, _ = fmt.Fprintf(w, "data: %s\n\n", `{"data":{"messageAdded":{"text":"second"}}}`) - flusher.Flush() - - close(serverDone) - })) - defer server.Close() - - serverCtx, serverCancel := context.WithCancel(context.Background()) - - ctx, clientCancel := context.WithCancel(context.Background()) - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithReadTimeout(time.Millisecond), - WithLogger(logger()), - ) - - updater := &testSubscriptionUpdater{} - - go func() { - err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - UseSSE: true, - }, updater) - assert.NoError(t, err) - }() - - updater.AwaitUpdates(t, time.Second, 2) - assert.Equal(t, 2, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) - - clientCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second, time.Millisecond*10, "server did not close") - serverCancel() -} - -func TestGraphQLSubscriptionClientSubscribe_SSE_RequestAbort(t *testing.T) { - if flags.IsWindows { - t.Skip("skipping test on windows") - } - - ctx, clientCancel := context.WithCancel(context.Background()) - // cancel after start the request - clientCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, t.Context(), - WithReadTimeout(time.Millisecond), - WithLogger(logger()), - ) - - updater := newTestSubscriptionUpdaterChan() - - go func() { - err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{ - URL: "http://dummy", - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - UseSSE: true, - }, updater) - assert.NoError(t, err) - }() - - updater.AwaitClose(t, time.Second) -} - -func TestGraphQLSubscriptionClientSubscribe_SSE_POST(t *testing.T) { - if flags.IsWindows { - t.Skip("skipping test on windows") - } - - postReqBody := GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - } - expectedReqBody, err := json.Marshal(postReqBody) - assert.NoError(t, err) - - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, http.MethodPost, r.Method) - - actualReqBody, err := io.ReadAll(r.Body) - assert.NoError(t, err) - assert.Equal(t, expectedReqBody, actualReqBody) - - // Make sure that the writer supports flushing. - flusher, ok := w.(http.Flusher) - require.True(t, ok) - - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - w.Header().Set("Access-Control-Allow-Origin", "*") - - _, _ = fmt.Fprintf(w, "data: %s\n\n", `{"data":{"messageAdded":{"text":"first"}}}`) - flusher.Flush() - - _, _ = fmt.Fprintf(w, "data: %s\n\n", `{"data":{"messageAdded":{"text":"second"}}}`) - flusher.Flush() - - _, _ = fmt.Fprintf(w, "event: complete\n\n") - flusher.Flush() - - close(serverDone) - })) - defer server.Close() - - serverCtx, serverCancel := context.WithCancel(context.Background()) - - ctx, clientCancel := context.WithCancel(context.Background()) - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithReadTimeout(time.Millisecond), - WithLogger(logger()), - ) - - updater := newTestSubscriptionUpdaterChan() - - go func() { - err = client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{ - URL: server.URL, - Body: postReqBody, - UseSSE: true, - SSEMethodPost: true, - }, updater) - assert.NoError(t, err) - }() - - updater.AwaitUpdateWithT(t, time.Second, func(t *testing.T, update string) { - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, update) - }) - - updater.AwaitUpdateWithT(t, time.Second, func(t *testing.T, update string) { - assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, update) - }) - - updater.AwaitComplete(t, time.Second) - - clientCancel() - - select { - case <-serverDone: - case <-time.After(time.Second): - require.Fail(t, "server did not close") - } - - serverCancel() -} - -func TestGraphQLSubscriptionClientSubscribe_SSE_WithEvents(t *testing.T) { - if flags.IsWindows { - t.Skip("skipping test on windows") - } - - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Make sure that the writer supports flushing. - flusher, ok := w.(http.Flusher) - require.True(t, ok) - - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - w.Header().Set("Access-Control-Allow-Origin", "*") - - _, _ = fmt.Fprintf(w, "event: next\ndata: %s\n\n", `{"data":{"messageAdded":{"text":"first"}}}`) - flusher.Flush() - - _, _ = fmt.Fprintf(w, "event: next\ndata: %s\n\n", `{"data":{"messageAdded":{"text":"second"}}}`) - flusher.Flush() - - _, _ = fmt.Fprintf(w, "event: complete\n\n") - flusher.Flush() - - close(serverDone) - })) - defer server.Close() - - serverCtx, serverCancel := context.WithCancel(context.Background()) - - ctx, clientCancel := context.WithCancel(context.Background()) - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithReadTimeout(time.Millisecond), - WithLogger(logger()), - ) - - updater := newTestSubscriptionUpdaterChan() - - go func() { - err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - UseSSE: true, - }, updater) - assert.NoError(t, err) - }() - - updater.AwaitUpdateWithT(t, time.Second, func(t *testing.T, update string) { - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, update) - }) - - updater.AwaitUpdateWithT(t, time.Second, func(t *testing.T, update string) { - assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, update) - }) - - updater.AwaitComplete(t, time.Second) - - clientCancel() - - select { - case <-serverDone: - case <-time.After(time.Second): - require.Fail(t, "server did not close") - } - - serverCancel() -} - -func TestGraphQLSubscriptionClientSubscribe_SSE_Error(t *testing.T) { - if flags.IsWindows { - t.Skip("skipping test on windows") - } - - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Make sure that the writer supports flushing. - flusher, ok := w.(http.Flusher) - require.True(t, ok) - - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - w.Header().Set("Access-Control-Allow-Origin", "*") - - _, _ = fmt.Fprintf(w, "data: %s\n\n", `{"errors":[{"message":"Unexpected error.","locations":[{"line":2,"column":3}],"path":["countdown"]}]}`) - flusher.Flush() - - close(serverDone) - })) - defer server.Close() - - serverCtx, serverCancel := context.WithCancel(context.Background()) - - ctx, clientCancel := context.WithCancel(context.Background()) - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithReadTimeout(time.Millisecond), - WithLogger(logger()), - ) - - updater := newTestSubscriptionUpdaterChan() - - go func() { - err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - UseSSE: true, - }, updater) - assert.NoError(t, err) - }() - - updater.AwaitUpdateWithT(t, time.Second, func(t *testing.T, update string) { - assert.Equal(t, `{"errors":[{"message":"Unexpected error.","locations":[{"line":2,"column":3}],"path":["countdown"]}]}`, update) - }) - - updater.AwaitClose(t, time.Second) - - clientCancel() - - select { - case <-serverDone: - case <-time.After(time.Second): - require.Fail(t, "server did not close") - } - - serverCancel() -} - -func TestGraphQLSubscriptionClientSubscribe_SSE_Error_Without_Header(t *testing.T) { - if flags.IsWindows { - t.Skip("skipping test on windows") - } - - testCases := []struct { - name string - errorMessage string - expectedErr string - }{ - { - name: "object_error_value", - errorMessage: `{"errors":{"message":"Unexpected error.","locations":[{"line":2,"column":3}],"path":["countdown"]},"data":null}`, - expectedErr: `{"errors":[{"message":"Unexpected error.","locations":[{"line":2,"column":3}],"path":["countdown"]}]}`, - }, - { - name: "list_error_value", - errorMessage: `{"errors":[{"message":"Unexpected error.","locations":[{"line":2,"column":3}],"path":["countdown"]}],"data":null}`, - expectedErr: `{"errors":[{"message":"Unexpected error.","locations":[{"line":2,"column":3}],"path":["countdown"]}]}`, - }, - { - name: "string_error_value", - errorMessage: `{"errors": "some string error"}`, - expectedErr: `{"errors":[{"message":"internal error"}]}`, - }, - { - name: "number_error_value", - errorMessage: `{"errors": 123}`, - expectedErr: `{"errors":[{"message":"internal error"}]}`, - }, - { - name: "boolean_true_error_value", - errorMessage: `{"errors": true}`, - expectedErr: `{"errors":[{"message":"internal error"}]}`, - }, - { - name: "null_error_value", - errorMessage: `{"errors": null}`, - expectedErr: `{"errors":[{"message":"internal error"}]}`, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Make sure that the writer supports flushing. - flusher, ok := w.(http.Flusher) - require.True(t, ok) - - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - w.Header().Set("Access-Control-Allow-Origin", "*") - - // Send the malformed error message WITHOUT the "data:" prefix - // This triggers the error parsing logic in the default case - _, _ = fmt.Fprintf(w, "%s\n\n", tc.errorMessage) - flusher.Flush() - - close(serverDone) - })) - defer server.Close() - - serverCtx, serverCancel := context.WithCancel(context.Background()) - - ctx, clientCancel := context.WithCancel(context.Background()) - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithReadTimeout(time.Millisecond), - WithLogger(logger()), - ) - - updater := newTestSubscriptionUpdaterChan() - - go func() { - err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - UseSSE: true, - }, updater) - assert.NoError(t, err) - }() - - updater.AwaitUpdateWithT(t, time.Second, func(t *testing.T, update string) { - assert.Equal(t, tc.expectedErr, update) - }) - - updater.AwaitClose(t, time.Second) - - clientCancel() - - select { - case <-serverDone: - case <-time.After(time.Second): - require.Fail(t, "server did not close") - } - - serverCancel() - }) - } -} - -func TestGraphQLSubscriptionClientSubscribe_QueryParams(t *testing.T) { - if flags.IsWindows { - t.Skip("skipping test on windows") - } - - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - urlQuery := r.URL.Query() - assert.Equal(t, "subscription($a: Int!){countdown(from: $a)}", urlQuery.Get("query")) - assert.Equal(t, "CountDown", urlQuery.Get("operationName")) - assert.Equal(t, `{"a":5}`, urlQuery.Get("variables")) - assert.Equal(t, `{"persistedQuery":{"version":1,"sha256Hash":"d41d8cd98f00b204e9800998ecf8427e"}}`, urlQuery.Get("extensions")) - - // Make sure that the writer supports flushing. - flusher, ok := w.(http.Flusher) - require.True(t, ok) - - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - w.Header().Set("Access-Control-Allow-Origin", "*") - - _, _ = fmt.Fprintf(w, "data: %s\n\n", `{"data":{"countdown":5}}`) - flusher.Flush() - - close(serverDone) - })) - defer server.Close() - - serverCtx, serverCancel := context.WithCancel(context.Background()) - ctx, clientCancel := context.WithCancel(context.Background()) - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithReadTimeout(time.Millisecond), - WithLogger(logger()), - ) - - updater := newTestSubscriptionUpdaterChan() - - go func() { - err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription($a: Int!){countdown(from: $a)}`, - OperationName: "CountDown", - Variables: []byte(`{"a":5}`), - Extensions: []byte(`{"persistedQuery":{"version":1,"sha256Hash":"d41d8cd98f00b204e9800998ecf8427e"}}`), - }, - UseSSE: true, - }, updater) - assert.NoError(t, err) - }() - - updater.AwaitUpdateWithT(t, time.Second, func(t *testing.T, update string) { - assert.Equal(t, `{"data":{"countdown":5}}`, update) - }) - - updater.AwaitClose(t, time.Second) - - clientCancel() - - select { - case <-serverDone: - case <-time.After(time.Second): - require.Fail(t, "server did not close") - } - - serverCancel() -} - -func TestBuildPOSTRequestSSE(t *testing.T) { - if flags.IsWindows { - t.Skip("skipping test on windows") - } - - subscriptionOptions := GraphQLSubscriptionOptions{ - URL: "test", - Body: GraphQLBody{ - Query: `subscription($a: Int!){countdown(from: $a)}`, - OperationName: "CountDown", - Variables: []byte(`{"a":5}`), - Extensions: []byte(`{"persistedQuery":{"version":1,"sha256Hash":"d41d8cd98f00b204e9800998ecf8427e"}}`), - }, - } - - h := gqlSSEConnectionHandler{ - options: subscriptionOptions, - } - - req, err := h.buildPOSTRequest(context.Background()) - assert.NoError(t, err) - - expectedReqBody, err := json.Marshal(subscriptionOptions.Body) - assert.NoError(t, err) - - assert.Equal(t, http.MethodPost, req.Method) - - actualReqBody, err := io.ReadAll(req.Body) - assert.NoError(t, err) - assert.Equal(t, expectedReqBody, actualReqBody) -} - -func TestBuildGETRequestSSE(t *testing.T) { - if flags.IsWindows { - t.Skip("skipping test on windows") - } - - subscriptionOptions := GraphQLSubscriptionOptions{ - URL: "test", - Body: GraphQLBody{ - Query: `subscription($a: Int!){countdown(from: $a)}`, - OperationName: "CountDown", - Variables: []byte(`{"a":5}`), - Extensions: []byte(`{"persistedQuery":{"version":1,"sha256Hash":"d41d8cd98f00b204e9800998ecf8427e"}}`), - }, - } - - h := gqlSSEConnectionHandler{ - options: subscriptionOptions, - } - - req, err := h.buildGETRequest(context.Background()) - assert.NoError(t, err) - - assert.Equal(t, http.MethodGet, req.Method) - - urlQuery := req.URL.Query() - assert.Equal(t, subscriptionOptions.Body.Query, urlQuery.Get("query")) - assert.Equal(t, subscriptionOptions.Body.OperationName, urlQuery.Get("operationName")) - - assert.Equal(t, string(subscriptionOptions.Body.Variables), urlQuery.Get("variables")) - assert.Equal(t, string(subscriptionOptions.Body.Extensions), urlQuery.Get("extensions")) - -} - -func TestGraphQLSubscriptionClientSubscribe_SSE_Upstream_Dies(t *testing.T) { - if flags.IsWindows { - t.Skip("skipping test on windows") - } - - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - urlQuery := r.URL.Query() - assert.Equal(t, "subscription {messageAdded(roomName: \"room\"){text}}", urlQuery.Get("query")) - - // Make sure that the writer supports flushing. - flusher, ok := w.(http.Flusher) - require.True(t, ok) - - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - w.Header().Set("Access-Control-Allow-Origin", "*") - - _, _ = fmt.Fprintf(w, "data: %s\n\n", `{"data":{"messageAdded":{"text":"first"}}}`) - flusher.Flush() - - // Kill the upstream server. We should catch this event as an "unexpected EOF" - // error and return an error message to the subscriber. - h, ok := w.(http.Hijacker) - require.True(t, ok) - rawConn, _, err := h.Hijack() - require.NoError(t, err) - _ = rawConn.Close() - - close(serverDone) - })) - defer server.Close() - - serverCtx, serverCancel := context.WithCancel(context.Background()) - - ctx, clientCancel := context.WithCancel(context.Background()) - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithReadTimeout(time.Millisecond), - WithLogger(logger()), - ) - - updater := newTestSubscriptionUpdaterChan() - - go func() { - err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - UseSSE: true, - }, updater) - assert.NoError(t, err) - }() - - updater.AwaitUpdateWithT(t, time.Second, func(t *testing.T, update string) { - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, update) - }) - - updater.AwaitUpdateWithT(t, time.Second, func(t *testing.T, update string) { - assert.Equal(t, `{"errors":[{"message":"internal error"}]}`, update) - }) - - updater.AwaitClose(t, time.Second) - - clientCancel() - - select { - case <-serverDone: - case <-time.After(time.Second): - require.Fail(t, "server did not close") - } - - serverCancel() -} diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go index c8a08df03f..b2b92ad28e 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client.go @@ -2,930 +2,347 @@ package graphql_datasource import ( "context" - "crypto/rand" - "crypto/sha1" - "crypto/tls" - "encoding/base64" + "encoding/json" "errors" "fmt" - "io" - "net" "net/http" - "net/http/httptrace" - "strings" - "sync" - "syscall" "time" - "github.com/buger/jsonparser" - "github.com/cespare/xxhash/v2" - "github.com/gobwas/ws" - "github.com/gobwas/ws/wsutil" + "github.com/coder/websocket" "github.com/jensneuse/abstractlogger" - "go.uber.org/atomic" + client "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" - "github.com/wundergraph/graphql-go-tools/v2/pkg/netpoll" ) -const ( - // The time to write a message to the server connection before timing out - writeTimeout = 10 * time.Second - // The time to wait for a connection ack message from the server before timing out - ackWaitTimeout = 30 * time.Second -) - -type netPollState struct { - // connections is a map of fd -> connection to keep track of all active connections - connections map[int]*connection - hasConnections atomic.Bool - // triggers is a map of subscription id -> fd to easily look up the connection for a subscription id - triggers map[uint64]int - - // clientUnsubscribe is a channel to signal to the netPoll run loop that a client needs to be unsubscribed - clientUnsubscribe chan uint64 - // addConn is a channel to signal to the netPoll run loop that a new connection needs to be added - addConn chan *connection - // waitForEventsTicker is the ticker for the netPoll run loop - // it is used to prevent busy waiting and to limit the CPU usage - // instead of polling the netPoll instance all the time, we wait until the next tick to throttle the netPoll loop - waitForEventsTicker *time.Ticker - - // waitForEventsTick is the channel to receive the tick from the waitForEventsTicker - waitForEventsTick <-chan time.Time -} - -// subscriptionClient allows running multiple subscriptions via the same WebSocket either SSE connection -// It takes care of de-duplicating connections to the same origin under certain circumstances -// If Hash(URL,Body,Headers) result in the same result, an existing connection is re-used -type subscriptionClient struct { - streamingClient *http.Client - httpClient *http.Client +// subscriptionClientConfig holds the subscription client configuration. +type subscriptionClientConfig struct { + UpgradeClient *http.Client + StreamingClient *http.Client + Logger abstractlogger.Logger - useHttpClientWithSkipRoundTrip bool + // Timeouts and limits + PingInterval time.Duration + PingTimeout time.Duration + AckTimeout time.Duration + WriteTimeout time.Duration + ReadLimit int64 - engineCtx context.Context - log abstractlogger.Logger - hashPool sync.Pool - onWsConnectionInitCallback *OnWsConnectionInitCallback - - readTimeout time.Duration - pingInterval time.Duration - frameTimeout time.Duration - pingTimeout time.Duration - - netPoll netpoll.Poller - netPollConfig NetPollConfiguration - netPollState *netPollState + // DefaultErrorExtensionCode is the extension code attached to GraphQL + // errors produced by upstream connection failures. Should match the + // resolve package's setting for consistent error formatting. + DefaultErrorExtensionCode string } -func (c *subscriptionClient) SubscribeAsync(ctx *resolve.Context, id uint64, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) error { - if options.UseSSE { - return c.subscribeSSE(ctx.Context(), c.engineCtx, options, updater) - } +func defaultsubscriptionClientConfig() *subscriptionClientConfig { + return &subscriptionClientConfig{ + UpgradeClient: http.DefaultClient, + StreamingClient: http.DefaultClient, + Logger: abstractlogger.NoopLogger, - if strings.HasPrefix(options.URL, "https") { - options.URL = "wss" + options.URL[5:] - } else if strings.HasPrefix(options.URL, "http") { - options.URL = "ws" + options.URL[4:] + PingInterval: 30 * time.Second, + PingTimeout: 10 * time.Second, } - - return c.asyncSubscribeWS(ctx.Context(), c.engineCtx, id, options, updater) -} - -func (c *subscriptionClient) Unsubscribe(id uint64) { - // if we don't have netPoll, we don't have a channel consumer of the clientUnsubscribe channel - // we have to return to prevent a deadlock - if c.netPoll == nil { - return - } - c.netPollState.clientUnsubscribe <- id } -type InvalidWsSubprotocolError struct { - InvalidProtocol string -} - -func (e InvalidWsSubprotocolError) Error() string { - return fmt.Sprintf("provided websocket subprotocol '%s' is not supported. The supported subprotocols are graphql-ws and graphql-transport-ws. Please configure your subsciptions with the mentioned subprotocols", e.InvalidProtocol) -} +// SubscriptionClientOption configures the subscription client. +type SubscriptionClientOption func(*subscriptionClientConfig) -func NewInvalidWsSubprotocolError(invalidProtocol string) InvalidWsSubprotocolError { - return InvalidWsSubprotocolError{ - InvalidProtocol: invalidProtocol, +// WithUpgradeClient sets the HTTP client used for WebSocket upgrade requests. +func WithUpgradeClient(c *http.Client) SubscriptionClientOption { + return func(cfg *subscriptionClientConfig) { + if c != nil { + cfg.UpgradeClient = c + } } } -type Options func(options *opts) - -func WithLogger(log abstractlogger.Logger) Options { - return func(options *opts) { - options.log = log +// WithStreamingClient sets the HTTP client used for SSE requests. +// This client should have appropriate timeouts for long-lived connections. +func WithStreamingClient(c *http.Client) SubscriptionClientOption { + return func(cfg *subscriptionClientConfig) { + if c != nil { + cfg.StreamingClient = c + } } } -func WithReadTimeout(timeout time.Duration) Options { - return func(options *opts) { - options.readTimeout = timeout +// WithLogger sets the logger for the client and its transports. +// If not set, logging is disabled (silent operation). +func WithLogger(log abstractlogger.Logger) SubscriptionClientOption { + return func(cfg *subscriptionClientConfig) { + if log != nil { + cfg.Logger = log + } } } -func WithPingInterval(interval time.Duration) Options { - return func(options *opts) { - options.pingInterval = interval +// WithPingInterval sets the interval between ping messages for connection health checks. +// Only applies to graphql-transport-ws protocol (legacy graphql-ws uses server-initiated keepalive). +// Default: 30s. Set to 0 to disable client-initiated pings. +func WithPingInterval(d time.Duration) SubscriptionClientOption { + return func(cfg *subscriptionClientConfig) { + cfg.PingInterval = d } } -func WithFrameTimeout(timeout time.Duration) Options { - return func(options *opts) { - options.frameTimeout = timeout +// WithPingTimeout sets the maximum time to wait for a pong response. +// If no pong is received within this duration, the connection is considered dead. +// Default: 10s. Set to 0 to disable the pong-timeout check. +func WithPingTimeout(d time.Duration) SubscriptionClientOption { + return func(cfg *subscriptionClientConfig) { + cfg.PingTimeout = d } } -func WithPingTimeout(timeout time.Duration) Options { - return func(options *opts) { - options.pingTimeout = timeout +// WithAckTimeout sets the maximum time to wait for connection_ack after connection_init. +func WithAckTimeout(d time.Duration) SubscriptionClientOption { + return func(cfg *subscriptionClientConfig) { + cfg.AckTimeout = d } } -type NetPollConfiguration struct { - // Enable can be set to true to enable netPoll - Enable bool - // BufferSize defines the size of the buffer for the netPoll loop - BufferSize int - // WaitForNumEvents defines how many events are waited for in the netPoll loop before TickInterval cancels the wait - WaitForNumEvents int - // MaxEventWorkers defines the parallelism of how many connections can be handled at the same time - // The higher the number, the more CPU is used. - MaxEventWorkers int - // TickInterval defines the time between each netPoll loop when WaitForNumEvents is not reached - TickInterval time.Duration -} - -func (e *NetPollConfiguration) ApplyDefaults() { - e.Enable = true - - if e.BufferSize == 0 { - e.BufferSize = 1024 - } - if e.MaxEventWorkers == 0 { - e.MaxEventWorkers = 6 - } - if e.WaitForNumEvents == 0 { - e.WaitForNumEvents = 1024 - } - if e.TickInterval == 0 { - e.TickInterval = time.Millisecond * 100 +// WithWriteTimeout sets the timeout for WebSocket write operations (subscribe, unsubscribe, ping, pong). +func WithWriteTimeout(d time.Duration) SubscriptionClientOption { + return func(cfg *subscriptionClientConfig) { + cfg.WriteTimeout = d } } -func WithNetPollConfiguration(config NetPollConfiguration) Options { - return func(options *opts) { - options.netPollConfiguration = config +// WithDefaultErrorExtensionCode sets the extension code attached to GraphQL +// errors produced by upstream connection failures. +func WithDefaultErrorExtensionCode(code string) SubscriptionClientOption { + return func(cfg *subscriptionClientConfig) { + cfg.DefaultErrorExtensionCode = code } } -type opts struct { - readTimeout time.Duration - pingInterval time.Duration - pingTimeout time.Duration - frameTimeout time.Duration - log abstractlogger.Logger - onWsConnectionInitCallback *OnWsConnectionInitCallback - netPollConfiguration NetPollConfiguration -} - -// GraphQLSubscriptionClientFactory abstracts the way of creating a new GraphQLSubscriptionClient. -// This can be very handy for testing purposes. -type GraphQLSubscriptionClientFactory interface { - NewSubscriptionClient(httpClient, streamingClient *http.Client, engineCtx context.Context, options ...Options) GraphQLSubscriptionClient -} - -type DefaultSubscriptionClientFactory struct{} - -func (d *DefaultSubscriptionClientFactory) NewSubscriptionClient(httpClient, streamingClient *http.Client, engineCtx context.Context, options ...Options) GraphQLSubscriptionClient { - return NewGraphQLSubscriptionClient(httpClient, streamingClient, engineCtx, options...) -} - -func IsDefaultGraphQLSubscriptionClient(client GraphQLSubscriptionClient) bool { - _, ok := client.(*subscriptionClient) - return ok -} - -func NewGraphQLSubscriptionClient(httpClient, streamingClient *http.Client, engineCtx context.Context, options ...Options) GraphQLSubscriptionClient { - - // Defaults - op := &opts{ - readTimeout: 5 * time.Second, - pingInterval: 15 * time.Second, - pingTimeout: 30 * time.Second, - frameTimeout: 100 * time.Millisecond, - log: abstractlogger.NoopLogger, - } - - op.netPollConfiguration.ApplyDefaults() - - for _, option := range options { - option(op) +// WithReadLimit sets the maximum size in bytes for incoming WebSocket messages. +func WithReadLimit(n int64) SubscriptionClientOption { + return func(cfg *subscriptionClientConfig) { + cfg.ReadLimit = n } - - client := &subscriptionClient{ - httpClient: httpClient, - streamingClient: streamingClient, - engineCtx: engineCtx, - log: op.log, - readTimeout: op.readTimeout, - pingInterval: op.pingInterval, - pingTimeout: op.pingTimeout, - frameTimeout: op.frameTimeout, - hashPool: sync.Pool{ - New: func() interface{} { - return xxhash.New() - }, - }, - onWsConnectionInitCallback: op.onWsConnectionInitCallback, - netPollConfig: op.netPollConfiguration, - } - if op.netPollConfiguration.Enable { - client.netPollState = &netPollState{ - connections: make(map[int]*connection), - triggers: make(map[uint64]int), - clientUnsubscribe: make(chan uint64, op.netPollConfiguration.BufferSize), - addConn: make(chan *connection, op.netPollConfiguration.BufferSize), - // this is not needed, but we want to make it explicit that we're starting with nil as the tick channel - // reading from nil channels blocks forever, which allows us to prevent the netPoll loop from starting - // once we add the first connection, we start the ticker and set the tick channel - // after the last connection is removed, we set the tick channel to nil again - // this way we can start and stop the epoll loop dynamically - waitForEventsTick: nil, - } - - // ignore error is ok, it means that netPoll is not supported, which is handled gracefully by the client - poller, _ := netpoll.NewPoller(op.netPollConfiguration.BufferSize, op.netPollConfiguration.TickInterval) - if poller != nil { - client.netPoll = poller - go client.runNetPoll(engineCtx) - } - } - return client } -type connection struct { - id uint64 - fd int - netConn net.Conn - handler ConnectionHandler - shouldClose bool +// subscriptionClientV2 implements GraphQLSubscriptionClient using the new +// channel-based subscription client. +type subscriptionClientV2 struct { + client *client.Client + defaultErrorExtensionCode string } -// Subscribe initiates a new GraphQL Subscription with the origin -// If an existing WS connection with the same ID (Hash) exists, it is being re-used -// If connection protocol is SSE, a new connection is always created -// If no connection exists, the client initiates a new one -func (c *subscriptionClient) Subscribe(ctx *resolve.Context, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) error { - options.readTimeout = c.readTimeout - if options.UseSSE { - return c.subscribeSSE(ctx.Context(), c.engineCtx, options, updater) +// NewGraphQLSubscriptionClient creates a new subscription client. +func NewGraphQLSubscriptionClient(ctx context.Context, opts ...SubscriptionClientOption) GraphQLSubscriptionClient { + cfg := defaultsubscriptionClientConfig() + for _, opt := range opts { + opt(cfg) } - return c.subscribeWS(ctx.Context(), c.engineCtx, options, updater) -} - -func (c *subscriptionClient) subscribeSSE(requestContext, engineContext context.Context, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) error { - options.readTimeout = c.readTimeout - if c.streamingClient == nil { - return fmt.Errorf("streaming http client is nil") + return &subscriptionClientV2{ + defaultErrorExtensionCode: cfg.DefaultErrorExtensionCode, + client: client.New(ctx, client.Config{ + UpgradeClient: cfg.UpgradeClient, + StreamingClient: cfg.StreamingClient, + Logger: cfg.Logger, + PingInterval: cfg.PingInterval, + PingTimeout: cfg.PingTimeout, + AckTimeout: cfg.AckTimeout, + WriteTimeout: cfg.WriteTimeout, + ReadLimit: cfg.ReadLimit, + }), } - - handler := newSSEConnectionHandler(requestContext, engineContext, c.streamingClient, updater, options, c.log) - - go handler.StartBlocking() - - return nil } -func (c *subscriptionClient) subscribeWS(requestContext, engineContext context.Context, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) error { - options.readTimeout = c.readTimeout - options.pingInterval = c.pingInterval - options.pingTimeout = c.pingTimeout - - if c.httpClient == nil { - return fmt.Errorf("http client is nil") - } - - conn, err := c.newWSConnectionHandler(requestContext, engineContext, options, updater) +// Subscribe implements GraphQLSubscriptionClient. +func (c *subscriptionClientV2) Subscribe(ctx *resolve.Context, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) error { + opts, req, err := convertToClientOptions(options) if err != nil { return err } - go func() { - err := conn.handler.StartBlocking() - if err != nil { - if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) { - return - } - c.log.Error("subscriptionClient.subscribeWS", abstractlogger.Error(err)) - } - }() + handler := buildMessageHandler(updater, c.defaultErrorExtensionCode) - return nil -} - -func (c *subscriptionClient) asyncSubscribeWS(requestContext, engineContext context.Context, id uint64, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) error { - options.readTimeout = c.readTimeout - options.pingInterval = c.pingInterval - options.pingTimeout = c.pingTimeout - - if c.httpClient == nil { - return fmt.Errorf("http client is nil") - } - - conn, err := c.newWSConnectionHandler(requestContext, engineContext, options, updater) - if err != nil { - return err - } - - if c.netPoll == nil { - go func() { - err := conn.handler.StartBlocking() - if err != nil && !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { - c.log.Error("subscriptionClient.asyncSubscribeWS", abstractlogger.Error(err)) - } - }() - return nil - } - - // if we have netPoll, we need to add the connection to the netPoll - - // init the subscription - err = conn.handler.Subscribe() + cancel, err := c.client.Subscribe(ctx.Context(), req, opts, handler) if err != nil { + if isUpstreamError(err) { + updater.Error(formatUpstreamServiceError(err, c.defaultErrorExtensionCode)) + updater.Done() + return nil + } return err } - var fd int + context.AfterFunc(ctx.Context(), func() { + cancel() + updater.Done() + }) - // we have to check if the connection is a tls connection to get the underlying net.Conn - if tlsConn, ok := conn.netConn.(*tls.Conn); ok { - netConn := tlsConn.NetConn() - fd = netpoll.SocketFD(netConn) - } else { - fd = netpoll.SocketFD(conn.netConn) - } - - if fd == 0 { - c.log.Error("failed to get file descriptor from connection. This indicates a problem with the netPoll implementation") - return fmt.Errorf("failed to get file descriptor from connection") - } - - conn.id, conn.fd = id, fd - // submit the connection to the netPoll run loop - c.netPollState.addConn <- conn return nil } -type UpgradeRequestError struct { - URL string - StatusCode int -} - -func (u *UpgradeRequestError) Error() string { - return fmt.Sprintf("failed to upgrade connection to %s, status code: %d", u.URL, u.StatusCode) -} - -func (c *subscriptionClient) newWSConnectionHandler(requestContext, engineContext context.Context, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater) (*connection, error) { - // Any failure will be stored here, needed for deferred body closer. - var err error - - conn, subProtocol, err := c.dial(requestContext, options) - if err != nil { - return nil, err - } - - if conn == nil { - return nil, fmt.Errorf("failed to dial connection") - } - - // conn is not nil. Any errored return below could lead to a leaking connection. - // To avoid this, close connection if failure happened. - defer func() { - if err != nil { - conn.Close() - } - }() - - initMsg, err := c.getConnectionInitMessage(requestContext, options.URL, options.Header) - if err != nil { - return nil, err - } - - if len(options.InitialPayload) > 0 { - initMsg, err = jsonparser.Set(initMsg, options.InitialPayload, "payload") - if err != nil { - return nil, err - } - } - - if options.Body.Extensions != nil { - initMsg, err = jsonparser.Set(initMsg, options.Body.Extensions, "payload", "extensions") - if err != nil { - return nil, err - } - } - - // init + ack - if err = conn.SetWriteDeadline(time.Now().Add(writeTimeout)); err != nil { - return nil, err - } - err = wsutil.WriteClientText(conn, initMsg) - if err != nil { - return nil, err - } - - if err = waitForAck(conn, c.readTimeout, writeTimeout); err != nil { - return nil, err - } - - switch subProtocol { - case ProtocolGraphQLWS: - return newGQLWSConnectionHandler(requestContext, engineContext, conn, options, updater, c.log), nil - case ProtocolGraphQLTWS: - return newGQLTWSConnectionHandler(requestContext, engineContext, conn, options, updater, c.log), nil - default: - return nil, NewInvalidWsSubprotocolError(subProtocol) - } -} - -func (c *subscriptionClient) dial(ctx context.Context, options GraphQLSubscriptionOptions) (conn net.Conn, subProtocol string, err error) { - subProtocols := []string{ProtocolGraphQLTWS, ProtocolGraphQLWS} - if options.WsSubProtocol != "" && options.WsSubProtocol != "auto" { - subProtocols = []string{options.WsSubProtocol} - } - - clientTrace := &httptrace.ClientTrace{ - GotConn: func(info httptrace.GotConnInfo) { - conn = info.Conn - }, - } - clientTraceCtx := httptrace.WithClientTrace(ctx, clientTrace) - u := options.URL - if strings.HasPrefix(options.URL, "wss") { - u = "https" + options.URL[3:] - } else if strings.HasPrefix(options.URL, "ws") { - u = "http" + options.URL[2:] - } - req, err := http.NewRequestWithContext(clientTraceCtx, http.MethodGet, u, nil) - if err != nil { - return nil, "", err - } - req.Proto = "HTTP/1.1" - req.ProtoMajor = 1 - req.ProtoMinor = 1 - if options.Header != nil { - req.Header = options.Header - } - req.Header.Set("Sec-WebSocket-Protocol", strings.Join(subProtocols, ",")) - req.Header.Set("Sec-WebSocket-Version", "13") - req.Header.Set("Connection", "Upgrade") - req.Header.Set("Upgrade", "websocket") - - challengeKey, err := generateChallengeKey() - if err != nil { - return nil, "", err - } - - req.Header.Set("Sec-WebSocket-Key", challengeKey) - - upgradeResponse, err := c.httpClient.Do(req) - if err != nil { - return nil, "", err - } - - // On failed upgrades, we close the body without transferring ownership to the caller. - - if upgradeResponse.StatusCode != http.StatusSwitchingProtocols { - // Drain to EOF to allow connection reuse by net/http. - _, _ = io.Copy(io.Discard, upgradeResponse.Body) - upgradeResponse.Body.Close() - return nil, "", &UpgradeRequestError{ - URL: u, - StatusCode: upgradeResponse.StatusCode, - } - } - - accept := computeAcceptKey(challengeKey) - if upgradeResponse.Header.Get("Sec-WebSocket-Accept") != accept { - _, _ = io.Copy(io.Discard, upgradeResponse.Body) - upgradeResponse.Body.Close() - return nil, "", fmt.Errorf("invalid Sec-WebSocket-Accept") - } - - subProtocol = subProtocols[0] - if options.WsSubProtocol == "" || options.WsSubProtocol == "auto" { - subProtocol = upgradeResponse.Header.Get("Sec-WebSocket-Protocol") - if subProtocol == "" { - subProtocol = ProtocolGraphQLWS - } - } - - return conn, subProtocol, nil -} - -func generateChallengeKey() (string, error) { - p := make([]byte, 16) - if _, err := io.ReadFull(rand.Reader, p); err != nil { - return "", err - } - return base64.StdEncoding.EncodeToString(p), nil -} - -var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") - -func computeAcceptKey(challengeKey string) string { - h := sha1.New() // #nosec G401 -- (CWE-326) https://datatracker.ietf.org/doc/html/rfc6455#page-54 - h.Write([]byte(challengeKey)) - h.Write(keyGUID) - return base64.StdEncoding.EncodeToString(h.Sum(nil)) -} - -func (c *subscriptionClient) getConnectionInitMessage(ctx context.Context, url string, header http.Header) ([]byte, error) { - if c.onWsConnectionInitCallback == nil { - return connectionInitMessage, nil - } - - callback := *c.onWsConnectionInitCallback - - payload, err := callback(ctx, url, header) - if err != nil { - return nil, err - } - - if len(payload) == 0 { - return connectionInitMessage, nil - } - - msg, err := jsonparser.Set(connectionInitMessage, payload, "payload") - if err != nil { - return nil, err - } - - return msg, nil -} - -type ConnectionHandler interface { - // StartBlocking starts the connection handler and blocks until the connection is closed - // Only used as fallback when epoll is not available - StartBlocking() error - // HandleMessage handles the incoming message from the connection - HandleMessage(data []byte) (done bool) - // Ping sends a ping message to the upstream server to keep the connection alive. - // Implementers must keep track of the last ping time to initiate a connection shutdown - // if the upstream is not sending a pong. - Ping() - // ServerClose closes the connection from the server side - ServerClose() - // ClientClose closes the connection from the client side - ClientClose() - // Subscribe subscribes to the connection - Subscribe() error -} - -func waitForAck(conn net.Conn, readTimeout, writeTimeout time.Duration) error { - timer := time.NewTimer(ackWaitTimeout) - for { - select { - case <-timer.C: - return fmt.Errorf("timeout while waiting for connection_ack") - default: - } - if err := conn.SetReadDeadline(time.Now().Add(readTimeout)); err != nil { - return fmt.Errorf("failed to set read deadline: %w", err) - } - msg, err := wsutil.ReadServerText(conn) - if err != nil { - return err - } - respType, err := jsonparser.GetString(msg, "type") - if err != nil { - return err - } - - switch respType { - // TODO this method mixes message types from different protocols. We should - // move the specific protocol handling to the concrete implementation - case messageTypeConnectionKeepAlive: - continue - case messageTypePing: - if err := conn.SetWriteDeadline(time.Now().Add(writeTimeout)); err != nil { - return fmt.Errorf("failed to set write deadline: %w", err) +// buildMessageHandler creates the handler that bridges client.Message → resolve.SubscriptionUpdater. +func buildMessageHandler(updater resolve.SubscriptionUpdater, errorCode string) client.Handler { + return func(msg *client.Message) { + switch msg.Type { + case client.MessageTypeConnectionError: + updater.Error(formatUpstreamServiceError(msg.Err, errorCode)) + updater.Done() + case client.MessageTypeError: + data, err := json.Marshal(msg.Payload) + if err != nil { + updater.Error(formatSubscriptionError(err)) + updater.Done() + return } - err = wsutil.WriteClientText(conn, []byte(pongMessage)) + updater.Error(data) + updater.Done() + case client.MessageTypeData: + data, err := json.Marshal(msg.Payload) if err != nil { - return fmt.Errorf("failed to send pong message: %w", err) + updater.Error(formatSubscriptionError(err)) + updater.Done() + return } - continue - case messageTypeConnectionAck: - return nil - default: - return fmt.Errorf("expected connection_ack or ka, got %s", respType) + updater.Update(data) + case client.MessageTypeComplete: + updater.Complete() + updater.Done() } } } -type connResult struct { - fd int - shouldClose bool +// isUpstreamError reports whether err is a connection-level upstream error +// that should be surfaced as a GraphQL error to the client. +// ErrFailedUpgrade and ErrInvalidSubprotocol are intentionally excluded so +// they propagate to the router, which formats detailed error messages +// (e.g. including the subgraph name and HTTP status code). +func isUpstreamError(err error) bool { + return errors.Is(err, client.ErrConnectionClosed) || + errors.Is(err, client.ErrConnectionError) || + errors.Is(err, client.ErrInitFailed) || + errors.Is(err, client.ErrDialFailed) || + errors.Is(err, client.ErrAckTimeout) || + errors.Is(err, client.ErrAckNotReceived) || + errors.Is(err, context.Canceled) || + errors.Is(err, context.DeadlineExceeded) } -func (c *subscriptionClient) runNetPoll(ctx context.Context) { - defer c.close() - done := ctx.Done() - // both handleConnCh and connResults are buffered channels with a size of WaitForNumEvents - // this is important because we submit all events before we start processing them - // and we start evaluating the results only after all events have been submitted - // this would not be possible with unbuffered channels - handleConnCh := make(chan *connection, c.netPollConfig.WaitForNumEvents) - connResults := make(chan connResult, c.netPollConfig.WaitForNumEvents) - pingCh := make(chan *connection, c.netPollConfig.WaitForNumEvents) - - // Start workers to handle connection events - // MaxEventWorkers defines the parallelism of how many connections can be handled at the same time - // This is the critical number on how much CPU is used - for i := 0; i < c.netPollConfig.MaxEventWorkers; i++ { - go func() { - for { - select { - case conn := <-pingCh: - conn.handler.Ping() - case conn := <-handleConnCh: - shouldClose := c.handleConnectionEvent(conn) - connResults <- connResult{fd: conn.fd, shouldClose: shouldClose} - case <-done: - return - } - } - }() +// convertToClientOptions converts GraphQLSubscriptionOptions to the new client's types. +func convertToClientOptions(options GraphQLSubscriptionOptions) (client.Options, *client.Request, error) { + opts := client.Options{ + Endpoint: options.URL, + Headers: options.Header, } - pingTicker := time.NewTicker(c.pingInterval) - defer pingTicker.Stop() - - // This is the main netPoll run loop - // It's a single threaded event loop that reacts to several events, such as added connections, clients unsubscribing, etc. - for { - select { - // if the engine context is done, we close the netPoll loop - case <-done: - return - case <-pingTicker.C: - // Send a ping to all connections - // We distribute the ping to all workers to prevent single threaded bottlenecks - // However, this required state synchronization with the last ping time on the handler - // because PING and PONG can be handled on different go routines - for _, conn := range c.netPollState.connections { - pingCh <- conn - } - case conn := <-c.netPollState.addConn: - c.handleAddConn(conn) - case id := <-c.netPollState.clientUnsubscribe: - c.handleClientUnsubscribe(id) - // while len(c.connections) == 0, this channel is nil, so we will never try to wait for netPoll events - // this is important to prevent busy waiting - // once we add the first connection, we start the ticker and set the tick channel - // the ticker ensures that we don't poll the netPoll instance all the time, - // but at most every TickInterval - case <-c.netPollState.waitForEventsTick: - events, err := c.netPoll.Wait(c.netPollConfig.WaitForNumEvents) - if err != nil { - c.log.Error("netPoll.Wait", abstractlogger.Error(err)) - continue - } - - waitForEvents := len(events) - - for i := range events { - fd := netpoll.SocketFD(events[i]) - conn, ok := c.netPollState.connections[fd] - if !ok { - // This can happen if the client was unsubscribed - // and the ticker is still running because we haven't removed the last connection yet - continue - } - // submit the connection to the worker pool - handleConnCh <- conn - } - - // we submit all events to the worker pool to handle all events in parallel - // instead of just waiting until all handlers are done, we can handle newly added connections or clients unsubscribing simultaneously - // we keep doing this until we have results for all events or the engine context is done - // this allows us to keep handling events in parallel while being able to manage connections without locks - // as a result, we can handle a large number of connections with a single threaded event loop - - // once we have results for all events, we can return to the top level loop and wait for the next tick - for waitForEvents > 0 { - select { - case result := <-connResults: - // if the connection indicates that it should be closed, we close and remove it - if result.shouldClose { - c.handleServerUnsubscribe(result.fd) - } - // we decrease the number of events we're waiting for to eventually break the loop - waitForEvents-- - case conn := <-c.netPollState.addConn: - c.handleAddConn(conn) - case id := <-c.netPollState.clientUnsubscribe: - c.handleClientUnsubscribe(id) - case <-done: - return - } - } - } - } -} - -func (c *subscriptionClient) close() { - defer c.log.Debug("subscriptionClient.close", abstractlogger.String("reason", "netPoll closed by context")) - if c.netPollState.waitForEventsTicker != nil { - c.netPollState.waitForEventsTicker.Stop() - } - for _, conn := range c.netPollState.connections { - _ = c.netPoll.Remove(conn.netConn) - conn.handler.ServerClose() - } - if c.netPoll != nil { - err := c.netPoll.Close(false) - if err != nil { - c.log.Error("subscriptionClient.close", abstractlogger.Error(err)) + // Transport selection + if options.UseSSE { + opts.Transport = client.TransportSSE + if options.SSEMethodPost { + opts.SSEMethod = client.SSEMethodPOST + } else { + opts.SSEMethod = client.SSEMethodGET } - } -} - -func (c *subscriptionClient) handleAddConn(conn *connection) { - var netConn net.Conn - - if tlsConn, ok := conn.netConn.(*tls.Conn); ok { - netConn = tlsConn.NetConn() } else { - netConn = conn.netConn + opts.Transport = client.TransportWS + opts.WSSubprotocol = mapWSSubprotocol(options.WsSubProtocol) } - if err := c.netPoll.Add(netConn); err != nil { - c.log.Error("subscriptionClient.handleAddConn", abstractlogger.Error(err)) - conn.handler.ServerClose() - return + // Convert InitialPayload from json.RawMessage to map[string]any + if len(options.InitialPayload) > 0 { + var initPayload map[string]any + if err := json.Unmarshal(options.InitialPayload, &initPayload); err != nil { + return client.Options{}, nil, fmt.Errorf("failed to unmarshal initial payload: %w", err) + } + opts.InitPayload = initPayload } - c.netPollState.connections[conn.fd] = conn - c.netPollState.triggers[conn.id] = conn.fd - // when we previously had 0 connections, we will have 1 connection now - // this means we need to start the ticker so that we get netPoll events - if len(c.netPollState.connections) == 1 { - c.netPollState.waitForEventsTicker = time.NewTicker(c.netPollConfig.TickInterval) - c.netPollState.waitForEventsTick = c.netPollState.waitForEventsTicker.C - c.netPollState.hasConnections.Store(true) + req := &client.Request{ + Query: options.Body.Query, + OperationName: options.Body.OperationName, + Variables: options.Body.Variables, + Extensions: options.Body.Extensions, } -} -func (c *subscriptionClient) handleClientUnsubscribe(id uint64) { - fd, ok := c.netPollState.triggers[id] - if !ok { - return - } - delete(c.netPollState.triggers, id) - conn, ok := c.netPollState.connections[fd] - if !ok { - return - } - delete(c.netPollState.connections, fd) - _ = c.netPoll.Remove(conn.netConn) - conn.handler.ClientClose() - // if we have no connections left, we stop the ticker - if len(c.netPollState.connections) == 0 { - c.netPollState.waitForEventsTicker.Stop() - c.netPollState.waitForEventsTick = nil - c.netPollState.hasConnections.Store(false) - } + return opts, req, nil } -func (c *subscriptionClient) handleServerUnsubscribe(fd int) { - conn, ok := c.netPollState.connections[fd] - if !ok { - return - } - delete(c.netPollState.connections, fd) - delete(c.netPollState.triggers, conn.id) - _ = c.netPoll.Remove(conn.netConn) - conn.handler.ServerClose() - // if we have no connections left, we stop the ticker - if len(c.netPollState.connections) == 0 { - c.netPollState.waitForEventsTicker.Stop() - c.netPollState.waitForEventsTick = nil - c.netPollState.hasConnections.Store(false) - } -} - -func (c *subscriptionClient) handleConnectionEvent(conn *connection) bool { - data, err := readMessage(conn.netConn, c.frameTimeout, c.readTimeout) - if err != nil { - return handleConnectionError(err) +// mapWSSubprotocol maps the string subprotocol to the client.WSSubprotocol type. +func mapWSSubprotocol(proto string) client.WSSubprotocol { + switch proto { + case string(client.SubprotocolGraphQLWS): + return client.SubprotocolGraphQLWS + case string(client.SubprotocolGraphQLTransportWS): + return client.SubprotocolGraphQLTransportWS + default: + return client.SubprotocolAuto } - return conn.handler.HandleMessage(data) } -func handleConnectionError(err error) (done bool) { - netOpErr := &net.OpError{} - if errors.As(err, &netOpErr) { - return !netOpErr.Timeout() +// formatUpstreamServiceError formats a connection-level error as a GraphQL error +// response with the configured error extension code. If the error chain contains +// a WebSocket close error, the close code and reason are included in extensions. +func formatUpstreamServiceError(err error, code string) []byte { + type errorExtensions struct { + Code string `json:"code"` + CloseCode int `json:"closeCode,omitempty"` + Reason string `json:"closeReason,omitempty"` } - // Check if we have errors during reading from the connection - if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { - return true + type graphqlError struct { + Message string `json:"message"` + Extensions errorExtensions `json:"extensions"` } - // Check if we have a context error - if errors.Is(err, context.DeadlineExceeded) { - return false - } - - // Check if the error is a connection reset by peer - if errors.Is(err, syscall.ECONNRESET) { - return true - } - if errors.Is(err, syscall.EPIPE) { - return true + gqlErr := graphqlError{ + Message: "upstream service error", + Extensions: errorExtensions{Code: code}, } - // Check if the error is a closed network connection. Introduced in go 1.16. - // This replaces the string match of "use of closed network connection" - if errors.Is(err, net.ErrClosed) { - return true + var closeErr websocket.CloseError + if errors.As(err, &closeErr) { + gqlErr.Extensions.CloseCode = int(closeErr.Code) + gqlErr.Extensions.Reason = closeErr.Reason } - // Check if the error is closed websocket connection - if errors.As(err, &wsutil.ClosedError{}) { - return true + resp := struct { + Errors []graphqlError `json:"errors"` + }{ + Errors: []graphqlError{gqlErr}, } - - return false + data, _ := json.Marshal(resp) + return data } -// readMessage reads a message from the connection -func readMessage(conn net.Conn, frameTimeout time.Duration, readTimeout time.Duration) ([]byte, error) { - controlHandler := wsutil.ControlFrameHandler(conn, ws.StateClientSide) - rd := &wsutil.Reader{ - Source: conn, - State: ws.StateClientSide, - CheckUTF8: true, - SkipHeaderCheck: false, - OnIntermediate: controlHandler, +// formatSubscriptionError formats an error as a GraphQL error response. +func formatSubscriptionError(err error) []byte { + errResponse := struct { + Errors []struct { + Message string `json:"message"` + } `json:"errors"` + }{ + Errors: []struct { + Message string `json:"message"` + }{ + {Message: err.Error()}, + }, } + data, _ := json.Marshal(errResponse) + return data +} - for { - // This method is used to check if we have data on the connection. The timeout needs to be much smaller - // than the readTimeout to ensure we don't block the connection for too long. If we have no data, we move - // on to the next connection. - err := conn.SetReadDeadline(time.Now().Add(frameTimeout)) - if err != nil { - return nil, err - } +// GraphQLSubscriptionClientFactory abstracts the way of creating a new GraphQLSubscriptionClient. +// This can be very handy for testing purposes. +type GraphQLSubscriptionClientFactory interface { + NewSubscriptionClient(ctx context.Context, options ...SubscriptionClientOption) GraphQLSubscriptionClient +} - // If we have data, we can read it. Otherwise, it will timeout and we wait for the next epoll tick - hdr, err := rd.NextFrame() - if err != nil { - // A timeout will not close the connection but return an error - return nil, err - } - if hdr.OpCode.IsControl() { - // The controlHandler writes the control frames. - // We need to work with a proper timeout to ensure we don't block forever. - err := conn.SetWriteDeadline(time.Now().Add(frameTimeout)) - if err != nil { - return nil, err - } - // Handles PING/PONG and CLOSE frames, but only on the ws protocol level - // We still need to handle the PING/PONG frames on the application protocol level - if err := controlHandler(hdr, rd); err != nil { - return nil, err - } - continue - } +type DefaultSubscriptionClientFactory struct{} - // We are only interested in text frames - if hdr.OpCode&ws.OpText == 0 { - // If we see anything else than a text frame, we need to discard the frame - if err := rd.Discard(); err != nil { - return nil, err - } - continue - } +func (d *DefaultSubscriptionClientFactory) NewSubscriptionClient(ctx context.Context, options ...SubscriptionClientOption) GraphQLSubscriptionClient { + return NewGraphQLSubscriptionClient(ctx, options...) +} - // We limit the amount of time we wait for a message to be read from the connection - // This is important to ensure we don't block the connection for too long - err = conn.SetReadDeadline(time.Now().Add(readTimeout)) - if err != nil { - return nil, err - } - return io.ReadAll(rd) - } +func IsDefaultGraphQLSubscriptionClient(client GraphQLSubscriptionClient) bool { + _, ok := client.(*subscriptionClientV2) + return ok } diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go index 86dd57c030..6ee1b013d0 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_subscription_client_test.go @@ -2,2570 +2,106 @@ package graphql_datasource import ( "context" - "encoding/base64" "encoding/json" - "fmt" - "net/http" - "net/http/httptest" - "runtime" - "strings" - "sync" "testing" - "time" - "github.com/buger/jsonparser" - "github.com/coder/websocket" - ll "github.com/jensneuse/abstractlogger" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "go.uber.org/atomic" - "go.uber.org/goleak" - "go.uber.org/zap" + client "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) -func logger() ll.Logger { - logger, err := zap.NewDevelopmentConfig().Build() - if err != nil { - panic(err) - } - - return ll.NewZapLogger(logger, ll.DebugLevel) +type testBridgeUpdater struct { + updates [][]byte + errors [][]byte + completed bool + done bool } -func TestGetConnectionInitMessageHelper(t *testing.T) { - var callback OnWsConnectionInitCallback = func(ctx context.Context, url string, header http.Header) (json.RawMessage, error) { - return json.RawMessage(`{"authorization":"secret"}`), nil - } - - tests := []struct { - name string - callback *OnWsConnectionInitCallback - want string - }{ - { - name: "without payload", - callback: nil, - want: `{"type":"connection_init"}`, - }, - { - name: "with payload", - callback: &callback, - want: `{"type":"connection_init","payload":{"authorization":"secret"}}`, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - client := subscriptionClient{onWsConnectionInitCallback: tt.callback} - got, err := client.getConnectionInitMessage(context.Background(), "", nil) - require.NoError(t, err) - require.NotEmpty(t, got) - - assert.Equal(t, tt.want, string(got)) - }) - } +func (t *testBridgeUpdater) Update(data []byte) { + t.updates = append(t.updates, data) } -func TestWebsocketSubscriptionClientImmediateClientCancel(t *testing.T) { - serverInvocations := atomic.NewInt64(0) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - serverInvocations.Inc() - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, +func (t *testBridgeUpdater) UpdateSubscription(id resolve.SubscriptionIdentifier, data []byte) {} - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} - go func() { - err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - }, updater) - assert.Error(t, err) - }() - assert.Eventuallyf(t, func() bool { - return serverInvocations.Load() == 0 - }, time.Second, time.Millisecond*10, "server did not close") - serverCancel() +func (t *testBridgeUpdater) Complete() { + t.completed = true } -func TestWebsocketSubscriptionClientWithServerDisconnect(t *testing.T) { - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"start","id":"1","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"data","id":"1","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"data","id":"1","payload":{"data":{"messageAdded":{"text":"second"}}}}`)) - assert.NoError(t, err) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"data","id":"1","payload":{"data":{"messageAdded":{"text":"third"}}}}`)) - assert.NoError(t, err) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"stop","id":"1"}`, string(data)) - - _, _, err = conn.Read(ctx) - assert.Error(t, err) - - close(serverDone) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} - go func() { - err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - }, updater) - assert.NoError(t, err) - }() - updater.AwaitUpdates(t, time.Second, 3) - assert.Equal(t, 3, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) - serverCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second, time.Millisecond*10, "server did not close") +func (t *testBridgeUpdater) Error(data []byte) { + t.errors = append(t.errors, data) } -// didnt configure subprotocol, but the subgraph return graphql-ws -func TestSubprotocolNegotiationWithGraphQLWS(t *testing.T) { - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - Subprotocols: []string{"graphql-ws"}, - }) - assert.NoError(t, err) - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"start","id":"1","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"data","id":"1","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"data","id":"1","payload":{"data":{"messageAdded":{"text":"second"}}}}`)) - assert.NoError(t, err) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"data","id":"1","payload":{"data":{"messageAdded":{"text":"third"}}}}`)) - assert.NoError(t, err) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"stop","id":"1"}`, string(data)) - close(serverDone) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} - go func() { - err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - }, updater) - assert.NoError(t, err) - }() - updater.AwaitUpdates(t, time.Second, 3) - assert.Equal(t, 3, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) - clientCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second, time.Millisecond*10, "server did not close") - serverCancel() +func (t *testBridgeUpdater) Done() { + t.done = true } -// didnt configure subprotocol, but the subgraph return graphql-transport-ws -func TestSubprotocolNegotiationWithGraphQLTransportWS(t *testing.T) { - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - Subprotocols: []string{"graphql-transport-ws"}, - }) - assert.NoError(t, err) - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"second"}}}}`)) - assert.NoError(t, err) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"third"}}}}`)) - assert.NoError(t, err) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"complete"}`, string(data)) - close(serverDone) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} - go func() { - err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - }, updater) - assert.NoError(t, err) - }() - updater.AwaitUpdates(t, time.Second, 3) - assert.Equal(t, 3, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) - clientCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second, time.Millisecond*10, "server did not close") - serverCancel() +func (t *testBridgeUpdater) CloseSubscription(id resolve.SubscriptionIdentifier) { } -// In this case the subgraph doesnt return the subprotocol and we didnt configure the subprotocol, so falls back to graphql-ws -func TestSubprotocolNegotiationWithNoSubprotocol(t *testing.T) { - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"start","id":"1","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"data","id":"1","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"data","id":"1","payload":{"data":{"messageAdded":{"text":"second"}}}}`)) - assert.NoError(t, err) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"data","id":"1","payload":{"data":{"messageAdded":{"text":"third"}}}}`)) - assert.NoError(t, err) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"stop","id":"1"}`, string(data)) - close(serverDone) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} - go func() { - err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - }, updater) - assert.NoError(t, err) - }() - updater.AwaitUpdates(t, time.Second, 3) - assert.Equal(t, 3, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) - clientCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second, time.Millisecond*10, "server did not close") - serverCancel() +func (t *testBridgeUpdater) Subscriptions() map[context.Context]resolve.SubscriptionIdentifier { + return map[context.Context]resolve.SubscriptionIdentifier{} } -func TestSubprotocolNegotiationWithConfiguredGraphQLWS(t *testing.T) { - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"start","id":"1","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"data","id":"1","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"data","id":"1","payload":{"data":{"messageAdded":{"text":"second"}}}}`)) - assert.NoError(t, err) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"data","id":"1","payload":{"data":{"messageAdded":{"text":"third"}}}}`)) - assert.NoError(t, err) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"stop","id":"1"}`, string(data)) - close(serverDone) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} - go func() { - err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - WsSubProtocol: ProtocolGraphQLWS, - }, updater) - assert.NoError(t, err) - }() - updater.AwaitUpdates(t, time.Second, 3) - assert.Equal(t, 3, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) - clientCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second, time.Millisecond*10, "server did not close") - serverCancel() -} - -func TestSubprotocolNegotiationWithConfiguredGraphQLTransportWS(t *testing.T) { - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "done") - }() - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"second"}}}}`)) - assert.NoError(t, err) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"third"}}}}`)) - assert.NoError(t, err) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"complete"}`, string(data)) - close(serverDone) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} - - go func() { - err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - WsSubProtocol: ProtocolGraphQLTWS, - }, updater) - assert.NoError(t, err) - }() - - updater.AwaitUpdates(t, time.Second, 3) - assert.Equal(t, 3, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) - clientCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second, time.Millisecond*10, "server did not close") - serverCancel() -} - -func TestWebSocketClientLeaks(t *testing.T) { - defer goleak.VerifyNone(t, - goleak.IgnoreCurrent(), // ignore the test itself - ) - serverDone := &sync.WaitGroup{} - serverDone.Add(2) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "done") - serverDone.Done() - }() - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 1) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 1) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"second"}}}}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 1) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"third"}}}}`)) - assert.NoError(t, err) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"complete"}`, string(data)) - })) - defer server.Close() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - - WithLogger(logger()), - ).(*subscriptionClient) - wg := &sync.WaitGroup{} - wg.Add(2) - for i := 0; i < 2; i++ { - go func(i int) { - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - updater := &testSubscriptionUpdater{} - err := client.SubscribeAsync(resolve.NewContext(ctx), uint64(i), GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - WsSubProtocol: ProtocolGraphQLTWS, - }, updater) - assert.NoError(t, err) - - updater.AwaitUpdates(t, time.Second*10, 3) - assert.Equal(t, 3, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) - client.Unsubscribe(uint64(i)) - clientCancel() - wg.Done() - }(i) - } - wg.Wait() - time.Sleep(time.Second) - serverCancel() - time.Sleep(time.Second) - serverDone.Wait() -} - -func TestAsyncSubscribe(t *testing.T) { - t.Parallel() - if runtime.GOOS == "windows" { - t.SkipNow() - } - t.Run("subscribe async", func(t *testing.T) { - t.Parallel() - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "done") - defer close(serverDone) - }() - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 1) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 1) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"second"}}}}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 1) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"third"}}}}`)) - assert.NoError(t, err) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"complete"}`, string(data)) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} - - err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - WsSubProtocol: ProtocolGraphQLTWS, - }, updater) - assert.NoError(t, err) - - updater.AwaitUpdates(t, time.Second*10, 3) - assert.Equal(t, 3, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) - client.Unsubscribe(1) - clientCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second*5, time.Millisecond*10, "server did not close") - serverCancel() - }) - t.Run("server timeout", func(t *testing.T) { - t.Parallel() - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "done") - close(serverDone) - }() - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"second"}}}}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 2) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"third"}}}}`)) - assert.NoError(t, err) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"complete"}`, string(data)) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} - - err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - WsSubProtocol: ProtocolGraphQLTWS, - }, updater) - assert.NoError(t, err) - - updater.AwaitUpdates(t, time.Second*10, 3) - assert.Equal(t, 3, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) - client.Unsubscribe(1) - clientCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second*5, time.Millisecond*10, "server did not close") - serverCancel() - }) - t.Run("server complete", func(t *testing.T) { - t.Parallel() - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "done") - }() - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"complete","id":"1"}`)) - assert.NoError(t, err) - close(serverDone) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} +func TestBuildMessageHandlerRoutesEachMessageTypeCorrectly(t *testing.T) { + t.Run("error is upstream service error for connection error", func(t *testing.T) { + updater := &testBridgeUpdater{} + handler := buildMessageHandler(updater, "DOWNSTREAM_SERVICE_ERROR") - err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - WsSubProtocol: ProtocolGraphQLTWS, - }, updater) - assert.NoError(t, err) - - updater.AwaitUpdates(t, time.Second*10, 1) - assert.Equal(t, 1, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - client.Unsubscribe(1) - clientCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second, time.Millisecond*10, "server did not close") - serverCancel() - }) - t.Run("server ka", func(t *testing.T) { - t.Parallel() - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "done") - }() - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"ka"}`)) - assert.NoError(t, err) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"ka"}`)) - assert.NoError(t, err) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"complete","id":"1"}`)) - assert.NoError(t, err) - close(serverDone) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} - - err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - WsSubProtocol: ProtocolGraphQLTWS, - }, updater) - assert.NoError(t, err) - - updater.AwaitUpdates(t, time.Second*10, 1) - assert.Equal(t, 1, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - client.Unsubscribe(1) - clientCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second, time.Millisecond*10, "server did not close") - serverCancel() - }) - t.Run("long timeout", func(t *testing.T) { - t.Parallel() - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "done") - }() - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 2) + handler(&client.Message{Type: client.MessageTypeConnectionError, Err: client.ErrConnectionClosed}) - close(serverDone) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} - - err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - WsSubProtocol: ProtocolGraphQLTWS, - }, updater) - assert.NoError(t, err) - - updater.AwaitUpdates(t, time.Second*10, 1) - assert.Equal(t, 1, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second*5, time.Millisecond*10, "server did not close") - time.Sleep(time.Second) - assert.Equal(t, false, client.netPollState.hasConnections.Load()) + require.True(t, updater.done) + require.Len(t, updater.errors, 1) + assert.Contains(t, string(updater.errors[0]), "DOWNSTREAM_SERVICE_ERROR") + require.Empty(t, updater.updates) + require.False(t, updater.completed) }) - t.Run("forever timeout", func(t *testing.T) { - t.Parallel() - globalCtx, cancel := context.WithCancel(context.Background()) - defer cancel() - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "done") - }() - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - <-globalCtx.Done() - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} + t.Run("error contains payload for graphql error", func(t *testing.T) { + updater := &testBridgeUpdater{} + handler := buildMessageHandler(updater, "DOWNSTREAM_SERVICE_ERROR") - err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, + handler(&client.Message{ + Type: client.MessageTypeError, + Payload: &client.ExecutionResult{ + Errors: json.RawMessage(`[{"message":"field not found"}]`), }, - WsSubProtocol: ProtocolGraphQLTWS, - }, updater) - assert.NoError(t, err) - - updater.AwaitUpdates(t, time.Second*3, 1) - assert.Equal(t, 1, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - time.Sleep(time.Second * 2) - }) - t.Run("graphql-ws", func(t *testing.T) { - t.Parallel() - t.Run("happy path", func(t *testing.T) { - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "done") - }() - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 1) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"start","id":"1","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"data","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 1) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"data","payload":{"data":{"messageAdded":{"text":"second"}}}}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 1) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"data","payload":{"data":{"messageAdded":{"text":"third"}}}}`)) - assert.NoError(t, err) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"stop","id":"1"}`, string(data)) - - ctx = conn.CloseRead(ctx) - <-ctx.Done() - close(serverDone) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} - - err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - WsSubProtocol: ProtocolGraphQLWS, - }, updater) - assert.NoError(t, err) - - updater.AwaitUpdates(t, time.Second*10, 3) - assert.Equal(t, 3, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) - client.Unsubscribe(1) - - clientCancel() - - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second*5, time.Millisecond*10, "server did not close") - serverCancel() - }) - t.Run("connection error", func(t *testing.T) { - t.Parallel() - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "done") - }() - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 1) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"start","id":"1","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"connection_error"}`)) - assert.NoError(t, err) - - _ = conn.Close(websocket.StatusNormalClosure, "done") - - close(serverDone) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} - - err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - WsSubProtocol: ProtocolGraphQLWS, - }, updater) - assert.NoError(t, err) - - updater.AwaitUpdates(t, time.Second*5, 1) - assert.Equal(t, 1, len(updater.updates)) - assert.Equal(t, `{"errors":[{"message":"connection error"}]}`, updater.updates[0]) - client.Unsubscribe(1) - clientCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second, time.Millisecond*10, "server did not close") - serverCancel() - assert.Equal(t, false, client.netPollState.hasConnections.Load()) - }) - t.Run("error object", func(t *testing.T) { - t.Parallel() - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "done") - }() - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 1) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"start","id":"1","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"error","payload":{"message":"ws error"}}`)) - assert.NoError(t, err) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"stop","id":"1"}`, string(data)) - close(serverDone) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} - - err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - WsSubProtocol: ProtocolGraphQLWS, - }, updater) - assert.NoError(t, err) - - updater.AwaitUpdates(t, time.Second*5, 1) - assert.Equal(t, 1, len(updater.updates)) - assert.Equal(t, `{"errors":[{"message":"ws error"}]}`, updater.updates[0]) - client.Unsubscribe(1) - clientCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second*5, time.Millisecond*10, "server did not close") - serverCancel() - }) - t.Run("error array", func(t *testing.T) { - t.Parallel() - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "done") - close(serverDone) - }() - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 1) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"start","id":"1","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"error","payload":[{"message":"ws error"}]}`)) - assert.NoError(t, err) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"stop","id":"1"}`, string(data)) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} - - err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - WsSubProtocol: ProtocolGraphQLWS, - }, updater) - assert.NoError(t, err) - - updater.AwaitUpdates(t, time.Second*5, 1) - assert.Equal(t, 1, len(updater.updates)) - assert.Equal(t, `{"errors":[{"message":"ws error"}]}`, updater.updates[0]) - client.Unsubscribe(1) - clientCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second*5, time.Millisecond*10, "server did not close") - serverCancel() - }) - }) - t.Run("graphql-transport-ws", func(t *testing.T) { - t.Parallel() - t.Run("happy path", func(t *testing.T) { - t.Parallel() - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "done") - }() - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 1) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 1) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"second"}}}}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 1) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"third"}}}}`)) - assert.NoError(t, err) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"complete"}`, string(data)) - - ctx = conn.CloseRead(ctx) - <-ctx.Done() - close(serverDone) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} - - err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - WsSubProtocol: ProtocolGraphQLTWS, - }, updater) - assert.NoError(t, err) - - updater.AwaitUpdates(t, time.Second*10, 3) - assert.Equal(t, 3, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) - client.Unsubscribe(1) - clientCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second*5, time.Millisecond*10, "server did not close") - serverCancel() - }) - t.Run("happy path no netPoll", func(t *testing.T) { - t.Parallel() - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "done") - }() - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 1) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 1) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"second"}}}}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 1) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"third"}}}}`)) - assert.NoError(t, err) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"complete"}`, string(data)) - close(serverDone) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithLogger(logger()), - WithNetPollConfiguration(NetPollConfiguration{ - Enable: false, - }), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} - - err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - WsSubProtocol: ProtocolGraphQLTWS, - }, updater) - assert.NoError(t, err) - - updater.AwaitUpdates(t, time.Second*10, 3) - assert.Equal(t, 3, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) - clientCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second, time.Millisecond*10, "server did not close") - serverCancel() - }) - t.Run("happy path no netPoll two clients", func(t *testing.T) { - t.Parallel() - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "done") - }() - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 1) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 1) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"second"}}}}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 1) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"third"}}}}`)) - assert.NoError(t, err) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"complete"}`, string(data)) - })) - defer server.Close() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithLogger(logger()), - WithNetPollConfiguration(NetPollConfiguration{ - Enable: false, - }), - ).(*subscriptionClient) - wg := &sync.WaitGroup{} - wg.Add(2) - for i := 0; i < 2; i++ { - go func(i int) { - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - updater := &testSubscriptionUpdater{} - err := client.SubscribeAsync(resolve.NewContext(ctx), uint64(i), GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - WsSubProtocol: ProtocolGraphQLTWS, - }, updater) - assert.NoError(t, err) - - updater.AwaitUpdates(t, time.Second*10, 3) - assert.Equal(t, 3, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) - clientCancel() - wg.Done() - }(i) - } - wg.Wait() }) - t.Run("ping", func(t *testing.T) { - t.Parallel() - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "done") - }() - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"ping"}`)) - assert.NoError(t, err) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"pong"}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"second"}}}}`)) - assert.NoError(t, err) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"third"}}}}`)) - assert.NoError(t, err) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"complete"}`, string(data)) - close(serverDone) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} - - err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - WsSubProtocol: ProtocolGraphQLTWS, - }, updater) - assert.NoError(t, err) - - updater.AwaitUpdates(t, time.Second*10, 3) - assert.Equal(t, 3, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) - client.Unsubscribe(1) - clientCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second*5, time.Millisecond*10, "server did not close") - serverCancel() - }) - t.Run("ka", func(t *testing.T) { - t.Parallel() - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "done") - }() - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 1) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"ka"}`)) - assert.NoError(t, err) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"second"}}}}`)) - assert.NoError(t, err) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"third"}}}}`)) - assert.NoError(t, err) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"complete"}`, string(data)) - close(serverDone) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} - - err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - WsSubProtocol: ProtocolGraphQLTWS, - }, updater) - assert.NoError(t, err) - - updater.AwaitUpdates(t, time.Second*10, 3) - assert.Equal(t, 3, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) - client.Unsubscribe(1) - clientCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second*5, time.Millisecond*10, "server did not close") - serverCancel() - }) - t.Run("error object", func(t *testing.T) { - t.Parallel() - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "done") - }() - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 1) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"error","payload":{"message":"ws error"}}`)) - assert.NoError(t, err) - - _ = conn.Close(websocket.StatusNormalClosure, "done") - - close(serverDone) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} - - err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - WsSubProtocol: ProtocolGraphQLTWS, - }, updater) - assert.NoError(t, err) - updater.AwaitUpdates(t, time.Second*5, 2) - assert.Equal(t, 2, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - assert.Equal(t, `{"errors":[{"message":"ws error"}]}`, updater.updates[1]) - client.Unsubscribe(1) - clientCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second, time.Millisecond*10, "server did not close") - serverCancel() - }) - t.Run("error array", func(t *testing.T) { - t.Parallel() - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "done") - }() - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 1) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"error","payload":[{"message":"ws error"}]}`)) - assert.NoError(t, err) - - _ = conn.Close(websocket.StatusNormalClosure, "done") - - close(serverDone) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} - - err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - WsSubProtocol: ProtocolGraphQLTWS, - }, updater) - assert.NoError(t, err) - - updater.AwaitUpdates(t, time.Second*5, 2) - assert.Equal(t, 2, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - assert.Equal(t, `{"errors":[{"message":"ws error"}]}`, updater.updates[1]) - client.Unsubscribe(1) - clientCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second, time.Millisecond*10, "server did not close") - serverCancel() - }) - t.Run("data error", func(t *testing.T) { - t.Parallel() - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "done") - }() - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 1) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"data","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - - close(serverDone) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} - - err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - WsSubProtocol: ProtocolGraphQLTWS, - }, updater) - assert.NoError(t, err) - updater.AwaitUpdates(t, time.Second*5, 1) - assert.Equal(t, 1, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - client.Unsubscribe(1) - clientCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second, time.Millisecond*10, "server did not close") - serverCancel() - }) - t.Run("connection error", func(t *testing.T) { - t.Parallel() - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "done") - }() - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - - time.Sleep(time.Second * 1) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"connection_error"}`)) - assert.NoError(t, err) - - close(serverDone) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} - - err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - WsSubProtocol: ProtocolGraphQLTWS, - }, updater) - assert.NoError(t, err) - updater.AwaitUpdates(t, time.Second*5, 1) - assert.Equal(t, 1, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - client.Unsubscribe(1) - clientCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second, time.Millisecond*10, "server did not close") - serverCancel() - }) + require.True(t, updater.done) + require.Len(t, updater.errors, 1) + assert.Contains(t, string(updater.errors[0]), "field not found") + require.Empty(t, updater.updates) + require.False(t, updater.completed) }) -} - -func TestClientToSubgraphPingPong(t *testing.T) { - t.Parallel() - if runtime.GOOS == "windows" { - t.Skip("Skipping test on Windows as it's not reliable") - } - - t.Run("client sends ping message after configured interval", func(t *testing.T) { - t.Parallel() - - serverDone := make(chan struct{}) // to signal server done - // buffered channels and non-blocking send to avoid double-close panics if events repeat - pingReceived := make(chan struct{}, 1) // signaled when the server receives a ping - payloadSend := make(chan struct{}, 1) // signaled when the server sends a payload - - // Create test server that will handle the WebSocket connection - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Sec-WebSocket-Protocol", ProtocolGraphQLTWS) - conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - Subprotocols: []string{ProtocolGraphQLTWS}, - }) - assert.NoError(t, err) - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "done") - close(serverDone) - }() - - ctx := context.Background() - - // Handle connection initialization - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - - // Handle subscription start - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - // Send initial data - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"initial data"}}}}`)) - assert.NoError(t, err) - - // Track what messages we've received - receivedPing := false - receivedComplete := false - - // Create a context with timeout for reading messages - readCtx, cancel := context.WithTimeout(ctx, 3*time.Second) - defer cancel() - - // Process up to 5 messages or until we see both a ping and a complete - for i := 0; i < 5; i++ { - if receivedPing && receivedComplete { - break - } - - _, data, err = conn.Read(readCtx) - if err != nil { - // Connection closed or timeout - t.Logf("Connection read ended: %v", err) - break - } - - messageStr := string(data) - t.Logf("Received message: %s", messageStr) - - switch messageStr { - case `{"type":"ping"}`: - receivedPing = true - select { - case pingReceived <- struct{}{}: - default: - } - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"pong"}`)) - assert.NoError(t, err) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"after ping-pong"}}}}`)) - assert.NoError(t, err) - select { - case payloadSend <- struct{}{}: - default: - } - case `{"id":"1","type":"complete"}`: - receivedComplete = true - } - } - // Test is successful if we received a ping message - if !receivedPing { - t.Error("Did not receive ping message from client") - } - })) - defer server.Close() + t.Run("update is delivered without completing for data message", func(t *testing.T) { + updater := &testBridgeUpdater{} + handler := buildMessageHandler(updater, "DOWNSTREAM_SERVICE_ERROR") - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - // Create subscription client with a short ping interval for testing - pingInterval := 400 * time.Millisecond - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithLogger(logger()), - WithPingInterval(pingInterval), - WithPingTimeout(1*time.Second), - WithNetPollConfiguration(NetPollConfiguration{ - Enable: true, - TickInterval: 100 * time.Millisecond, - BufferSize: 10, - MaxEventWorkers: 2, - WaitForNumEvents: 1, - }), - ).(*subscriptionClient) - - updater := &testSubscriptionUpdater{} - - err := client.SubscribeAsync(resolve.NewContext(ctx), 1, GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, + handler(&client.Message{ + Type: client.MessageTypeData, + Payload: &client.ExecutionResult{ + Data: json.RawMessage(`{"foo":"bar"}`), }, - WsSubProtocol: ProtocolGraphQLTWS, - }, updater) - assert.NoError(t, err) - - // Wait for ping to be received before unsubscribing - select { - case <-pingReceived: - t.Log("Ping received successfully") - case <-time.After(2 * time.Second): - t.Log("Timed out waiting for ping, will unsubscribe anyway") - } - // don't unsubscribe immediately, give the server time to send a payload - select { - case <-payloadSend: - t.Log("Payload sent successfully") - case <-time.After(2 * time.Second): - t.Log("Timed out waiting for sent payload, will unsubscribe anyway") - } - - // Cleanup - client.Unsubscribe(1) - - // Wait for server to finish - select { - case <-serverDone: - // Server completed successfully - case <-time.After(5 * time.Second): - t.Fatal("Timed out waiting for server to complete") - } + }) - clientCancel() - serverCancel() + require.Len(t, updater.updates, 1) + assert.JSONEq(t, `{"data":{"foo":"bar"}}`, string(updater.updates[0])) + require.False(t, updater.done) + require.False(t, updater.completed) + require.Empty(t, updater.errors) }) - t.Run("client responds with pong when server sends ping", func(t *testing.T) { - t.Parallel() - - pongReceived := make(chan struct{}) - serverDone := make(chan struct{}) - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Sec-WebSocket-Protocol", ProtocolGraphQLTWS) - conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - Subprotocols: []string{ProtocolGraphQLTWS}, - }) - assert.NoError(t, err) - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "done") - }() - - ctx := context.Background() - - // Handle connection initialization - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - - // Handle subscription start - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - // Send initial data - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"initial data"}}}}`)) - assert.NoError(t, err) - - // Send ping message - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"ping"}`)) - assert.NoError(t, err) - - // Wait for pong response from client - msgType, data, err = conn.Read(ctx) - if err != nil { - t.Errorf("Error reading pong: %v", err) - return - } - - if string(data) == `{"type":"pong"}` { - assert.Equal(t, websocket.MessageText, msgType) - // Send another data message - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"after ping-pong"}}}}`)) - assert.NoError(t, err) - - close(pongReceived) - } - - // Wait for client to unsubscribe (complete message) - readTimeout := time.NewTimer(3 * time.Second) - defer readTimeout.Stop() - - readDone := make(chan struct{}) - go func() { - msgType, data, _ = conn.Read(ctx) - close(readDone) - }() - - select { - case <-readDone: - // Successfully read client message - case <-readTimeout.C: - // Timeout is fine, we're just waiting for unsubscribe - } - - close(serverDone) - })) - defer server.Close() - - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() + t.Run("complete and done are set for complete message", func(t *testing.T) { + updater := &testBridgeUpdater{} + handler := buildMessageHandler(updater, "DOWNSTREAM_SERVICE_ERROR") - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithLogger(logger()), - WithReadTimeout(100*time.Millisecond), - WithNetPollConfiguration(NetPollConfiguration{ - Enable: true, - TickInterval: 100 * time.Millisecond, - BufferSize: 10, - MaxEventWorkers: 2, - WaitForNumEvents: 1, - }), - ).(*subscriptionClient) + handler(&client.Message{Type: client.MessageTypeComplete}) - updater := &testSubscriptionUpdater{} - - err := client.SubscribeAsync(resolve.NewContext(ctx), 2, GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - WsSubProtocol: ProtocolGraphQLTWS, - }, updater) - assert.NoError(t, err) - - select { - case <-pongReceived: - t.Log("Server received pong successfully") - case <-time.After(2 * time.Second): - t.Log("Timed out waiting for pong in server, will unsubscribe anyway") - } - - // Verify we receive at least the initial data - updater.mux.Lock() - updatesCount := len(updater.updates) - firstUpdate := "" - if updatesCount > 0 { - firstUpdate = updater.updates[0] - } - updater.mux.Unlock() - - assert.GreaterOrEqual(t, updatesCount, 1) - assert.Equal(t, `{"data":{"messageAdded":{"text":"initial data"}}}`, firstUpdate) - - // Cleanup - client.Unsubscribe(2) - t.Log("client unsubscribed") - - // Wait for server to finish - select { - case <-serverDone: - // Server completed successfully - case <-time.After(2 * time.Second): - t.Fatal("Timed out waiting for server to complete") - } - - clientCancel() - serverCancel() + require.True(t, updater.done) + require.True(t, updater.completed) + require.Empty(t, updater.errors) }) } - -func TestClientClosesConnectionOnPingTimeout(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("Skipping test on Windows as it's not reliable") - } - - t.Parallel() - - serverDone := make(chan struct{}) - pingReceived := make(chan struct{}, 1) // Buffer 1 in case ping arrives slightly late - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Sec-WebSocket-Protocol", ProtocolGraphQLTWS) - conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - Subprotocols: []string{ProtocolGraphQLTWS}, - }) - assert.NoError(t, err) - defer func() { - _ = conn.Close(websocket.StatusNormalClosure, "server done") - close(serverDone) - }() - - ctx := context.Background() - - // Handle connection initialization - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - - // Handle subscription start - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - // Use regexp because client might generate different IDs - assert.Regexp(t, `{"id":".*","type":"subscribe","payload":{.*}}`, string(data)) - // More specific check for query - payloadQuery, err := jsonparser.GetString(data, "payload", "query") - if assert.NoError(t, err) { - assert.Equal(t, `subscription {messageAdded(roomName: "room"){text}}`, payloadQuery) - } - - // Send initial data - subID, err := jsonparser.GetString(data, "id") // Get the actual ID used by the client - if !assert.NoError(t, err) { - return - } - initialDataMsg := fmt.Sprintf(`{"id":"%s","type":"next","payload":{"data":{"messageAdded":{"text":"initial data"}}}}`, subID) - err = conn.Write(r.Context(), websocket.MessageText, []byte(initialDataMsg)) - assert.NoError(t, err) - - // Wait for ping, but DO NOT send pong - readCtx, cancelRead := context.WithTimeout(ctx, 5*time.Second) // Timeout for reading messages - defer cancelRead() - - hasReceivedPing := false - for !hasReceivedPing { - _, data, err = conn.Read(readCtx) - if err != nil { - t.Logf("Server read error (expected after client closes): %v", err) - // Expecting an error here eventually as the client should close the connection - assert.Error(t, err, "Server should encounter read error when client closes connection due to ping timeout") - // Signal that the server is done (connection closed) - close(serverDone) - return // Exit handler goroutine - } - - messageStr := string(data) - t.Logf("Server received: %s", messageStr) - if messageStr == `{"type":"ping"}` { - t.Log("Server received ping, NOT sending pong.") - hasReceivedPing = true - select { - case pingReceived <- struct{}{}: - default: // Avoid blocking if channel is full - } - } else if strings.Contains(messageStr, `"type":"complete"`) { - // Client might send complete before closing if test runs fast - t.Log("Server received complete from client.") - } else { - t.Logf("Server received unexpected message type: %s", messageStr) - } - } - - // Keep reading until the connection is closed by the client - for { - // Use a timeout context to make sure we don't hang indefinitely - readCtx, cancel := context.WithTimeout(ctx, 2*time.Second) - _, data, err = conn.Read(readCtx) - cancel() - - if err != nil { - t.Logf("Server read error after ping (expected): %v", err) - assert.Error(t, err, "Server should encounter read error after client closes connection") - return // Exit handler goroutine - } - - // Log any messages received before connection close - t.Logf("Server still receiving messages: %s", string(data)) - } - })) - defer server.Close() - - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithLogger(logger()), - WithPingInterval(500*time.Millisecond), - WithPingTimeout(100*time.Millisecond), - // Need netpoll enabled for ping/pong handling - WithNetPollConfiguration(NetPollConfiguration{ - Enable: true, - TickInterval: 100 * time.Millisecond, - BufferSize: 10, - MaxEventWorkers: 2, - WaitForNumEvents: 1, - }), - ).(*subscriptionClient) - - updater := &testSubscriptionUpdater{} - - // Use a unique ID for async subscription - subscriptionID := uint64(42) - - err := client.SubscribeAsync(resolve.NewContext(ctx), subscriptionID, GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - WsSubProtocol: ProtocolGraphQLTWS, - }, updater) - assert.NoError(t, err) - - // Wait for initial data - updater.AwaitUpdates(t, 3*time.Second, 1) - updater.mux.Lock() - updatesCount := len(updater.updates) - firstUpdate := "" - if updatesCount > 0 { - firstUpdate = updater.updates[0] - } - updater.mux.Unlock() - - require.Equal(t, 1, updatesCount, "Client should receive initial data") - assert.Equal(t, `{"data":{"messageAdded":{"text":"initial data"}}}`, firstUpdate) - - // Wait for the server to confirm it received a ping - select { - case <-pingReceived: - t.Log("Test confirmed server received ping.") - case <-time.After(3 * time.Second): // Should receive ping within ~pingInterval + read time - t.Fatal("Timed out waiting for server to receive ping") - } - - // Wait for server to signal it's done (connection closed by client) - select { - case <-serverDone: - t.Log("Server confirmed connection closed.") - // Success: server detected connection closure - case <-time.After(5 * time.Second): // Should happen within ~2*pingInterval + processing time - t.Fatal("Timed out waiting for server to detect connection closure") - } - - // Explicitly unsubscribe just in case, although it should be closed already - client.Unsubscribe(subscriptionID) - clientCancel() // Cancel client context - serverCancel() // Cancel server context (though serverDone should ensure it exited) -} - -func TestWebSocketUpgradeFailures(t *testing.T) { - t.Parallel() - - testCases := []struct { - name string - statusCode int - headers map[string]string - expectError bool - errorContains string - }{ - { - name: "HTTP 400 Bad Request", - statusCode: http.StatusBadRequest, - headers: map[string]string{"Content-Type": "text/plain"}, - expectError: true, - errorContains: "failed to upgrade connection", - }, - { - name: "HTTP 401 Unauthorized", - statusCode: http.StatusUnauthorized, - headers: map[string]string{"WWW-Authenticate": "Bearer"}, - expectError: true, - errorContains: "failed to upgrade connection", - }, - { - name: "HTTP 403 Forbidden", - statusCode: http.StatusForbidden, - headers: map[string]string{"Content-Type": "application/json"}, - expectError: true, - errorContains: "failed to upgrade connection", - }, - { - name: "HTTP 404 Not Found", - statusCode: http.StatusNotFound, - headers: map[string]string{"Content-Type": "text/html"}, - expectError: true, - errorContains: "failed to upgrade connection", - }, - { - name: "HTTP 500 Internal Server Error", - statusCode: http.StatusInternalServerError, - headers: map[string]string{"Content-Type": "application/json"}, - expectError: true, - errorContains: "failed to upgrade connection", - }, - { - name: "HTTP 502 Bad Gateway", - statusCode: http.StatusBadGateway, - headers: map[string]string{"Content-Type": "text/html"}, - expectError: true, - errorContains: "failed to upgrade connection", - }, - { - name: "HTTP 503 Service Unavailable", - statusCode: http.StatusServiceUnavailable, - headers: map[string]string{"Retry-After": "60"}, - expectError: true, - errorContains: "failed to upgrade connection", - }, - { - name: "HTTP 200 OK (wrong status for WebSocket)", - statusCode: http.StatusOK, - headers: map[string]string{"Content-Type": "application/json"}, - expectError: true, - errorContains: "failed to upgrade connection", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - for key, value := range tc.headers { - w.Header().Set(key, value) - } - w.WriteHeader(tc.statusCode) - _, _ = fmt.Fprintf(w, `{"error": "WebSocket upgrade failed", "status": %d}`, tc.statusCode) - })) - defer server.Close() - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithLogger(logger()), - ).(*subscriptionClient) - - wsURL := strings.Replace(server.URL, "http://", "ws://", 1) - - updater := &testSubscriptionUpdater{} - err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{ - URL: wsURL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - }, updater) - - if tc.expectError { - require.ErrorContains(t, err, tc.errorContains) - - // Verify the error is of the correct type - var upgradeErr *UpgradeRequestError - require.ErrorAs(t, err, &upgradeErr) - require.Equal(t, tc.statusCode, upgradeErr.StatusCode) - require.Equal(t, server.URL, upgradeErr.URL) - } else { - assert.NoError(t, err, "Expected no error for status code %d", tc.statusCode) - } - }) - } -} - -func TestInvalidWebSocketAcceptKey(t *testing.T) { - t.Parallel() - - testCases := []struct { - name string - acceptKeyHandler func(challengeKey string) string - expectError bool - errorContains string - }{ - { - name: "Missing Sec-WebSocket-Accept header", - acceptKeyHandler: func(challengeKey string) string { - return "" // Don't set the header - }, - expectError: true, - errorContains: "invalid Sec-WebSocket-Accept", - }, - { - name: "Malformed base64 Sec-WebSocket-Accept", - acceptKeyHandler: func(challengeKey string) string { - return "not-valid-base64!!!" - }, - expectError: true, - errorContains: "invalid Sec-WebSocket-Accept", - }, - { - name: "Correct length but wrong content", - acceptKeyHandler: func(challengeKey string) string { - // 20 bytes (not the SHA-1 of challengeKey+GUID) - return base64.StdEncoding.EncodeToString([]byte("12345678901234567890")) - }, - expectError: true, - errorContains: "invalid Sec-WebSocket-Accept", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - var receivedChallengeKey string - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - receivedChallengeKey = r.Header.Get("Sec-WebSocket-Key") - require.NotEmpty(t, receivedChallengeKey, "Challenge key should be present in request") - - w.Header().Set("Upgrade", "websocket") - w.Header().Set("Connection", "Upgrade") - w.Header().Set("Sec-WebSocket-Version", "13") - - acceptKey := tc.acceptKeyHandler(receivedChallengeKey) - if acceptKey != "" { - w.Header().Set("Sec-WebSocket-Accept", acceptKey) - } - // If acceptKey is empty, we don't set the header (simulating missing header) - - w.WriteHeader(http.StatusSwitchingProtocols) - - // Close the connection immediately to prevent hanging - // This simulates a server that sends 101 but then closes - if hijacker, ok := w.(http.Hijacker); ok { - conn, _, err := hijacker.Hijack() - if err == nil { - conn.Close() - } - } - })) - defer server.Close() - - // Create subscription client - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithLogger(logger()), - ).(*subscriptionClient) - - wsURL := strings.Replace(server.URL, "http://", "ws://", 1) - - updater := &testSubscriptionUpdater{} - err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{ - URL: wsURL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - }, updater) - - require.Error(t, err) - require.ErrorContains(t, err, tc.errorContains) - require.NotEmpty(t, receivedChallengeKey) - }) - } -} diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go deleted file mode 100644 index 8ebf5447cd..0000000000 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler.go +++ /dev/null @@ -1,379 +0,0 @@ -package graphql_datasource - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "net" - "sync/atomic" - "time" - - "github.com/buger/jsonparser" - "github.com/gobwas/ws" - "github.com/gobwas/ws/wsutil" - "github.com/jensneuse/abstractlogger" - - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" -) - -// gqlTWSConnectionHandler is responsible for handling a connection to an origin -// it is responsible for managing all subscriptions using the underlying WebSocket connection -// if all Subscriptions are complete or cancelled/unsubscribed the handler will terminate -type gqlTWSConnectionHandler struct { - // The underlying net.Conn. Only used for netPoll. Should not be used to shutdown the connection. - conn net.Conn - requestContext, engineContext context.Context - log abstractlogger.Logger - options GraphQLSubscriptionOptions - updater resolve.SubscriptionUpdater - lastPingSentUnix atomic.Int64 - pingTimeout time.Duration - shuttingDown atomic.Bool -} - -func (h *gqlTWSConnectionHandler) ServerClose() { - h.shuttingDown.Store(true) - - // Because the server closes the connection, we need to send a close frame to the event loop. - h.updater.Close(resolve.SubscriptionCloseKindDownstreamServiceError) - - _ = h.conn.SetWriteDeadline(time.Now().Add(writeTimeout)) - _ = ws.WriteFrame(h.conn, ws.MaskFrame(ws.NewCloseFrame(ws.NewCloseFrameBody(ws.StatusNormalClosure, "Normal Closure")))) - _ = h.conn.Close() -} - -// ClientClose is called when the client closes the connection. Is called when the trigger is shutdown with all subscriptions. -func (h *gqlTWSConnectionHandler) ClientClose() { - h.shuttingDown.Store(true) - - _ = h.conn.SetWriteDeadline(time.Now().Add(writeTimeout)) - _ = wsutil.WriteClientText(h.conn, []byte(`{"id":"1","type":"complete"}`)) - _ = h.conn.SetWriteDeadline(time.Now().Add(writeTimeout)) - _ = ws.WriteFrame(h.conn, ws.MaskFrame(ws.NewCloseFrame(ws.NewCloseFrameBody(ws.StatusNormalClosure, "Normal Closure")))) - _ = h.conn.Close() - -} - -func (h *gqlTWSConnectionHandler) Subscribe() error { - return h.subscribe() -} - -func (h *gqlTWSConnectionHandler) HandleMessage(data []byte) (done bool) { - messageType, err := jsonparser.GetString(data, "type") - if err != nil { - return false - } - switch messageType { - case messageTypePing: - h.handleMessageTypePing() - return false - case messageTypeNext: - h.handleMessageTypeNext(data) - return false - case messageTypeComplete: - h.handleMessageTypeComplete(data) - return true - case messageTypeError: - h.handleMessageTypeError(data) - return false - case messageTypeData, messageTypeConnectionError: - h.log.Error("Invalid subprotocol. The subprotocol should be set to graphql-transport-ws, but currently it is set to graphql-ws") - return true - case messageTypePong: - h.handleMessageTypePong() - return false - default: - h.log.Error("unknown message type", abstractlogger.String("type", messageType)) - return false - } -} - -func (h *gqlTWSConnectionHandler) NetConn() net.Conn { - return h.conn -} - -func newGQLTWSConnectionHandler(requestContext, engineContext context.Context, conn net.Conn, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater, l abstractlogger.Logger) *connection { - handler := &gqlTWSConnectionHandler{ - conn: conn, - requestContext: requestContext, - engineContext: engineContext, - log: l, - updater: updater, - options: options, - pingTimeout: options.pingTimeout, - } - - return &connection{ - handler: handler, - netConn: conn, - } -} - -func (h *gqlTWSConnectionHandler) StartBlocking() error { - readCtx, cancel := context.WithCancel(h.requestContext) - dataCh := make(chan []byte) - errCh := make(chan error) - - defer func() { - cancel() - h.unsubscribeAllAndCloseConn() - }() - - err := h.subscribe() - if err != nil { - return err - } - - pingTicker := time.NewTicker(h.options.pingInterval) - defer pingTicker.Stop() - - go h.readBlocking(readCtx, h.options.readTimeout, dataCh, errCh) - - for { - select { - case <-h.engineContext.Done(): - return h.engineContext.Err() - case <-readCtx.Done(): - return readCtx.Err() - case <-pingTicker.C: - h.Ping() - case err := <-errCh: - h.log.Error("gqlWSConnectionHandler.StartBlocking", abstractlogger.Error(err)) - h.broadcastErrorMessage(err) - return err - case data := <-dataCh: - messageType, err := jsonparser.GetString(data, "type") - if err != nil { - continue - } - - switch messageType { - case messageTypePong: - h.handleMessageTypePong() - continue - case messageTypePing: - h.handleMessageTypePing() - continue - case messageTypeNext: - h.handleMessageTypeNext(data) - continue - case messageTypeComplete: - h.handleMessageTypeComplete(data) - return nil - case messageTypeError: - h.handleMessageTypeError(data) - continue - case messageTypeConnectionKeepAlive: - continue - case messageTypeData, messageTypeConnectionError: - h.log.Error("Invalid subprotocol. The subprotocol should be set to graphql-transport-ws, but currently it is set to graphql-ws") - return errors.New("invalid subprotocol") - default: - h.log.Error("unknown message type", abstractlogger.String("type", messageType)) - continue - } - } - } -} - -func (h *gqlTWSConnectionHandler) unsubscribeAllAndCloseConn() { - h.unsubscribe() - _ = h.conn.SetWriteDeadline(time.Now().Add(writeTimeout)) - _ = ws.WriteFrame(h.conn, ws.MaskFrame(ws.NewCloseFrame(ws.NewCloseFrameBody(ws.StatusNormalClosure, "Normal Closure")))) - _ = h.conn.Close() -} - -func (h *gqlTWSConnectionHandler) unsubscribe() { - h.updater.Complete() - req := fmt.Sprintf(completeMessage, "1") - _ = h.conn.SetWriteDeadline(time.Now().Add(writeTimeout)) - err := wsutil.WriteClientText(h.conn, []byte(req)) - if err != nil { - h.log.Error("failed to write complete message", abstractlogger.Error(err)) - } -} - -// subscribe adds a new Subscription to the gqlTWSConnectionHandler and sends the subscribeMessage to the origin -func (h *gqlTWSConnectionHandler) subscribe() error { - graphQLBody, err := json.Marshal(h.options.Body) - if err != nil { - return err - } - subscribeRequest := fmt.Sprintf(subscribeMessage, "1", string(graphQLBody)) - if err = h.conn.SetWriteDeadline(time.Now().Add(writeTimeout)); err != nil { - return err - } - if err = wsutil.WriteClientText(h.conn, []byte(subscribeRequest)); err != nil { - return err - } - return nil -} - -func (h *gqlTWSConnectionHandler) broadcastErrorMessage(err error) { - errMsg := fmt.Sprintf(errorMessageTemplate, err) - h.updater.Update([]byte(errMsg)) -} - -func (h *gqlTWSConnectionHandler) handleMessageTypeComplete(data []byte) { - id, err := jsonparser.GetString(data, "id") - if err != nil { - return - } - if id != "1" { - return - } - h.updater.Complete() -} - -func (h *gqlTWSConnectionHandler) handleMessageTypeError(data []byte) { - id, err := jsonparser.GetString(data, "id") - if err != nil { - return - } - if id != "1" { - return - } - value, valueType, _, err := jsonparser.Get(data, "payload") - if err != nil { - h.log.Error( - "failed to get payload from error message", - abstractlogger.Error(err), - abstractlogger.ByteString("raw message", data), - ) - h.updater.Update([]byte(internalError)) - return - } - - switch valueType { - case jsonparser.Array: - response := []byte(`{}`) - response, err = jsonparser.Set(response, value, "errors") - if err != nil { - h.log.Error( - "failed to set errors response", - abstractlogger.Error(err), - abstractlogger.ByteString("raw message", value), - ) - h.updater.Update([]byte(internalError)) - return - } - h.updater.Update(response) - case jsonparser.Object: - response := []byte(`{"errors":[]}`) - response, err = jsonparser.Set(response, value, "errors", "[0]") - if err != nil { - h.updater.Update([]byte(internalError)) - return - } - h.updater.Update(response) - default: - h.updater.Update([]byte(internalError)) - } -} - -func (h *gqlTWSConnectionHandler) handleMessageTypePing() { - _ = h.conn.SetWriteDeadline(time.Now().Add(writeTimeout)) - err := wsutil.WriteClientText(h.conn, []byte(pongMessage)) - if err != nil { - h.log.Error("failed to write pong message", abstractlogger.Error(err)) - } -} - -func (h *gqlTWSConnectionHandler) handleMessageTypePong() { - // We received a pong message from the server. We can reset it. - h.lastPingSentUnix.Store(0) -} - -func (h *gqlTWSConnectionHandler) Ping() { - - // Do nothing if the connection is shutting down - if h.shuttingDown.Load() { - h.log.Debug("ping skipped. connection is shutting down") - return - } - - lastPingTimestamp := h.lastPingSentUnix.Load() - - // We will detect a dead connection not immediately but on the next ping interval. - if lastPingTimestamp > 0 { - pingTime := time.Unix(0, lastPingTimestamp) - duration := time.Since(pingTime) - - if duration > h.pingTimeout { - h.log.Error("ping timeout exceeded. Closing connection") - // We close the connection because the ping timeout has been exceeded, - // and we assume the connection is dead. ServerClose will send a done event to the client - // event loop to close all triggers and subscriptions - h.ServerClose() - } - - // We don't want to send another ping if one is already in flight - return - } - - // Start measuring the time since to write the message to the connection, including the IO time - h.lastPingSentUnix.Store(time.Now().UnixNano()) - - _ = h.conn.SetWriteDeadline(time.Now().Add(writeTimeout)) - err := wsutil.WriteClientText(h.conn, []byte(pingMessage)) - - if err != nil { - h.log.Error("failed to write ping message", abstractlogger.Error(err)) - return - } -} - -func (h *gqlTWSConnectionHandler) handleMessageTypeNext(data []byte) { - id, err := jsonparser.GetString(data, "id") - if err != nil { - return - } - if id != "1" { - return - } - value, _, _, err := jsonparser.Get(data, "payload") - if err != nil { - h.log.Error( - "failed to get payload from next message", - abstractlogger.Error(err), - ) - h.updater.Update([]byte(internalError)) - return - } - - h.updater.Update(value) -} - -// readBlocking is a dedicated loop running in a separate goroutine -// because the library "github.com/coder/websocket" doesn't allow reading with a context with Timeout -// we'll block forever on reading until the context of the gqlTWSConnectionHandler stops -func (h *gqlTWSConnectionHandler) readBlocking(ctx context.Context, readTimeout time.Duration, dataCh chan []byte, errCh chan error) { - netOpErr := &net.OpError{} - for { - _ = h.conn.SetReadDeadline(time.Now().Add(readTimeout)) - data, err := wsutil.ReadServerText(h.conn) - if err != nil { - if errors.As(err, &netOpErr) { - if netOpErr.Timeout() { - select { - case <-ctx.Done(): - return - default: - continue - } - } - } - select { - case errCh <- err: - case <-ctx.Done(): - } - return - } - select { - case dataCh <- data: - case <-ctx.Done(): - return - } - } -} diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler_test.go deleted file mode 100644 index cc69902198..0000000000 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_tws_handler_test.go +++ /dev/null @@ -1,312 +0,0 @@ -//go:build !race - -package graphql_datasource - -import ( - "context" - "net/http" - "net/http/httptest" - "testing" - "time" - - "github.com/coder/websocket" - "github.com/stretchr/testify/assert" - - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" - "github.com/wundergraph/graphql-go-tools/v2/pkg/testing/flags" -) - -func TestWebsocketSubscriptionClient_GQLTWS(t *testing.T) { - if flags.IsWindows { - t.Skip("skipping test on windows") - } - - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - Subprotocols: []string{"graphql-transport-ws"}, - }) - assert.NoError(t, err) - - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"second"}}}}`)) - assert.NoError(t, err) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"third"}}}}`)) - assert.NoError(t, err) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"complete"}`, string(data)) - close(serverDone) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - - serverCtx, serverCancel := context.WithCancel(context.Background()) - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithReadTimeout(time.Millisecond), - WithLogger(logger()), - ).(*subscriptionClient) - - updater := &testSubscriptionUpdater{} - go func() { - rCtx := resolve.NewContext(ctx) - err := client.Subscribe(rCtx, GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - }, updater) - assert.NoError(t, err) - }() - updater.AwaitUpdates(t, time.Second*5, 3) - assert.Equal(t, 3, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) - - clientCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second, time.Millisecond*10, "server did not close") - serverCancel() -} - -func TestWebsocketSubscriptionClientPing_GQLTWS(t *testing.T) { - if flags.IsWindows { - t.Skip("skipping test on windows") - } - - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - Subprotocols: []string{"graphql-transport-ws"}, - }) - assert.NoError(t, err) - - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"ping"}`)) - assert.NoError(t, err) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"pong"}`, string(data)) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"complete"}`, string(data)) - close(serverDone) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - - serverCtx, serverCancel := context.WithCancel(context.Background()) - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithReadTimeout(time.Millisecond), - WithLogger(logger()), - ).(*subscriptionClient) - - updater := &testSubscriptionUpdater{} - - go func() { - err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - }, updater) - assert.NoError(t, err) - }() - - updater.AwaitUpdates(t, time.Second, 1) - assert.Equal(t, 1, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - - clientCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second, time.Millisecond*10, "server did not close") - serverCancel() -} - -func TestWebsocketSubscriptionClientError_GQLTWS(t *testing.T) { - if flags.IsWindows { - t.Skip("skipping test on windows") - } - - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - Subprotocols: []string{"graphql-transport-ws"}, - }) - assert.NoError(t, err) - - msgType, data, err := conn.Read(r.Context()) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - - msgType, data, err = conn.Read(r.Context()) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"wrongQuery {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"payload":[{"message":"Unexpected Name \"wrongQuery\"","locations":[{"line":1,"column":1}],"extensions":{"code":"GRAPHQL_PARSE_FAILED"}}],"id":"1","type":"error"}`)) - assert.NoError(t, err) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1",type":"complete"}`)) - assert.NoError(t, err) - - close(serverDone) - })) - defer server.Close() - - serverCtx, serverCancel := context.WithCancel(context.Background()) - - clientCtx, clientCancel := context.WithCancel(context.Background()) - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithReadTimeout(time.Millisecond), - WithLogger(logger()), - ) - - updater := &testSubscriptionUpdater{} - - go func() { - err := client.Subscribe(resolve.NewContext(clientCtx), GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `wrongQuery {messageAdded(roomName: "room"){text}}`, - }, - }, updater) - assert.NoError(t, err) - }() - - updater.AwaitUpdates(t, time.Second, 1) - assert.Equal(t, 1, len(updater.updates)) - assert.Equal(t, `{"errors":[{"message":"Unexpected Name \"wrongQuery\"","locations":[{"line":1,"column":1}],"extensions":{"code":"GRAPHQL_PARSE_FAILED"}}]}`, updater.updates[0]) - - clientCancel() - updater.AwaitDone(t, time.Second) - - serverCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second, time.Millisecond*10, "server did not close") -} - -func TestWebsocketSubscriptionClient_GQLTWS_Upstream_Dies(t *testing.T) { - if flags.IsWindows { - t.Skip("skipping test on windows") - } - - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - Subprotocols: []string{"graphql-transport-ws"}, - }) - assert.NoError(t, err) - - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"id":"1","type":"subscribe","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"id":"1","type":"next","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - - <-serverCtx.Done() - })) - - // Wrap the listener to hijack the underlying TCP connection. - // Hijacking via http.ResponseWriter doesn't work because the WebSocket - // client already hijacks the connection before us. - wrappedListener := &listenerWrapper{ - listener: server.Listener, - } - server.Listener = wrappedListener - server.Start() - - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithReadTimeout(time.Second), - WithLogger(logger()), - ).(*subscriptionClient) - - updater := &testSubscriptionUpdater{} - - go func() { - err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - }, updater) - assert.NoError(t, err) - }() - - updater.AwaitUpdates(t, time.Second, 1) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - - // Kill the upstream here. We should get an End-of-File error. - assert.NoError(t, wrappedListener.underlyingConnection.Close()) - updater.AwaitUpdates(t, time.Second, 2) - assert.Equal(t, `{"errors":[{"message":"EOF"}]}`, updater.updates[1]) - - clientCancel() - serverCancel() -} diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go deleted file mode 100644 index 423080e2d7..0000000000 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler.go +++ /dev/null @@ -1,301 +0,0 @@ -package graphql_datasource - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "io" - "net" - "time" - - "github.com/buger/jsonparser" - "github.com/gobwas/ws" - "github.com/gobwas/ws/wsutil" - "github.com/jensneuse/abstractlogger" - - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" -) - -// gqlWSConnectionHandler is responsible for handling a connection to an origin -// it is responsible for managing all subscriptions using the underlying WebSocket connection -// if all Subscriptions are complete or cancelled/unsubscribed the handler will terminate -type gqlWSConnectionHandler struct { - // The underlying net.Conn. Only used for netPoll. Should not be used to shutdown the connection. - conn net.Conn - requestContext, engineContext context.Context - log abstractlogger.Logger - options GraphQLSubscriptionOptions - updater resolve.SubscriptionUpdater -} - -func (h *gqlWSConnectionHandler) ServerClose() { - // Because the server closes the connection, we need to send a close frame to the event loop. - h.updater.Close(resolve.SubscriptionCloseKindDownstreamServiceError) - - _ = h.conn.SetWriteDeadline(time.Now().Add(writeTimeout)) - _ = ws.WriteFrame(h.conn, ws.MaskFrame(ws.NewCloseFrame(ws.NewCloseFrameBody(ws.StatusNormalClosure, "Normal Closure")))) - _ = h.conn.Close() -} - -// ClientClose is called when the client closes the connection. Is called when the trigger is shutdown with all subscriptions. -func (h *gqlWSConnectionHandler) ClientClose() { - _ = h.conn.SetWriteDeadline(time.Now().Add(writeTimeout)) - _ = wsutil.WriteClientText(h.conn, []byte(`{"type":"stop","id":"1"}`)) - _ = h.conn.SetWriteDeadline(time.Now().Add(writeTimeout)) - _ = ws.WriteFrame(h.conn, ws.MaskFrame(ws.NewCloseFrame(ws.NewCloseFrameBody(ws.StatusNormalClosure, "Normal Closure")))) - _ = h.conn.Close() -} - -func (h *gqlWSConnectionHandler) Subscribe() error { - return h.subscribe() -} - -func (h *gqlWSConnectionHandler) HandleMessage(data []byte) (done bool) { - messageType, err := jsonparser.GetString(data, "type") - if err != nil { - return false - } - switch messageType { - case messageTypeConnectionKeepAlive: - return false - case messageTypeData: - h.handleMessageTypeData(data) - return false - case messageTypeComplete: - h.handleMessageTypeComplete(data) - return true - case messageTypeConnectionError: - h.handleMessageTypeConnectionError() - return true - case messageTypeError: - h.handleMessageTypeError(data) - return false - default: - return false - } -} - -func (h *gqlWSConnectionHandler) NetConn() net.Conn { - return h.conn -} - -func newGQLWSConnectionHandler(requestContext, engineContext context.Context, conn net.Conn, options GraphQLSubscriptionOptions, updater resolve.SubscriptionUpdater, log abstractlogger.Logger) *connection { - handler := &gqlWSConnectionHandler{ - conn: conn, - requestContext: requestContext, - engineContext: engineContext, - log: log, - updater: updater, - options: options, - } - return &connection{ - handler: handler, - netConn: conn, - } -} - -// StartBlocking starts the single threaded event loop of the handler -// if the global context returns or the websocket connection is terminated, it will stop -func (h *gqlWSConnectionHandler) StartBlocking() error { - dataCh := make(chan []byte) - errCh := make(chan error) - readCtx, cancel := context.WithCancel(h.requestContext) - - defer func() { - cancel() - h.unsubscribeAllAndCloseConn() - }() - - err := h.subscribe() - if err != nil { - return err - } - - go h.readBlocking(readCtx, h.options.readTimeout, dataCh, errCh) - - for { - select { - case <-h.engineContext.Done(): - return h.engineContext.Err() - case <-readCtx.Done(): - return readCtx.Err() - case err := <-errCh: - if !errors.Is(err, context.Canceled) && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { - h.log.Error("gqlWSConnectionHandler.StartBlocking", abstractlogger.Error(err)) - } - h.broadcastErrorMessage(err) - return err - case data := <-dataCh: - messageType, err := jsonparser.GetString(data, "type") - if err != nil { - continue - } - switch messageType { - case messageTypeData: - h.handleMessageTypeData(data) - continue - case messageTypeComplete: - h.handleMessageTypeComplete(data) - return nil - case messageTypeConnectionError: - h.handleMessageTypeConnectionError() - return nil - case messageTypeError: - h.handleMessageTypeError(data) - continue - case messageTypeConnectionKeepAlive: - continue - case messageTypePing, messageTypeNext: - h.log.Error("Invalid subprotocol. The subprotocol should be set to graphql-ws, but currently it is set to graphql-transport-ws") - return errors.New("invalid subprotocol") - default: - h.log.Error("unknown message type", abstractlogger.String("type", messageType)) - continue - } - } - } -} - -// readBlocking is a dedicated loop running in a separate goroutine -// because the library "github.com/coder/websocket" doesn't allow reading with a context with Timeout -// we'll block forever on reading until the context of the gqlWSConnectionHandler stops -func (h *gqlWSConnectionHandler) readBlocking(ctx context.Context, readTimeout time.Duration, dataCh chan []byte, errCh chan error) { - netOpErr := &net.OpError{} - for { - _ = h.conn.SetReadDeadline(time.Now().Add(readTimeout)) - data, err := wsutil.ReadServerText(h.conn) - if err != nil { - if errors.As(err, &netOpErr) { - if netOpErr.Timeout() { - select { - case <-ctx.Done(): - return - default: - continue - } - } - } - select { - case errCh <- err: - case <-ctx.Done(): - } - return - } - select { - case dataCh <- data: - case <-ctx.Done(): - return - } - } -} - -func (h *gqlWSConnectionHandler) unsubscribeAllAndCloseConn() { - h.unsubscribe() - _ = h.conn.SetWriteDeadline(time.Now().Add(writeTimeout)) - _ = ws.WriteFrame(h.conn, ws.MaskFrame(ws.NewCloseFrame(ws.NewCloseFrameBody(ws.StatusNormalClosure, "Normal Closure")))) - _ = h.conn.Close() -} - -func (h *gqlWSConnectionHandler) Ping() { - // This protocol has no client side ping/pong mechanism. The server send a ka message to understand - // if the connection is still alive. The client only acknowledges the retrieval of the ka message - // by consuming it in the readBlocking loop. - - // TODO We could check if we receive a ka message in a certain time frame and if not, we could close the connection - // However, because we don't send something to the server, we can't verify if the connection is still healthy and - // responsive from both sides. -} - -// subscribe adds a new Subscription to the gqlWSConnectionHandler and sends the startMessage to the origin -func (h *gqlWSConnectionHandler) subscribe() error { - graphQLBody, err := json.Marshal(h.options.Body) - if err != nil { - return err - } - startRequest := fmt.Sprintf(startMessage, "1", string(graphQLBody)) - _ = h.conn.SetWriteDeadline(time.Now().Add(writeTimeout)) - err = wsutil.WriteClientText(h.conn, []byte(startRequest)) - if err != nil { - return err - } - return nil -} - -func (h *gqlWSConnectionHandler) handleMessageTypeData(data []byte) { - id, err := jsonparser.GetString(data, "id") - if err != nil { - return - } - if id != "1" { - return - } - payload, _, _, err := jsonparser.Get(data, "payload") - if err != nil { - return - } - - h.updater.Update(payload) -} - -func (h *gqlWSConnectionHandler) handleMessageTypeConnectionError() { - h.updater.Update([]byte(connectionError)) -} - -func (h *gqlWSConnectionHandler) broadcastErrorMessage(err error) { - errMsg := fmt.Sprintf(errorMessageTemplate, err) - h.updater.Update([]byte(errMsg)) -} - -func (h *gqlWSConnectionHandler) handleMessageTypeComplete(data []byte) { - id, err := jsonparser.GetString(data, "id") - if err != nil { - return - } - if id != "1" { - return - } - h.updater.Complete() -} - -func (h *gqlWSConnectionHandler) handleMessageTypeError(data []byte) { - id, err := jsonparser.GetString(data, "id") - if err != nil { - return - } - if id != "1" { - return - } - value, valueType, _, err := jsonparser.Get(data, "payload") - if err != nil { - h.updater.Update([]byte(internalError)) - return - } - switch valueType { - case jsonparser.Array: - response := []byte(`{}`) - response, err = jsonparser.Set(response, value, "errors") - if err != nil { - h.updater.Update([]byte(internalError)) - return - } - h.updater.Update(response) - case jsonparser.Object: - response := []byte(`{"errors":[]}`) - response, err = jsonparser.Set(response, value, "errors", "[0]") - if err != nil { - h.updater.Update([]byte(internalError)) - return - } - h.updater.Update(response) - default: - h.updater.Update([]byte(internalError)) - } -} - -func (h *gqlWSConnectionHandler) unsubscribe() { - h.updater.Complete() - stopRequest := fmt.Sprintf(stopMessage, "1") - _ = h.conn.SetWriteDeadline(time.Now().Add(writeTimeout)) - _ = wsutil.WriteClientText(h.conn, []byte(stopRequest)) -} diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler_test.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler_test.go deleted file mode 100644 index eddc47253c..0000000000 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_handler_test.go +++ /dev/null @@ -1,371 +0,0 @@ -//go:build !race - -package graphql_datasource - -import ( - "context" - "net" - "net/http" - "net/http/httptest" - "testing" - "time" - - "github.com/coder/websocket" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" - "github.com/wundergraph/graphql-go-tools/v2/pkg/testing/flags" -) - -func TestWebSocketSubscriptionClientInitIncludeKA_GQLWS(t *testing.T) { - if flags.IsWindows { - t.Skip("skipping test on windows") - } - - serverDone := make(chan struct{}) - assertion := require.New(t) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assertion.NoError(err) - - // write "ka" every second - go func() { - for { - err := conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"ka"}`)) - if err != nil { - break - } - time.Sleep(time.Second) - } - }() - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assertion.NoError(err) - assertion.Equal(websocket.MessageText, msgType) - assertion.Equal(`{"type":"connection_init"}`, string(data)) - - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assertion.NoError(err) - - msgType, data, err = conn.Read(ctx) - assertion.NoError(err) - assertion.Equal(websocket.MessageText, msgType) - assertion.Equal(`{"type":"start","id":"1","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"data","id":"1","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assertion.NoError(err) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"data","id":"1","payload":{"data":{"messageAdded":{"text":"second"}}}}`)) - assertion.NoError(err) - assertion.NoError(err) - - msgType, data, err = conn.Read(ctx) - assertion.NoError(err) - assertion.Equal(websocket.MessageText, msgType) - assertion.Equal(`{"type":"stop","id":"1"}`, string(data)) - close(serverDone) - })) - - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - defer clientCancel() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithReadTimeout(time.Millisecond), - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} - go func() { - err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - }, updater) - assertion.NoError(err) - }() - updater.AwaitUpdates(t, time.Second, 2) - assertion.Equal(`{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - assertion.Equal(`{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) - clientCancel() - assertion.Eventuallyf(func() bool { - <-serverDone - return true - }, time.Second, time.Millisecond*10, "server did not close") - serverCancel() -} - -func TestWebsocketSubscriptionClient_GQLWS(t *testing.T) { - if flags.IsWindows { - t.Skip("skipping test on windows") - } - - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"start","id":"1","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"data","id":"1","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"data","id":"1","payload":{"data":{"messageAdded":{"text":"second"}}}}`)) - assert.NoError(t, err) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"data","id":"1","payload":{"data":{"messageAdded":{"text":"third"}}}}`)) - assert.NoError(t, err) - - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"stop","id":"1"}`, string(data)) - close(serverDone) - })) - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithReadTimeout(time.Millisecond), - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} - go func() { - rCtx := resolve.NewContext(ctx) - err := client.Subscribe(rCtx, GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - }, updater) - assert.NoError(t, err) - }() - updater.AwaitUpdates(t, time.Second*5, 3) - assert.Equal(t, 3, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"second"}}}`, updater.updates[1]) - assert.Equal(t, `{"data":{"messageAdded":{"text":"third"}}}`, updater.updates[2]) - clientCancel() - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second, time.Millisecond*10, "server did not close") - serverCancel() -} - -func TestWebsocketSubscriptionClientErrorArray(t *testing.T) { - if flags.IsWindows { - t.Skip("skipping test on windows") - } - - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - msgType, data, err := conn.Read(r.Context()) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - msgType, data, err = conn.Read(r.Context()) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"start","id":"1","payload":{"query":"subscription {messageAdded(roomNam: \"room\"){text}}"}}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"error","id":"1","payload":[{"message":"error"},{"message":"error"}]}`)) - assert.NoError(t, err) - msgType, data, err = conn.Read(r.Context()) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"stop","id":"1"}`, string(data)) - _, _, err = conn.Read(r.Context()) - assert.NotNil(t, err) - close(serverDone) - })) - defer server.Close() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - clientCtx, clientCancel := context.WithCancel(context.Background()) - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithReadTimeout(time.Millisecond), - WithLogger(logger()), - ) - updater := &testSubscriptionUpdater{} - go func() { - err := client.Subscribe(resolve.NewContext(clientCtx), GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomNam: "room"){text}}`, - }, - }, updater) - assert.NoError(t, err) - }() - updater.AwaitUpdates(t, time.Second, 1) - assert.Equal(t, `{"errors":[{"message":"error"},{"message":"error"}]}`, updater.updates[0]) - clientCancel() - updater.AwaitDone(t, time.Second) - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second, time.Millisecond*10, "server did not close") -} - -func TestWebsocketSubscriptionClientErrorObject(t *testing.T) { - if flags.IsWindows { - t.Skip("skipping test on windows") - } - - serverDone := make(chan struct{}) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - msgType, data, err := conn.Read(r.Context()) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - msgType, data, err = conn.Read(r.Context()) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"start","id":"1","payload":{"query":"subscription {messageAdded(roomNam: \"room\"){text}}"}}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"error","id":"1","payload":{"message":"error"}}`)) - assert.NoError(t, err) - msgType, data, err = conn.Read(r.Context()) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"stop","id":"1"}`, string(data)) - _, _, err = conn.Read(r.Context()) - assert.NotNil(t, err) - close(serverDone) - })) - defer server.Close() - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - clientCtx, clientCancel := context.WithCancel(context.Background()) - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithReadTimeout(time.Millisecond), - WithLogger(logger()), - ) - updater := &testSubscriptionUpdater{} - go func() { - err := client.Subscribe(resolve.NewContext(clientCtx), GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomNam: "room"){text}}`, - }, - }, updater) - assert.NoError(t, err) - }() - updater.AwaitUpdates(t, time.Second, 1) - assert.Equal(t, 1, len(updater.updates)) - assert.Equal(t, `{"errors":[{"message":"error"}]}`, updater.updates[0]) - clientCancel() - updater.AwaitDone(t, time.Second) - assert.Eventuallyf(t, func() bool { - <-serverDone - return true - }, time.Second, time.Millisecond*10, "server did not close") -} - -func TestWebsocketSubscriptionClient_GQLWS_Upstream_Dies(t *testing.T) { - if flags.IsWindows { - t.Skip("skipping test on windows") - } - - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.Accept(w, r, nil) - assert.NoError(t, err) - ctx := context.Background() - msgType, data, err := conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"connection_init"}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"connection_ack"}`)) - assert.NoError(t, err) - msgType, data, err = conn.Read(ctx) - assert.NoError(t, err) - assert.Equal(t, websocket.MessageText, msgType) - assert.Equal(t, `{"type":"start","id":"1","payload":{"query":"subscription {messageAdded(roomName: \"room\"){text}}"}}`, string(data)) - err = conn.Write(r.Context(), websocket.MessageText, []byte(`{"type":"data","id":"1","payload":{"data":{"messageAdded":{"text":"first"}}}}`)) - assert.NoError(t, err) - - <-serverCtx.Done() - })) - - // Wrap the listener to hijack the underlying TCP connection. - // Hijacking via http.ResponseWriter doesn't work because the WebSocket - // client already hijacks the connection before us. - wrappedListener := &listenerWrapper{ - listener: server.Listener, - } - server.Listener = wrappedListener - server.Start() - - defer server.Close() - ctx, clientCancel := context.WithCancel(context.Background()) - - // Start a new GQL subscription and exchange some messages. - client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx, - WithReadTimeout(time.Second), - WithLogger(logger()), - ).(*subscriptionClient) - updater := &testSubscriptionUpdater{} - go func() { - err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{ - URL: server.URL, - Body: GraphQLBody{ - Query: `subscription {messageAdded(roomName: "room"){text}}`, - }, - }, updater) - assert.NoError(t, err) - }() - updater.AwaitUpdates(t, time.Second, 1) - assert.Equal(t, 1, len(updater.updates)) - assert.Equal(t, `{"data":{"messageAdded":{"text":"first"}}}`, updater.updates[0]) - - // Kill the upstream here. We should get an End-of-File error. - assert.NoError(t, wrappedListener.underlyingConnection.Close()) - updater.AwaitUpdates(t, time.Second, 2) - assert.Equal(t, `{"errors":[{"message":"EOF"}]}`, updater.updates[1]) - - serverCancel() - clientCancel() -} - -type listenerWrapper struct { - listener net.Listener - underlyingConnection net.Conn -} - -func (l *listenerWrapper) Accept() (net.Conn, error) { - conn, err := l.listener.Accept() - if err != nil { - return nil, err - } - l.underlyingConnection = conn - return l.underlyingConnection, nil -} - -func (l *listenerWrapper) Close() error { - return l.listener.Close() -} - -func (l *listenerWrapper) Addr() net.Addr { - return l.listener.Addr() -} - -var _ net.Listener = (*listenerWrapper)(nil) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_proto_types.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_proto_types.go deleted file mode 100644 index 16b7b98d04..0000000000 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_ws_proto_types.go +++ /dev/null @@ -1,47 +0,0 @@ -package graphql_datasource - -// common -var ( - connectionInitMessage = []byte(`{"type":"connection_init"}`) -) - -const ( - messageTypeConnectionAck = "connection_ack" - messageTypeComplete = "complete" - messageTypeError = "error" -) - -// websocket sub-protocol: -// https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md -const ( - ProtocolGraphQLWS = "graphql-ws" - - startMessage = `{"type":"start","id":"%s","payload":%s}` - stopMessage = `{"type":"stop","id":"%s"}` - - messageTypeConnectionKeepAlive = "ka" - messageTypeData = "data" - messageTypeConnectionError = "connection_error" -) - -// websocket sub-protocol: -// https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md -const ( - ProtocolGraphQLTWS = "graphql-transport-ws" - - subscribeMessage = `{"id":"%s","type":"subscribe","payload":%s}` - pongMessage = `{"type":"pong"}` - pingMessage = `{"type":"ping"}` - completeMessage = `{"id":"%s","type":"complete"}` - - messageTypePing = "ping" - messageTypePong = "pong" - messageTypeNext = "next" -) - -// internal -const ( - internalError = `{"errors":[{"message":"internal error"}]}` - connectionError = `{"errors":[{"message":"connection error"}]}` - errorMessageTemplate = `{"errors":[{"message":"%s"}]}` -) diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/client.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/client.go new file mode 100644 index 0000000000..6b844621ee --- /dev/null +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/client.go @@ -0,0 +1,118 @@ +package client + +import ( + "context" + "errors" + "fmt" + "net/http" + "time" + + "github.com/jensneuse/abstractlogger" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport" +) + +// ErrClientClosed is returned when Subscribe is called after the client's context has been canceled. +var ErrClientClosed = errors.New("client closed") + +const ( + defaultReadLimit = 1 << 20 // 1MiB + defaultAckTimeout = 30 * time.Second + defaultWriteTimeout = 5 * time.Second +) + +type Client struct { + ctx context.Context + log abstractlogger.Logger + + ws *transport.WSTransport + sse *transport.SSETransport +} + +// Stats contains client statistics. +type Stats struct { + WSConns int // active WebSocket connections + SSEConns int // active SSE connections +} + +// Config holds the client configuration. +type Config struct { + UpgradeClient *http.Client + StreamingClient *http.Client + Logger abstractlogger.Logger + PingInterval time.Duration + PingTimeout time.Duration + AckTimeout time.Duration + WriteTimeout time.Duration + ReadLimit int64 + WSIdleTimeout time.Duration +} + +// New creates a new subscription client with the provided config. +func New(ctx context.Context, cfg Config) *Client { + if cfg.UpgradeClient == nil { + cfg.UpgradeClient = http.DefaultClient + } + if cfg.StreamingClient == nil { + cfg.StreamingClient = http.DefaultClient + } + if cfg.Logger == nil { + cfg.Logger = abstractlogger.NoopLogger + } + if cfg.ReadLimit <= 0 { + cfg.ReadLimit = defaultReadLimit + } + if cfg.AckTimeout <= 0 { + cfg.AckTimeout = defaultAckTimeout + } + if cfg.WriteTimeout <= 0 { + cfg.WriteTimeout = defaultWriteTimeout + } + + c := &Client{ + ctx: ctx, + log: cfg.Logger, + + ws: transport.NewWSTransport(ctx, transport.WSTransportOptions{ + UpgradeClient: cfg.UpgradeClient, + Logger: cfg.Logger, + PingInterval: cfg.PingInterval, + PingTimeout: cfg.PingTimeout, + AckTimeout: cfg.AckTimeout, + WriteTimeout: cfg.WriteTimeout, + ReadLimit: cfg.ReadLimit, + IdleTimeout: cfg.WSIdleTimeout, + }), + sse: transport.NewSSETransport(ctx, cfg.StreamingClient, cfg.Logger), + } + + c.log.Debug("subscriptionClient.New", abstractlogger.String("status", "initialized")) + + return c +} + +// Subscribe creates a new upstream via the appropriate transport. +func (c *Client) Subscribe(ctx context.Context, req *common.Request, opts common.Options, handler common.Handler) (func(), error) { + if c.ctx.Err() != nil { + return nil, ErrClientClosed + } + + switch opts.Transport { + case common.TransportSSE: + return c.sse.Subscribe(ctx, req, opts, handler) + case common.TransportWS: + return c.ws.Subscribe(ctx, req, opts, handler) + default: + return nil, fmt.Errorf("unsupported transport: %q", opts.Transport) + } +} + +// Stats returns client statistics. +func (c *Client) Stats() Stats { + stats := Stats{ + WSConns: c.ws.ConnCount(), + SSEConns: c.sse.ConnCount(), + } + return stats +} diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/client_test.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/client_test.go new file mode 100644 index 0000000000..be71af528b --- /dev/null +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/client_test.go @@ -0,0 +1,298 @@ +package client + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/coder/websocket" + "github.com/coder/websocket/wsjson" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/goleak" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common" +) + +func TestClient(t *testing.T) { + t.Run("new creates client with transports", func(t *testing.T) { + c := New(t.Context(), Config{}) + + assert.NotNil(t, c.ws) + assert.NotNil(t, c.sse) + }) + + t.Run("context cancellation is idempotent", func(t *testing.T) { + assert.NotPanics(t, func() { + ctx, cancel := context.WithCancel(context.Background()) + _ = New(ctx, Config{}) + cancel() + cancel() // should not panic + }) + }) + + t.Run("subscribe fails after context cancelled", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + c := New(ctx, Config{}) + cancel() + + _, err := c.Subscribe(t.Context(), &Request{Query: "subscription { a }"}, Options{ + Endpoint: "ws://localhost/graphql", + }, func(_ *common.Message) {}) + + assert.Equal(t, ErrClientClosed, err) + }) +} + +func TestClient_SubscriberDrain(t *testing.T) { + // These tests verify that cancelling all subscriptions properly cleans up all goroutines. + // Connections close themselves when their last subscriber is removed. + + t.Run("subscriber drain cleans up", func(t *testing.T) { + defer goleak.VerifyNone(t, goleak.IgnoreAnyFunction("net/http/httptest.(*Server).goServe.func1")) + + server := newTestWSServer(t) + + c := New(t.Context(), Config{}) + + ch := make(chan *common.Message, 1) + subCancel, err := c.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + }, func(msg *common.Message) { + ch <- msg + }) + require.NoError(t, err) + + // subscription is working + select { + case msg := <-ch: + require.NotNil(t, msg.Payload) + case <-time.After(time.Second): + t.Fatal("timeout waiting for message") + } + + subCancel() + + // Give ReadLoop goroutine time to exit after connection close + assert.Eventually(t, func() bool { + return c.Stats().WSConns == 0 + }, time.Second, 10*time.Millisecond) + }) + + t.Run("subscriber drain cleans up multiple connections", func(t *testing.T) { + defer goleak.VerifyNone(t, goleak.IgnoreAnyFunction("net/http/httptest.(*Server).goServe.func1")) + + server := newTestWSServer(t) + + c := New(t.Context(), Config{}) + + cancels := make([]func(), 3) + // Start subscriptions with different headers (forces multiple connections) + for i := range 3 { + headers := http.Header{"X-Request-ID": []string{string(rune('A' + i))}} + ch := make(chan *common.Message, 1) + subCancel, err := c.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + Headers: headers, + }, func(msg *common.Message) { + ch <- msg + }) + require.NoError(t, err) + cancels[i] = subCancel + + // Drain first message + select { + case <-ch: + case <-time.After(time.Second): + t.Fatal("timeout") + } + } + + // Should have 3 connections + stats := c.Stats() + require.Equal(t, 3, stats.WSConns) + + for _, fn := range cancels { + fn() + } + + assert.Eventually(t, func() bool { + return c.Stats().WSConns == 0 + }, time.Second, 10*time.Millisecond) + }) +} + +func TestClient_CancelSendsComplete(t *testing.T) { + t.Run("cancel sends complete to server", func(t *testing.T) { + defer goleak.VerifyNone(t, + goleak.IgnoreAnyFunction("net/http/httptest.(*Server).goServe.func1"), + ) + + completeReceived := make(chan string, 1) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + Subprotocols: []string{"graphql-transport-ws"}, + }) + if err != nil { + return + } + defer func() { + _ = conn.CloseNow() + }() + + ctx := r.Context() + + // Handle connection_init + var initMsg map[string]any + if err := wsjson.Read(ctx, conn, &initMsg); err != nil { + return + } + if initMsg["type"] != "connection_init" { + return + } + _ = wsjson.Write(ctx, conn, map[string]string{"type": "connection_ack"}) + + // Read subscribe + var subMsg map[string]any + if err := wsjson.Read(ctx, conn, &subMsg); err != nil { + return + } + if subMsg["type"] != "subscribe" { + return + } + subID := subMsg["id"].(string) + + // Send a next message + _ = wsjson.Write(ctx, conn, map[string]any{ + "id": subID, + "type": "next", + "payload": map[string]any{"data": map[string]any{"value": 1}}, + }) + + // Wait for complete message from client + var completeMsg map[string]any + if err := wsjson.Read(ctx, conn, &completeMsg); err != nil { + return + } + if completeMsg["type"] == "complete" { + completeReceived <- completeMsg["id"].(string) + } + })) + t.Cleanup(server.Close) + + c := New(t.Context(), Config{}) + + ch := make(chan *common.Message, 1) + cancel, err := c.Subscribe(t.Context(), &common.Request{ + Query: "subscription { test }", + }, common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + }, func(msg *common.Message) { + ch <- msg + }) + require.NoError(t, err) + + // Wait for first message + select { + case msg := <-ch: + require.NotNil(t, msg.Payload) + case <-time.After(time.Second): + t.Fatal("timeout waiting for message") + } + + // Cancel the subscription - this should send complete to server + cancel() + + // Verify server received complete + select { + case id := <-completeReceived: + assert.NotEmpty(t, id, "complete message should have subscription ID") + case <-time.After(time.Second): + t.Fatal("timeout waiting for complete message on server") + } + + // Wait for ReadLoop goroutine to exit after connection close + assert.Eventually(t, func() bool { + return c.Stats().WSConns == 0 + }, time.Second, 10*time.Millisecond) + }) +} + +// Test helper: creates a WebSocket server that sends periodic messages +func newTestWSServer(t *testing.T) *httptest.Server { + t.Helper() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + Subprotocols: []string{"graphql-transport-ws"}, + }) + if err != nil { + return + } + defer func() { + _ = conn.CloseNow() + }() + + ctx, cancel := context.WithCancel(r.Context()) + defer cancel() + + // Handle connection_init + var initMsg map[string]any + if err := wsjson.Read(ctx, conn, &initMsg); err != nil { + return + } + if initMsg["type"] != "connection_init" { + return + } + _ = wsjson.Write(ctx, conn, map[string]string{"type": "connection_ack"}) + + // Read messages in background, cancel context when connection closes + go func() { + defer cancel() + for { + var msg map[string]any + if err := wsjson.Read(ctx, conn, &msg); err != nil { + return + } + // Handle subscribe by sending first message + if msg["type"] == "subscribe" { + _ = wsjson.Write(ctx, conn, map[string]any{ + "id": msg["id"], + "type": "next", + "payload": map[string]any{"data": map[string]any{"value": 1}}, + }) + } + } + }() + + // Send periodic messages until context cancelled + ticker := time.NewTicker(100 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + // Ignore write errors (connection may be closed) + _ = wsjson.Write(ctx, conn, map[string]any{ + "type": "next", + "payload": map[string]any{"data": map[string]any{"value": 1}}, + }) + } + } + })) + + t.Cleanup(server.Close) + return server +} diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common/message.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common/message.go new file mode 100644 index 0000000000..d99f12a470 --- /dev/null +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common/message.go @@ -0,0 +1,50 @@ +package common + +import ( + "encoding/json" + "errors" +) + +var ErrConnectionClosed = errors.New("connection closed") + +// MessageType identifies the kind of message delivered on a subscription channel. +type MessageType uint8 + +const ( + MessageTypeUnknown MessageType = iota + MessageTypeData // normal data payload + MessageTypeError // GraphQL-level error from server (has Payload) + MessageTypeComplete // subscription completed normally + MessageTypeConnectionError // connection-level error (has Err) +) + +// IsTerminal reports whether the message type signals end-of-stream. +func (t MessageType) IsTerminal() bool { + return t == MessageTypeError || t == MessageTypeComplete || t == MessageTypeConnectionError +} + +// Message is a single subscription event delivered to a Handler. +type Message struct { + Type MessageType + Payload *ExecutionResult + Err error // only set when Type == MessageTypeConnectionError +} + +// Handler receives subscription messages. It is called synchronously on the +// transport's read goroutine; a slow handler blocks message delivery. +type Handler func(msg *Message) + +// ExecutionResult is the GraphQL response payload for data and error messages. +type ExecutionResult struct { + Data json.RawMessage `json:"data,omitempty"` + Errors json.RawMessage `json:"errors,omitempty"` + Extensions json.RawMessage `json:"extensions,omitempty"` +} + +// Request is a GraphQL operation sent to the server when subscribing. +type Request struct { + Query string `json:"query"` + OperationName string `json:"operationName,omitempty"` + Variables json.RawMessage `json:"variables,omitempty"` + Extensions json.RawMessage `json:"extensions,omitempty"` +} diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common/options.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common/options.go new file mode 100644 index 0000000000..fe0eb31c23 --- /dev/null +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common/options.go @@ -0,0 +1,57 @@ +package common + +import ( + "net/http" +) + +// TransportType selects the subscription transport mechanism. +type TransportType string + +const ( + TransportWS TransportType = "ws" // WebSocket connection + TransportSSE TransportType = "sse" // Server-Sent Events over HTTP +) + +// WSSubprotocol selects the GraphQL-over-WebSocket subprotocol. +type WSSubprotocol string + +const ( + SubprotocolAuto WSSubprotocol = "" // Auto, negotiated with the server + SubprotocolGraphQLTransportWS WSSubprotocol = "graphql-transport-ws" // Modern protocol from The Guild + SubprotocolGraphQLWS WSSubprotocol = "graphql-ws" // Legacy Apollo protocol, deprecated +) + +// Subprotocols returns the WebSocket subprotocol strings to offer during the upgrade handshake. +func (s WSSubprotocol) Subprotocols() []string { + switch s { + case SubprotocolAuto: + return []string{"graphql-transport-ws", "graphql-ws"} + case SubprotocolGraphQLTransportWS: + return []string{"graphql-transport-ws"} + case SubprotocolGraphQLWS: + return []string{"graphql-ws"} + default: + return nil + } +} + +type SSEMethod string + +const ( + SSEMethodPOST SSEMethod = "POST" // POST with JSON body (graphql-sse spec) + SSEMethodGET SSEMethod = "GET" // GET with query parameters (traditional SSE) +) + +// Options configures a single subscription request (endpoint, headers, transport selection). +type Options struct { + Endpoint string + Headers http.Header + InitPayload map[string]any + Transport TransportType + + // Only affects the WebSocket transport. + WSSubprotocol WSSubprotocol + + // Only affects the SSE transport. + SSEMethod SSEMethod +} diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/exports.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/exports.go new file mode 100644 index 0000000000..f1ff7e47ab --- /dev/null +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/exports.go @@ -0,0 +1,61 @@ +package client + +import ( + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport" +) + +// Re-exports from internal sub-packages (common, protocol, transport) so that +// callers can import this package alone instead of reaching into internals. + +type ( + Message = common.Message + MessageType = common.MessageType + ExecutionResult = common.ExecutionResult + Request = common.Request + Options = common.Options + Handler = common.Handler + TransportType = common.TransportType + WSSubprotocol = common.WSSubprotocol + SSEMethod = common.SSEMethod +) + +// Re-export constants. + +const ( + MessageTypeUnknown = common.MessageTypeUnknown + MessageTypeData = common.MessageTypeData + MessageTypeError = common.MessageTypeError + MessageTypeComplete = common.MessageTypeComplete + MessageTypeConnectionError = common.MessageTypeConnectionError + + TransportWS = common.TransportWS + TransportSSE = common.TransportSSE + + SubprotocolAuto = common.SubprotocolAuto + SubprotocolGraphQLTransportWS = common.SubprotocolGraphQLTransportWS + SubprotocolGraphQLWS = common.SubprotocolGraphQLWS + + SSEMethodPOST = common.SSEMethodPOST + SSEMethodGET = common.SSEMethodGET +) + +// Re-export error types. + +type ( + ErrFailedUpgrade = transport.ErrFailedUpgrade + ErrInvalidSubprotocol = transport.ErrInvalidSubprotocol +) + +// Re-export sentinel errors. + +var ( + ErrConnectionClosed = common.ErrConnectionClosed + ErrConnectionError = protocol.ErrConnectionError + ErrAckTimeout = protocol.ErrAckTimeout + ErrAckNotReceived = protocol.ErrAckNotReceived + ErrSubscriptionExists = transport.ErrSubscriptionExists + ErrDialFailed = transport.ErrDialFailed + ErrInitFailed = transport.ErrInitFailed +) diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_transport_ws.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_transport_ws.go new file mode 100644 index 0000000000..330dd969a3 --- /dev/null +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_transport_ws.go @@ -0,0 +1,168 @@ +package protocol + +import ( + "context" + "encoding/json" + "errors" + "fmt" + + "github.com/coder/websocket" + "github.com/coder/websocket/wsjson" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common" +) + +// graphqlTransportWS implements the graphql-transport-ws protocol. +// See: https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md +type graphqlTransportWS struct{} + +const ( + gtwsTypeConnectionInit = "connection_init" + gtwsTypeConnectionAck = "connection_ack" + gtwsTypePing = "ping" + gtwsTypePong = "pong" + gtwsTypeSubscribe = "subscribe" + gtwsTypeNext = "next" + gtwsTypeError = "error" + gtwsTypeComplete = "complete" +) + +type outgoingMessage struct { + ID string `json:"id,omitempty"` + Type string `json:"type"` + Payload any `json:"payload,omitempty"` +} + +type incomingMessage struct { + ID string `json:"id,omitempty"` + Type string `json:"type"` + Payload json.RawMessage `json:"payload,omitempty"` +} + +func NewGraphQLTransportWS() *graphqlTransportWS { + return &graphqlTransportWS{} +} + +// Init implements Protocol. +func (p *graphqlTransportWS) Init(ctx context.Context, conn *websocket.Conn, payload map[string]any) error { + initMsg := outgoingMessage{ + Type: gtwsTypeConnectionInit, + } + if payload != nil { + initMsg.Payload = payload + } + if err := wsjson.Write(ctx, conn, initMsg); err != nil { + return fmt.Errorf("write connection_init: %w", err) + } + + for { + var ackMessage incomingMessage + if err := wsjson.Read(ctx, conn, &ackMessage); err != nil { + if errors.Is(err, context.DeadlineExceeded) { + return ErrAckTimeout + } + return fmt.Errorf("read connection_ack: %w", err) + } + + switch ackMessage.Type { + case gtwsTypeConnectionAck: + return nil + case gtwsTypePing: + if err := p.Pong(ctx, conn); err != nil { + return fmt.Errorf("pre-init pong: %w", err) + } + continue + case gtwsTypePong: + continue + default: + return fmt.Errorf("%w: got %q", ErrAckNotReceived, ackMessage.Type) + } + } +} + +// Ping implements Pinger. +func (p *graphqlTransportWS) Ping(ctx context.Context, conn *websocket.Conn) error { + msg := outgoingMessage{ + Type: gtwsTypePing, + } + return wsjson.Write(ctx, conn, msg) +} + +// Pong implements Pinger. +func (p *graphqlTransportWS) Pong(ctx context.Context, conn *websocket.Conn) error { + msg := outgoingMessage{ + Type: gtwsTypePong, + } + return wsjson.Write(ctx, conn, msg) +} + +// Read implements Protocol. +func (p *graphqlTransportWS) Read(ctx context.Context, conn *websocket.Conn) (*WireMessage, error) { + var raw incomingMessage + if err := wsjson.Read(ctx, conn, &raw); err != nil { + return nil, fmt.Errorf("read message: %w", err) + } + + return p.decode(raw) +} + +// Subscribe implements Protocol. +func (p *graphqlTransportWS) Subscribe(ctx context.Context, conn *websocket.Conn, id string, req *common.Request) error { + msg := outgoingMessage{ + ID: id, + Type: gtwsTypeSubscribe, + Payload: req, + } + return wsjson.Write(ctx, conn, msg) +} + +// Unsubscribe implements Protocol. +func (p *graphqlTransportWS) Unsubscribe(ctx context.Context, conn *websocket.Conn, id string) error { + msg := outgoingMessage{ + ID: id, + Type: gtwsTypeComplete, + } + return wsjson.Write(ctx, conn, msg) +} + +func (p *graphqlTransportWS) decode(raw incomingMessage) (*WireMessage, error) { + msg := &WireMessage{ + ID: raw.ID, + } + + switch raw.Type { + case gtwsTypeNext: + msg.Type = MessageData + if raw.Payload != nil { + var resp common.ExecutionResult + if err := json.Unmarshal(raw.Payload, &resp); err != nil { + return nil, fmt.Errorf("unmarshal next payload: %w", err) + } + msg.Payload = &resp + } + case gtwsTypeError: + msg.Type = MessageError + if raw.Payload != nil { + msg.Payload = &common.ExecutionResult{Errors: raw.Payload} + } + + case gtwsTypeComplete: + msg.Type = MessageComplete + + case gtwsTypePing: + msg.Type = MessagePing + + case gtwsTypePong: + msg.Type = MessagePong + + default: + return nil, fmt.Errorf("unknown message type: %s", raw.Type) + } + + return msg, nil +} + +var ( + _ Protocol = (*graphqlTransportWS)(nil) + _ Pinger = (*graphqlTransportWS)(nil) +) diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_transport_ws_test.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_transport_ws_test.go new file mode 100644 index 0000000000..951109c4ff --- /dev/null +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_transport_ws_test.go @@ -0,0 +1,380 @@ +package protocol + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/coder/websocket" + "github.com/coder/websocket/wsjson" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common" +) + +func TestGraphQLTransportWS_Init(t *testing.T) { + t.Parallel() + + t.Run("sends connection_init and receives connection_ack", func(t *testing.T) { + t.Parallel() + + received := make(chan map[string]any, 1) + server := newGTWSTestServer(t, func(ctx context.Context, conn *websocket.Conn) { + var msg map[string]any + if err := wsjson.Read(ctx, conn, &msg); err == nil { + received <- msg + } + _ = wsjson.Write(ctx, conn, map[string]string{"type": "connection_ack"}) + }) + + conn := dialGTWS(t, server) + + p := NewGraphQLTransportWS() + + err := p.Init(t.Context(), conn, map[string]any{"secret": "token"}) + require.NoError(t, err) + + awaitChannelWithT(t, time.Second, received, func(t *testing.T, msg map[string]any) { + assert.Equal(t, "connection_init", msg["type"]) + payload, _ := msg["payload"].(map[string]any) + assert.Equal(t, "token", payload["secret"]) + }) + }) + + t.Run("returns error when ack times out", func(t *testing.T) { + t.Parallel() + + server := newGTWSTestServer(t, func(ctx context.Context, conn *websocket.Conn) { + time.Sleep(500 * time.Millisecond) + }) + + conn := dialGTWS(t, server) + + p := NewGraphQLTransportWS() + ctx, cancel := context.WithTimeout(t.Context(), 50*time.Millisecond) + defer cancel() + err := p.Init(ctx, conn, nil) + + require.ErrorIs(t, err, ErrAckTimeout) + }) + + t.Run("handles ping before ack", func(t *testing.T) { + t.Parallel() + + received := make(chan map[string]any, 2) + server := newGTWSTestServer(t, func(ctx context.Context, conn *websocket.Conn) { + var msg1 map[string]any + _ = wsjson.Read(ctx, conn, &msg1) + received <- msg1 + + _ = wsjson.Write(ctx, conn, map[string]string{"type": "ping"}) + + var msg2 map[string]any + _ = wsjson.Read(ctx, conn, &msg2) + received <- msg2 + + _ = wsjson.Write(ctx, conn, map[string]string{"type": "connection_ack"}) + }) + + conn := dialGTWS(t, server) + + p := NewGraphQLTransportWS() + err := p.Init(t.Context(), conn, nil) + require.NoError(t, err) + + awaitChannelWithT(t, time.Second, received, func(t *testing.T, msg map[string]any) { + assert.Equal(t, "connection_init", msg["type"]) + }) + + awaitChannelWithT(t, time.Second, received, func(t *testing.T, msg map[string]any) { + assert.Equal(t, "pong", msg["type"]) + }) + }) + + t.Run("returns error on unexpected message type", func(t *testing.T) { + t.Parallel() + + server := newGTWSTestServer(t, func(ctx context.Context, conn *websocket.Conn) { + var msg map[string]any + _ = wsjson.Read(ctx, conn, &msg) + _ = wsjson.Write(ctx, conn, map[string]string{"type": "error"}) + }) + + conn := dialGTWS(t, server) + + p := NewGraphQLTransportWS() + err := p.Init(t.Context(), conn, nil) + + assert.ErrorIs(t, err, ErrAckNotReceived) + }) +} + +func TestGraphQLWS_Subscribe(t *testing.T) { + t.Parallel() + + t.Run("sends subscribe message with query and variables", func(t *testing.T) { + t.Parallel() + + received := make(chan map[string]any, 1) + server := newGTWSTestServer(t, func(ctx context.Context, conn *websocket.Conn) { + var msg map[string]any + if err := wsjson.Read(ctx, conn, &msg); err == nil { + received <- msg + } + }) + + conn := dialGTWS(t, server) + + p := NewGraphQLTransportWS() + err := p.Subscribe(t.Context(), conn, "sub-1", &common.Request{ + Query: "subscription { test }", + Variables: []byte(`{"id": 123}`), + }) + require.NoError(t, err) + + awaitChannelWithT(t, time.Second, received, func(t *testing.T, msg map[string]any) { + assert.Equal(t, "subscribe", msg["type"]) + assert.Equal(t, "sub-1", msg["id"]) + + payload, _ := msg["payload"].(map[string]any) + assert.Equal(t, "subscription { test }", payload["query"]) + + vars, _ := payload["variables"].(map[string]any) + assert.Equal(t, float64(123), vars["id"]) + }) + }) +} + +func TestGraphQLWS_Unsubscribe(t *testing.T) { + t.Parallel() + + t.Run("sends complete message with subscription id", func(t *testing.T) { + t.Parallel() + + received := make(chan map[string]any, 1) + server := newGTWSTestServer(t, func(ctx context.Context, conn *websocket.Conn) { + var msg map[string]any + if err := wsjson.Read(ctx, conn, &msg); err == nil { + received <- msg + } + }) + + conn := dialGTWS(t, server) + + p := NewGraphQLTransportWS() + err := p.Unsubscribe(t.Context(), conn, "sub-1") + require.NoError(t, err) + + awaitChannelWithT(t, time.Second, received, func(t *testing.T, msg map[string]any) { + assert.Equal(t, "complete", msg["type"]) + assert.Equal(t, "sub-1", msg["id"]) + }) + }) +} + +func TestGraphQLWS_Read(t *testing.T) { + t.Parallel() + + t.Run("decodes next message with data payload", func(t *testing.T) { + t.Parallel() + + server := newGTWSTestServer(t, func(ctx context.Context, conn *websocket.Conn) { + _ = wsjson.Write(ctx, conn, map[string]any{ + "id": "sub-1", + "type": "next", + "payload": map[string]any{ + "data": map[string]any{"value": 42}, + }, + }) + }) + + conn := dialGTWS(t, server) + + p := NewGraphQLTransportWS() + msg, err := p.Read(t.Context(), conn) + + require.NoError(t, err) + assert.Equal(t, "sub-1", msg.ID) + assert.Equal(t, MessageData, msg.Type) + require.NotNil(t, msg.Payload) + assert.Contains(t, string(msg.Payload.Data), "42") + }) + + t.Run("decodes error message with graphql errors", func(t *testing.T) { + t.Parallel() + + server := newGTWSTestServer(t, func(ctx context.Context, conn *websocket.Conn) { + _ = wsjson.Write(ctx, conn, map[string]any{ + "id": "sub-1", + "type": "error", + "payload": []map[string]any{ + {"message": "something went wrong"}, + }, + }) + }) + + conn := dialGTWS(t, server) + + p := NewGraphQLTransportWS() + msg, err := p.Read(t.Context(), conn) + + require.NoError(t, err) + assert.Equal(t, MessageError, msg.Type) + require.NotNil(t, msg.Payload) + assert.Contains(t, string(msg.Payload.Errors), "something went wrong") + }) + + t.Run("decodes complete message", func(t *testing.T) { + t.Parallel() + + server := newGTWSTestServer(t, func(ctx context.Context, conn *websocket.Conn) { + _ = wsjson.Write(ctx, conn, map[string]any{ + "id": "sub-1", + "type": "complete", + }) + }) + + conn := dialGTWS(t, server) + + p := NewGraphQLTransportWS() + msg, err := p.Read(t.Context(), conn) + + require.NoError(t, err) + assert.Equal(t, "sub-1", msg.ID) + assert.Equal(t, MessageComplete, msg.Type) + }) + + t.Run("decodes ping message", func(t *testing.T) { + t.Parallel() + + server := newGTWSTestServer(t, func(ctx context.Context, conn *websocket.Conn) { + _ = wsjson.Write(ctx, conn, map[string]string{"type": "ping"}) + }) + + conn := dialGTWS(t, server) + + p := NewGraphQLTransportWS() + msg, err := p.Read(t.Context(), conn) + + require.NoError(t, err) + assert.Equal(t, MessagePing, msg.Type) + }) + + t.Run("returns error for unknown message type", func(t *testing.T) { + t.Parallel() + + server := newGTWSTestServer(t, func(ctx context.Context, conn *websocket.Conn) { + _ = wsjson.Write(ctx, conn, map[string]string{"type": "unknown"}) + }) + + conn := dialGTWS(t, server) + + p := NewGraphQLTransportWS() + _, err := p.Read(t.Context(), conn) + + require.Error(t, err) + assert.Contains(t, err.Error(), "unknown") + }) +} + +func TestGraphQLTransportWS_PingPong(t *testing.T) { + t.Parallel() + + t.Run("sends ping message", func(t *testing.T) { + t.Parallel() + + received := make(chan map[string]any, 1) + + server := newGTWSTestServer(t, func(ctx context.Context, conn *websocket.Conn) { + var msg map[string]any + if err := wsjson.Read(ctx, conn, &msg); err == nil { + received <- msg + } + }) + + conn := dialGTWS(t, server) + + p := NewGraphQLTransportWS() + err := p.Ping(t.Context(), conn) + require.NoError(t, err) + + awaitChannelWithT(t, time.Second, received, func(t *testing.T, msg map[string]any) { + assert.Equal(t, "ping", msg["type"]) + }) + }) + + t.Run("sends pong message", func(t *testing.T) { + t.Parallel() + + received := make(chan map[string]any, 1) + + server := newGTWSTestServer(t, func(ctx context.Context, conn *websocket.Conn) { + var msg map[string]any + if err := wsjson.Read(ctx, conn, &msg); err == nil { + received <- msg + } + }) + + conn := dialGTWS(t, server) + + p := NewGraphQLTransportWS() + err := p.Pong(t.Context(), conn) + require.NoError(t, err) + + awaitChannelWithT(t, time.Second, received, func(t *testing.T, msg map[string]any) { + assert.Equal(t, "pong", msg["type"]) + }) + }) +} + +func newGTWSTestServer(t *testing.T, handler func(ctx context.Context, conn *websocket.Conn)) *httptest.Server { + t.Helper() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + Subprotocols: []string{"graphql-transport-ws"}, + }) + if err != nil { + return + } + defer conn.Close(websocket.StatusNormalClosure, "") + + ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second) + defer cancel() + + handler(ctx, conn) + })) + + t.Cleanup(server.Close) + + return server +} + +func dialGTWS(t *testing.T, server *httptest.Server) *websocket.Conn { + t.Helper() + + conn, _, err := websocket.Dial(t.Context(), server.URL, &websocket.DialOptions{ //nolint:bodyclose + Subprotocols: []string{"graphql-transport-ws"}, + }) + require.NoError(t, err) + + t.Cleanup(func() { + _ = conn.Close(websocket.StatusNormalClosure, "") + }) + + return conn +} + +func awaitChannelWithT[A any](t *testing.T, timeout time.Duration, ch <-chan A, f func(*testing.T, A), msgAndArgs ...any) { + t.Helper() + + select { + case args := <-ch: + f(t, args) + case <-time.After(timeout): + require.Fail(t, "unable to receive message before timeout", msgAndArgs...) + } +} diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_ws.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_ws.go new file mode 100644 index 0000000000..ee9f24e38b --- /dev/null +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_ws.go @@ -0,0 +1,148 @@ +package protocol + +import ( + "context" + "encoding/json" + "errors" + "fmt" + + "github.com/coder/websocket" + "github.com/coder/websocket/wsjson" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common" +) + +// graphqlWS implements the legacy graphql-ws protocol. +// See: https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md +type graphqlWS struct{} + +const ( + gwsTypeConnectionInit = "connection_init" + gwsTypeConnectionAck = "connection_ack" + gwsTypeConnectionError = "connection_error" + gwsTypeConnectionKeepAlive = "ka" + gwsTypeStart = "start" + gwsTypeData = "data" + gwsTypeError = "error" + gwsTypeComplete = "complete" + gwsTypeStop = "stop" +) + +func NewGraphQLWS() *graphqlWS { + return &graphqlWS{} +} + +// Init implements Protocol. +func (p *graphqlWS) Init(ctx context.Context, conn *websocket.Conn, payload map[string]any) error { + initMsg := outgoingMessage{ + Type: gwsTypeConnectionInit, + } + if payload != nil { + initMsg.Payload = payload + } + if err := wsjson.Write(ctx, conn, initMsg); err != nil { + return fmt.Errorf("write connection_init: %w", err) + } + + for { + var ackMessage incomingMessage + if err := wsjson.Read(ctx, conn, &ackMessage); err != nil { + if errors.Is(err, context.DeadlineExceeded) { + return ErrAckTimeout + } + return fmt.Errorf("read connection_ack: %w", err) + } + + switch ackMessage.Type { + case gwsTypeConnectionAck: + return nil + case gwsTypeConnectionKeepAlive: + // Keep-alive messages can arrive before ack, ignore them + continue + case gwsTypeConnectionError: + var errPayload map[string]any + if ackMessage.Payload != nil { + // If this fails, the error will have nil errors anyway, handling it does nothing unique + _ = json.Unmarshal(ackMessage.Payload, &errPayload) + } + return fmt.Errorf("%w: %v", ErrConnectionError, errPayload) + default: + return fmt.Errorf("%w: got %q", ErrAckNotReceived, ackMessage.Type) + } + } +} + +// Subscribe implements Protocol. +func (p *graphqlWS) Subscribe(ctx context.Context, conn *websocket.Conn, id string, req *common.Request) error { + msg := outgoingMessage{ + ID: id, + Type: gwsTypeStart, + Payload: req, + } + return wsjson.Write(ctx, conn, msg) +} + +// Unsubscribe implements Protocol. +func (p *graphqlWS) Unsubscribe(ctx context.Context, conn *websocket.Conn, id string) error { + msg := outgoingMessage{ + ID: id, + Type: gwsTypeStop, + } + return wsjson.Write(ctx, conn, msg) +} + +// Read implements Protocol. +func (p *graphqlWS) Read(ctx context.Context, conn *websocket.Conn) (*WireMessage, error) { + var raw incomingMessage + if err := wsjson.Read(ctx, conn, &raw); err != nil { + return nil, fmt.Errorf("read message: %w", err) + } + + return p.decode(raw) +} + +func (p *graphqlWS) decode(raw incomingMessage) (*WireMessage, error) { + msg := &WireMessage{ + ID: raw.ID, + } + + switch raw.Type { + case gwsTypeData: + msg.Type = MessageData + if raw.Payload != nil { + var resp common.ExecutionResult + if err := json.Unmarshal(raw.Payload, &resp); err != nil { + return nil, fmt.Errorf("unmarshal data payload: %w", err) + } + msg.Payload = &resp + } + + case gwsTypeError: + msg.Type = MessageError + if raw.Payload != nil { + msg.Payload = &common.ExecutionResult{Errors: raw.Payload} + } + + case gwsTypeComplete: + msg.Type = MessageComplete + + case gwsTypeConnectionKeepAlive: + // Map keep-alive to ping for consistent handling + msg.Type = MessagePing + + case gwsTypeConnectionError: + msg.Type = MessageError + var errPayload map[string]any + if raw.Payload != nil { + _ = json.Unmarshal(raw.Payload, &errPayload) + } + msg.Err = fmt.Errorf("%w: %v", ErrConnectionError, errPayload) + + default: + return nil, fmt.Errorf("unknown message type: %s", raw.Type) + } + + return msg, nil +} + +var _ Protocol = (*graphqlWS)(nil) diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_ws_test.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_ws_test.go new file mode 100644 index 0000000000..f19ebe095b --- /dev/null +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/graphql_ws_test.go @@ -0,0 +1,351 @@ +package protocol + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/coder/websocket" + "github.com/coder/websocket/wsjson" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common" +) + +func TestGraphQLWS_Init(t *testing.T) { + t.Parallel() + + t.Run("sends connection_init and receives connection_ack", func(t *testing.T) { + t.Parallel() + + received := make(chan map[string]any, 1) + server := newGWSTestServer(t, func(ctx context.Context, conn *websocket.Conn) { + var msg map[string]any + if err := wsjson.Read(ctx, conn, &msg); err == nil { + received <- msg + } + _ = wsjson.Write(ctx, conn, map[string]string{"type": "connection_ack"}) + }) + + conn := dialGWS(t, server) + + p := NewGraphQLWS() + + err := p.Init(t.Context(), conn, map[string]any{"secret": "token"}) + require.NoError(t, err) + + awaitChannelWithT(t, time.Second, received, func(t *testing.T, msg map[string]any) { + assert.Equal(t, "connection_init", msg["type"]) + payload, _ := msg["payload"].(map[string]any) + assert.Equal(t, "token", payload["secret"]) + }) + }) + + t.Run("returns error when ack times out", func(t *testing.T) { + t.Parallel() + + server := newGWSTestServer(t, func(ctx context.Context, conn *websocket.Conn) { + time.Sleep(500 * time.Millisecond) + }) + + conn := dialGWS(t, server) + + p := NewGraphQLWS() + ctx, cancel := context.WithTimeout(t.Context(), 50*time.Millisecond) + defer cancel() + err := p.Init(ctx, conn, nil) + + require.ErrorIs(t, err, ErrAckTimeout) + }) + + t.Run("handles keep-alive before ack", func(t *testing.T) { + t.Parallel() + + server := newGWSTestServer(t, func(ctx context.Context, conn *websocket.Conn) { + var msg1 map[string]any + _ = wsjson.Read(ctx, conn, &msg1) // connection_init + + // Send keep-alive before ack + _ = wsjson.Write(ctx, conn, map[string]string{"type": "ka"}) + + // Then send ack + _ = wsjson.Write(ctx, conn, map[string]string{"type": "connection_ack"}) + }) + + conn := dialGWS(t, server) + + p := NewGraphQLWS() + err := p.Init(t.Context(), conn, nil) + require.NoError(t, err) + }) + + t.Run("returns error on connection_error", func(t *testing.T) { + t.Parallel() + + server := newGWSTestServer(t, func(ctx context.Context, conn *websocket.Conn) { + var msg map[string]any + _ = wsjson.Read(ctx, conn, &msg) + _ = wsjson.Write(ctx, conn, map[string]any{ + "type": "connection_error", + "payload": map[string]any{"message": "auth failed"}, + }) + }) + + conn := dialGWS(t, server) + + p := NewGraphQLWS() + err := p.Init(t.Context(), conn, nil) + + require.ErrorIs(t, err, ErrConnectionError) + assert.Contains(t, err.Error(), "auth failed") + }) + + t.Run("returns error on unexpected message type", func(t *testing.T) { + t.Parallel() + + server := newGWSTestServer(t, func(ctx context.Context, conn *websocket.Conn) { + var msg map[string]any + _ = wsjson.Read(ctx, conn, &msg) + _ = wsjson.Write(ctx, conn, map[string]string{"type": "error"}) + }) + + conn := dialGWS(t, server) + + p := NewGraphQLWS() + err := p.Init(t.Context(), conn, nil) + + assert.ErrorIs(t, err, ErrAckNotReceived) + }) +} + +func TestGraphQLWSLegacy_Subscribe(t *testing.T) { + t.Parallel() + + t.Run("sends start message with query and variables", func(t *testing.T) { + t.Parallel() + + received := make(chan map[string]any, 1) + server := newGWSTestServer(t, func(ctx context.Context, conn *websocket.Conn) { + var msg map[string]any + if err := wsjson.Read(ctx, conn, &msg); err == nil { + received <- msg + } + }) + + conn := dialGWS(t, server) + + p := NewGraphQLWS() + err := p.Subscribe(t.Context(), conn, "sub-1", &common.Request{ + Query: "subscription { test }", + Variables: []byte(`{"id": 123}`), + }) + require.NoError(t, err) + + awaitChannelWithT(t, time.Second, received, func(t *testing.T, msg map[string]any) { + assert.Equal(t, "start", msg["type"]) + assert.Equal(t, "sub-1", msg["id"]) + + payload, _ := msg["payload"].(map[string]any) + assert.Equal(t, "subscription { test }", payload["query"]) + + vars, _ := payload["variables"].(map[string]any) + assert.Equal(t, float64(123), vars["id"]) + }) + }) +} + +func TestGraphQLWSLegacy_Unsubscribe(t *testing.T) { + t.Parallel() + + t.Run("sends stop message with subscription id", func(t *testing.T) { + t.Parallel() + + received := make(chan map[string]any, 1) + server := newGWSTestServer(t, func(ctx context.Context, conn *websocket.Conn) { + var msg map[string]any + if err := wsjson.Read(ctx, conn, &msg); err == nil { + received <- msg + } + }) + + conn := dialGWS(t, server) + + p := NewGraphQLWS() + err := p.Unsubscribe(t.Context(), conn, "sub-1") + require.NoError(t, err) + + awaitChannelWithT(t, time.Second, received, func(t *testing.T, msg map[string]any) { + assert.Equal(t, "stop", msg["type"]) + assert.Equal(t, "sub-1", msg["id"]) + }) + }) +} + +func TestGraphQLWSLegacy_Read(t *testing.T) { + t.Parallel() + + t.Run("decodes data message with payload", func(t *testing.T) { + t.Parallel() + + server := newGWSTestServer(t, func(ctx context.Context, conn *websocket.Conn) { + _ = wsjson.Write(ctx, conn, map[string]any{ + "id": "sub-1", + "type": "data", + "payload": map[string]any{ + "data": map[string]any{"value": 42}, + }, + }) + }) + + conn := dialGWS(t, server) + + p := NewGraphQLWS() + msg, err := p.Read(t.Context(), conn) + + require.NoError(t, err) + assert.Equal(t, "sub-1", msg.ID) + assert.Equal(t, MessageData, msg.Type) + require.NotNil(t, msg.Payload) + assert.Contains(t, string(msg.Payload.Data), "42") + }) + + t.Run("decodes error message with graphql errors", func(t *testing.T) { + t.Parallel() + + server := newGWSTestServer(t, func(ctx context.Context, conn *websocket.Conn) { + _ = wsjson.Write(ctx, conn, map[string]any{ + "id": "sub-1", + "type": "error", + "payload": []map[string]any{ + {"message": "something went wrong"}, + }, + }) + }) + + conn := dialGWS(t, server) + + p := NewGraphQLWS() + msg, err := p.Read(t.Context(), conn) + + require.NoError(t, err) + assert.Equal(t, MessageError, msg.Type) + require.NotNil(t, msg.Payload) + assert.Contains(t, string(msg.Payload.Errors), "something went wrong") + }) + + t.Run("decodes complete message", func(t *testing.T) { + t.Parallel() + + server := newGWSTestServer(t, func(ctx context.Context, conn *websocket.Conn) { + _ = wsjson.Write(ctx, conn, map[string]any{ + "id": "sub-1", + "type": "complete", + }) + }) + + conn := dialGWS(t, server) + + p := NewGraphQLWS() + msg, err := p.Read(t.Context(), conn) + + require.NoError(t, err) + assert.Equal(t, "sub-1", msg.ID) + assert.Equal(t, MessageComplete, msg.Type) + }) + + t.Run("decodes keep-alive message as ping", func(t *testing.T) { + t.Parallel() + + server := newGWSTestServer(t, func(ctx context.Context, conn *websocket.Conn) { + _ = wsjson.Write(ctx, conn, map[string]string{"type": "ka"}) + }) + + conn := dialGWS(t, server) + + p := NewGraphQLWS() + msg, err := p.Read(t.Context(), conn) + + require.NoError(t, err) + assert.Equal(t, MessagePing, msg.Type) + }) + + t.Run("decodes connection_error message", func(t *testing.T) { + t.Parallel() + + server := newGWSTestServer(t, func(ctx context.Context, conn *websocket.Conn) { + _ = wsjson.Write(ctx, conn, map[string]any{ + "type": "connection_error", + "payload": map[string]any{"reason": "session expired"}, + }) + }) + + conn := dialGWS(t, server) + + p := NewGraphQLWS() + msg, err := p.Read(t.Context(), conn) + + require.NoError(t, err) + assert.Equal(t, MessageError, msg.Type) + require.Error(t, msg.Err) + assert.Contains(t, msg.Err.Error(), "session expired") + }) + + t.Run("returns error for unknown message type", func(t *testing.T) { + t.Parallel() + + server := newGWSTestServer(t, func(ctx context.Context, conn *websocket.Conn) { + _ = wsjson.Write(ctx, conn, map[string]any{ + "type": "unknown", + }) + }) + + conn := dialGWS(t, server) + + p := NewGraphQLWS() + _, err := p.Read(t.Context(), conn) + + require.Error(t, err) + assert.Contains(t, err.Error(), "unknown") + }) +} + +func newGWSTestServer(t *testing.T, handler func(ctx context.Context, conn *websocket.Conn)) *httptest.Server { + t.Helper() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + Subprotocols: []string{"graphql-ws"}, + }) + if err != nil { + return + } + defer conn.Close(websocket.StatusNormalClosure, "") + + ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second) + defer cancel() + + handler(ctx, conn) + })) + + t.Cleanup(server.Close) + + return server +} + +func dialGWS(t *testing.T, server *httptest.Server) *websocket.Conn { + t.Helper() + + conn, _, err := websocket.Dial(t.Context(), server.URL, &websocket.DialOptions{ //nolint:bodyclose + Subprotocols: []string{"graphql-ws"}, + }) + require.NoError(t, err) + + t.Cleanup(func() { + _ = conn.Close(websocket.StatusNormalClosure, "") + }) + + return conn +} diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/protocol.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/protocol.go new file mode 100644 index 0000000000..f4941c313a --- /dev/null +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol/protocol.go @@ -0,0 +1,93 @@ +package protocol + +import ( + "context" + "errors" + + "github.com/coder/websocket" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common" +) + +// Protocol defines the message framing and behaviour used on a WS connection. +type Protocol interface { + // Init performs the connection handshake with the server. + Init(ctx context.Context, conn *websocket.Conn, payload map[string]any) error + + // Subscribe starts a subscription for the given operation. + Subscribe(ctx context.Context, conn *websocket.Conn, id string, req *common.Request) error + + // Unsubscribe ends a subscription. + Unsubscribe(ctx context.Context, conn *websocket.Conn, id string) error + + // Read blocks until the next message arrives and decodes it. + Read(ctx context.Context, conn *websocket.Conn) (*WireMessage, error) +} + +// Pinger is an optional interface for protocols that support client-initiated +// ping/pong (e.g. graphql-transport-ws). Protocols that only have server-initiated +// keep-alive (e.g. legacy graphql-ws with ka messages) do not implement this. +type Pinger interface { + Ping(ctx context.Context, conn *websocket.Conn) error + Pong(ctx context.Context, conn *websocket.Conn) error +} + +var ( + ErrAckTimeout = errors.New("connection_ack timeout") + ErrAckNotReceived = errors.New("expected connection_ack") + ErrConnectionError = errors.New("connection error from server") +) + +// WireMessage is a decoded wire-level protocol message. +// It is different from the common message format because it still contains the ID and internal type, +// which is not exposed to consumers. +type WireMessage struct { + ID string + Type WireMessageType + Payload *common.ExecutionResult + Err error +} + +func (m *WireMessage) IntoClientMessage() *common.Message { + switch m.Type { + case MessageData: + return &common.Message{Type: common.MessageTypeData, Payload: m.Payload} + case MessageError: + if m.Payload != nil { + return &common.Message{Type: common.MessageTypeError, Payload: m.Payload} + } + return &common.Message{Type: common.MessageTypeConnectionError, Err: m.Err} + case MessageComplete: + return &common.Message{Type: common.MessageTypeComplete} + default: + return &common.Message{Type: common.MessageTypeUnknown} + } +} + +// WireMessageType identifies the message type. +type WireMessageType uint8 + +const ( + MessageData WireMessageType = iota + MessageError + MessageComplete + MessagePing + MessagePong +) + +func (t WireMessageType) String() string { + switch t { + case MessageData: + return "data" + case MessageError: + return "error" + case MessageComplete: + return "complete" + case MessagePing: + return "ping" + case MessagePong: + return "pong" + default: + return "unknown" + } +} diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_conn.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_conn.go new file mode 100644 index 0000000000..a98999bf5e --- /dev/null +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_conn.go @@ -0,0 +1,175 @@ +package transport + +import ( + "bytes" + "encoding/json" + "net/http" + "sync/atomic" + + "github.com/r3labs/sse/v2" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common" +) + +var ( + headerData = []byte("data:") + headerEvent = []byte("event:") +) + +// sseConnection handles a single SSE subscription stream. +type sseConnection struct { + resp *http.Response + handler common.Handler + closed atomic.Bool +} + +func newSSEConnection(resp *http.Response, handler common.Handler) *sseConnection { + return &sseConnection{ + resp: resp, + handler: handler, + } +} + +// readLoop reads SSE events from the response body and delivers them to the handler. +// Every exit path delivers a terminal message to the handler unless the connection +// was closed by the consumer. +func (c *sseConnection) readLoop() { + defer c.cleanup() + + reader := sse.NewEventStreamReader(c.resp.Body, 1<<16) // 64KB + + for { + if c.closed.Load() { + return + } + + eventBytes, err := reader.ReadEvent() + if err != nil { + if c.closed.Load() { + return + } + c.sendError(err) + return + } + + // Parse the raw event bytes into event type and data + eventType, data := c.parseEventBytes(eventBytes) + + // Skip empty events (e.g., keep-alive comments) + if eventType == "" && data == nil { + continue + } + + msg := c.parseEvent(eventType, data) + + if c.closed.Load() { + return + } + c.handler(msg) + + if msg.Type.IsTerminal() { + return + } + } +} + +// parseEventBytes extracts the event type and data from raw SSE event bytes. +// Based on r3labs/sse's processEvent but simplified for our needs. +func (c *sseConnection) parseEventBytes(msg []byte) (eventType string, data []byte) { + if len(msg) == 0 { + return "", nil + } + + for line := range bytes.Lines(msg) { + line = bytes.TrimRight(line, "\r\n") + switch { + case bytes.HasPrefix(line, headerEvent): + eventType = string(trimHeader(len(headerEvent), line)) + + case bytes.HasPrefix(line, headerData): + // The spec allows for multiple data fields per event, concatenated with "\n" + data = append(data, trimHeader(len(headerData), line)...) + data = append(data, '\n') + + case bytes.Equal(line, []byte("data")): + // A line that simply contains "data" should be treated as empty data + data = append(data, '\n') + + // Comments (lines starting with ':') are already filtered by EventStreamReader + } + } + + // Trim the trailing "\n" per SSE spec + data = bytes.TrimSuffix(data, []byte("\n")) + + return eventType, data +} + +// trimHeader removes the header prefix and optional leading space. +func trimHeader(size int, data []byte) []byte { + if len(data) < size { + return data + } + + data = data[size:] + // Remove optional leading whitespace (single space after colon) + if len(data) > 0 && data[0] == ' ' { + data = data[1:] + } + return data +} + +// parseEvent converts parsed SSE event data into a shared.Message. +func (c *sseConnection) parseEvent(eventType string, data []byte) *common.Message { + switch eventType { + case "next": + var resp common.ExecutionResult + if err := json.Unmarshal(data, &resp); err != nil { + return &common.Message{Type: common.MessageTypeConnectionError, Err: err} + } + return &common.Message{Type: common.MessageTypeData, Payload: &resp} + + case "error": + return &common.Message{ + Type: common.MessageTypeError, + Payload: &common.ExecutionResult{Errors: data}, + } + + case "complete": + return &common.Message{Type: common.MessageTypeComplete} + + default: + // Unknown event type or no event type specified - treat as data + // This handles servers that send data without an event type + if len(data) == 0 { + return &common.Message{Type: common.MessageTypeComplete} + } + var resp common.ExecutionResult + if err := json.Unmarshal(data, &resp); err != nil { + return &common.Message{Type: common.MessageTypeConnectionError, Err: err} + } + return &common.Message{Type: common.MessageTypeData, Payload: &resp} + } +} + +func (c *sseConnection) sendError(err error) { + if c.closed.Load() { + return + } + c.handler(&common.Message{Type: common.MessageTypeConnectionError, Err: err}) +} + +func (c *sseConnection) cleanup() { + c.closed.Store(true) + + c.resp.Body.Close() +} + +// closeConn terminates the SSE connection. +func (c *sseConnection) closeConn() { + if !c.closed.CompareAndSwap(false, true) { + return + } + + c.resp.Body.Close() +} diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_conn_test.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_conn_test.go new file mode 100644 index 0000000000..8aaca5fdc5 --- /dev/null +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_conn_test.go @@ -0,0 +1,139 @@ +package transport + +import ( + "io" + "net/http" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common" +) + +func TestSSEConnection_ReadLoop(t *testing.T) { + t.Run("reads and parses SSE events", func(t *testing.T) { + body := io.NopCloser(strings.NewReader( + "event: next\ndata: {\"data\":{\"time\":\"12:00\"}}\n\n", + )) + resp := &http.Response{Body: body} + handler, receive := collectingHandler() + conn := newSSEConnection(resp, handler) + + go conn.readLoop() + + msg := receive(t, 1*time.Second) + require.NotNil(t, msg.Payload) + // Data field contains the raw "data" value from GraphQL response + assert.JSONEq(t, `{"time":"12:00"}`, string(msg.Payload.Data)) + }) + + t.Run("delivers connection error on EOF", func(t *testing.T) { + body := io.NopCloser(strings.NewReader("")) + resp := &http.Response{Body: body} + handler, receive := collectingHandler() + conn := newSSEConnection(resp, handler) + + go conn.readLoop() + + msg := receive(t, 1*time.Second) + assert.Equal(t, common.MessageTypeConnectionError, msg.Type) + assert.Error(t, msg.Err) + }) + + t.Run("sends error on read failure", func(t *testing.T) { + body := &errorReader{err: io.ErrUnexpectedEOF} + resp := &http.Response{Body: io.NopCloser(body)} + handler, receive := collectingHandler() + conn := newSSEConnection(resp, handler) + + go conn.readLoop() + + msg := receive(t, 1*time.Second) + require.Error(t, msg.Err) + assert.Equal(t, common.MessageTypeConnectionError, msg.Type) + }) + + t.Run("stops on complete event", func(t *testing.T) { + body := io.NopCloser(strings.NewReader( + "event: next\ndata: {\"data\":{}}\n\n" + + "event: complete\ndata:\n\n" + + "event: next\ndata: {\"data\":{}}\n\n", // Should not receive this + )) + resp := &http.Response{Body: body} + handler, receive := collectingHandler() + wrappedHandler, collect := waitForMessages(handler) + conn := newSSEConnection(resp, wrappedHandler) + + go conn.readLoop() + + // First message + msg1 := receive(t, 1*time.Second) + assert.NotNil(t, msg1.Payload) + assert.Equal(t, common.MessageTypeData, msg1.Type) + + // Complete message + msg2 := receive(t, 1*time.Second) + assert.Equal(t, common.MessageTypeComplete, msg2.Type) + + // Wait and verify no more messages arrive after complete + messages := collect(100 * time.Millisecond) + assert.Len(t, messages, 2, "should receive exactly 2 messages before stopping") + }) +} + +func TestSSEConnection_Close(t *testing.T) { + t.Run("closes body", func(t *testing.T) { + pr, pw := io.Pipe() + body := &trackingCloser{Reader: pr} + resp := &http.Response{Body: body} + handler, _ := collectingHandler() + conn := newSSEConnection(resp, handler) + + go conn.readLoop() + + conn.closeConn() + pw.Close() // Ensure pipe is fully closed + + assert.Eventually(t, func() bool { + return body.closed.Load() + }, 1*time.Second, 10*time.Millisecond, "body should be closed") + }) + + t.Run("is idempotent", func(t *testing.T) { + body := io.NopCloser(strings.NewReader("")) + resp := &http.Response{Body: body} + handler, _ := collectingHandler() + conn := newSSEConnection(resp, handler) + + conn.closeConn() + conn.closeConn() // second call is a no-op + }) +} + +// errorReader always returns an error +type errorReader struct { + err error +} + +func (r *errorReader) Read(_ []byte) (int, error) { + return 0, r.err +} + +// trackingCloser tracks if Close was called and forwards to the underlying reader if it implements io.Closer. +type trackingCloser struct { + io.Reader + + closed atomic.Bool +} + +func (c *trackingCloser) Close() error { + c.closed.Store(true) + if closer, ok := c.Reader.(io.Closer); ok { + return closer.Close() + } + return nil +} diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport.go new file mode 100644 index 0000000000..4e93a3fa4d --- /dev/null +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport.go @@ -0,0 +1,269 @@ +package transport + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "maps" + "mime" + "net/http" + "net/url" + "strings" + "sync" + + "github.com/jensneuse/abstractlogger" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common" +) + +const maxErrorBodySize = 4096 + +// SSETransport implements the Transport interface using Server-Sent Events. +// Unlike WebSocket, each subscription creates a separate HTTP request. +// TCP connection reuse is handled by http.Client's connection pool. +// +// Supports both POST (graphql-sse spec) and GET (traditional SSE) methods. +type SSETransport struct { + ctx context.Context + client *http.Client + log abstractlogger.Logger + + mu sync.Mutex + conns map[*sseConnection]struct{} +} + +// NewSSETransport creates a new SSETransport with the provided http.Client. +// The transport will automatically close all connections when ctx is cancelled. +func NewSSETransport(ctx context.Context, client *http.Client, log abstractlogger.Logger) *SSETransport { + if log == nil { + log = abstractlogger.NoopLogger + } + + t := &SSETransport{ + ctx: ctx, + client: client, + log: log, + conns: make(map[*sseConnection]struct{}), + } + + context.AfterFunc(ctx, t.closeAll) + + return t +} + +// Subscribe initiates a GraphQL subscription over SSE. +// Each call creates a new HTTP request (no multiplexing). +// +// The HTTP method is determined by opts.SSEMethod: +// - SSEMethodPOST: POST with JSON body (graphql-sse spec) +// - SSEMethodGET: GET with query parameters (traditional SSE) +func (t *SSETransport) Subscribe(ctx context.Context, req *common.Request, opts common.Options, handler common.Handler) (func(), error) { + var httpReq *http.Request + var err error + + t.log.Debug("sseTransport.Subscribe", + abstractlogger.String("endpoint", opts.Endpoint), + abstractlogger.String("method", string(opts.SSEMethod)), + ) + + switch opts.SSEMethod { + case common.SSEMethodPOST: + httpReq, err = buildPOSTRequest(req, opts) + case common.SSEMethodGET: + httpReq, err = buildGETRequest(req, opts) + default: + return nil, fmt.Errorf("unsupported SSE method: %s", opts.SSEMethod) + } + + if err != nil { + return nil, err + } + + // Derive a request context that outlives ctx (via WithoutCancel) so we can + // control its lifetime independently. Two AfterFunc registrations tie the + // request to both shutdown paths: + // - t.ctx cancel: transport-wide shutdown, tears down all in-flight requests. + // - ctx cancel: individual subscription cancelled by the caller. + requestCtx, requestCancel := context.WithCancel(context.WithoutCancel(ctx)) + context.AfterFunc(t.ctx, requestCancel) + context.AfterFunc(ctx, requestCancel) + + httpReq = httpReq.WithContext(requestCtx) + + // Execute request + resp, err := t.client.Do(httpReq) + if err != nil { + requestCancel() + t.log.Error("sseTransport.Subscribe", + abstractlogger.String("endpoint", opts.Endpoint), + abstractlogger.Error(err), + ) + return nil, fmt.Errorf("execute request: %w", err) + } + + if resp.StatusCode != http.StatusOK { + requestCancel() + body, _ := io.ReadAll(io.LimitReader(resp.Body, maxErrorBodySize)) + resp.Body.Close() + t.log.Error("sseTransport.Subscribe", + abstractlogger.String("endpoint", opts.Endpoint), + abstractlogger.Int("status", resp.StatusCode), + ) + if len(body) > 0 { + return nil, fmt.Errorf("unexpected status %d: %s", resp.StatusCode, string(body)) + } + return nil, fmt.Errorf("unexpected status: %d", resp.StatusCode) + } + + // Verify content type (should be text/event-stream) + if err := t.validateContentType(resp); err != nil { + requestCancel() + resp.Body.Close() + return nil, err + } + + t.log.Debug("sseTransport.Subscribe", + abstractlogger.String("endpoint", opts.Endpoint), + abstractlogger.String("status", "connected"), + ) + + // Create connection + conn := newSSEConnection(resp, handler) + + t.mu.Lock() + t.conns[conn] = struct{}{} + t.mu.Unlock() + + go conn.readLoop() + + cancelFn := func() { + requestCancel() + conn.closeConn() + t.removeConn(conn) + } + + return cancelFn, nil +} + +// buildPOSTRequest creates a POST request with JSON body (graphql-sse spec). +func buildPOSTRequest(req *common.Request, opts common.Options) (*http.Request, error) { + body, err := json.Marshal(req) + if err != nil { + return nil, fmt.Errorf("marshal request: %w", err) + } + + httpReq, err := http.NewRequest(http.MethodPost, opts.Endpoint, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + + // Add custom headers first, then set SSE-required headers so they cannot be overwritten + maps.Copy(httpReq.Header, opts.Headers) + + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Accept", "text/event-stream") + httpReq.Header.Set("Cache-Control", "no-cache") + + return httpReq, nil +} + +// buildGETRequest creates a GET request with query parameters (traditional SSE). +func buildGETRequest(req *common.Request, opts common.Options) (*http.Request, error) { + // Parse the endpoint URL + u, err := url.Parse(opts.Endpoint) + if err != nil { + return nil, fmt.Errorf("parse endpoint: %w", err) + } + + // Build query parameters + q := u.Query() + q.Set("query", req.Query) + + if len(req.Variables) > 0 { + varsJSON, err := json.Marshal(req.Variables) + if err != nil { + return nil, fmt.Errorf("marshal variables: %w", err) + } + q.Set("variables", string(varsJSON)) + } + + if req.OperationName != "" { + q.Set("operationName", req.OperationName) + } + + if len(req.Extensions) > 0 { + extJSON, err := json.Marshal(req.Extensions) + if err != nil { + return nil, fmt.Errorf("marshal extensions: %w", err) + } + q.Set("extensions", string(extJSON)) + } + + u.RawQuery = q.Encode() + + httpReq, err := http.NewRequest(http.MethodGet, u.String(), nil) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + + // Add custom headers first, then set SSE-required headers so they cannot be overwritten + maps.Copy(httpReq.Header, opts.Headers) + + httpReq.Header.Set("Accept", "text/event-stream") + httpReq.Header.Set("Cache-Control", "no-cache") + + return httpReq, nil +} + +// validateContentType checks that the response has the correct content type. +func (t *SSETransport) validateContentType(resp *http.Response) error { + contentType := resp.Header.Get("Content-Type") + if contentType == "" { + return nil // Allow missing content-type + } + + mediaType, _, err := mime.ParseMediaType(contentType) + if err != nil { + return fmt.Errorf("invalid content-type %q: %w", contentType, err) + } + + if !strings.EqualFold(mediaType, "text/event-stream") { + return fmt.Errorf("unexpected content-type: %s", contentType) + } + + return nil +} + +func (t *SSETransport) removeConn(conn *sseConnection) { + t.mu.Lock() + delete(t.conns, conn) + t.mu.Unlock() +} + +// closeAll terminates all active SSE connections. Called automatically when context is cancelled. +func (t *SSETransport) closeAll() { + t.mu.Lock() + conns := make([]*sseConnection, 0, len(t.conns)) + for conn := range t.conns { + conns = append(conns, conn) + } + t.conns = make(map[*sseConnection]struct{}) + t.mu.Unlock() + + t.log.Debug("sseTransport.closeAll", + abstractlogger.Int("connections", len(conns)), + ) + + for _, conn := range conns { + conn.closeConn() + } +} + +// ConnCount returns the number of active SSE connections. +func (t *SSETransport) ConnCount() int { + t.mu.Lock() + defer t.mu.Unlock() + return len(t.conns) +} diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport_test.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport_test.go new file mode 100644 index 0000000000..19100241ae --- /dev/null +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/sse_transport_test.go @@ -0,0 +1,857 @@ +package transport + +import ( + "context" + "encoding/json" + "fmt" + "io" + "maps" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common" +) + +func TestSSETransport_Subscribe(t *testing.T) { + t.Parallel() + + t.Run("sends POST request and receives messages", func(t *testing.T) { + t.Parallel() + + var receivedBody map[string]any + server := newSSEServer(t, func(w http.ResponseWriter, r *http.Request) { + // Verify POST method + assert.Equal(t, http.MethodPost, r.Method) + + // Verify headers + assert.Equal(t, "application/json", r.Header.Get("Content-Type")) + assert.Equal(t, "text/event-stream", r.Header.Get("Accept")) + assert.Equal(t, "no-cache", r.Header.Get("Cache-Control")) + + // Read and verify body + body, _ := io.ReadAll(r.Body) + assert.NoError(t, json.Unmarshal(body, &receivedBody)) + + // Send SSE response + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + + flusher := w.(http.Flusher) + + fmt.Fprintf(w, "event: next\ndata: {\"data\": {\"value\": 42}}\n\n") + flusher.Flush() + + fmt.Fprintf(w, "event: complete\ndata:\n\n") + flusher.Flush() + }) + + tr := NewSSETransport(t.Context(), http.DefaultClient, nil) + + handler, receive := collectingHandler() + cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + Variables: []byte(`{"id": 123}`), + OperationName: "TestSub", + }, common.Options{ + Endpoint: server.URL, + Transport: common.TransportSSE, + SSEMethod: common.SSEMethodPOST, + }, handler) + require.NoError(t, err) + defer cancel() + + // Verify request body + assert.Equal(t, "subscription { test }", receivedBody["query"]) + assert.Equal(t, float64(123), receivedBody["variables"].(map[string]any)["id"]) + assert.Equal(t, "TestSub", receivedBody["operationName"]) + + // Receive data message + msg := receive(t, time.Second) + require.NotNil(t, msg.Payload) + assert.Contains(t, string(msg.Payload.Data), "42") + + // Receive complete message + msg = receive(t, time.Second) + assert.Equal(t, common.MessageTypeComplete, msg.Type) + }) + + t.Run("passes custom headers", func(t *testing.T) { + t.Parallel() + + var receivedAuth string + var receivedCustom string + server := newSSEServer(t, func(w http.ResponseWriter, r *http.Request) { + receivedAuth = r.Header.Get("Authorization") + receivedCustom = r.Header.Get("X-Custom-Header") + + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, "event: complete\ndata:\n\n") + }) + + tr := NewSSETransport(t.Context(), http.DefaultClient, nil) + + headers := http.Header{ + "Authorization": []string{"Bearer token123"}, + "X-Custom-Header": []string{"custom-value"}, + } + + handler, receive := collectingHandler() + cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{ + Endpoint: server.URL, + Headers: headers, + SSEMethod: common.SSEMethodPOST, + }, handler) + require.NoError(t, err) + defer cancel() + + receive(t, time.Second) + + assert.Equal(t, "Bearer token123", receivedAuth) + assert.Equal(t, "custom-value", receivedCustom) + }) + + t.Run("handles next event with data", func(t *testing.T) { + t.Parallel() + + server := newSSEServer(t, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + + flusher := w.(http.Flusher) + fmt.Fprintf(w, "event: next\ndata: {\"data\": {\"user\": {\"name\": \"Alice\"}}}\n\n") + flusher.Flush() + + fmt.Fprintf(w, "event: complete\ndata:\n\n") + flusher.Flush() + }) + + tr := NewSSETransport(t.Context(), http.DefaultClient, nil) + + handler, receive := collectingHandler() + cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { user { name } }", + }, common.Options{Endpoint: server.URL, SSEMethod: common.SSEMethodPOST}, handler) + require.NoError(t, err) + defer cancel() + + msg := receive(t, time.Second) + require.NotNil(t, msg.Payload) + assert.Contains(t, string(msg.Payload.Data), "Alice") + assert.Equal(t, common.MessageTypeData, msg.Type) + }) + + t.Run("handles error event", func(t *testing.T) { + t.Parallel() + + server := newSSEServer(t, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + + fmt.Fprintf(w, "event: error\ndata: [{\"message\": \"Something went wrong\"}]\n\n") + }) + + tr := NewSSETransport(t.Context(), http.DefaultClient, nil) + + handler, receive := collectingHandler() + cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{Endpoint: server.URL, SSEMethod: common.SSEMethodPOST}, handler) + require.NoError(t, err) + defer cancel() + + msg := receive(t, time.Second) + assert.Equal(t, common.MessageTypeError, msg.Type) + require.NotNil(t, msg.Payload) + assert.Contains(t, string(msg.Payload.Errors), "Something went wrong") + }) + + t.Run("handles complete event", func(t *testing.T) { + t.Parallel() + + server := newSSEServer(t, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + + fmt.Fprintf(w, "event: complete\ndata:\n\n") + }) + + tr := NewSSETransport(t.Context(), http.DefaultClient, nil) + + handler, receive := collectingHandler() + cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{Endpoint: server.URL, SSEMethod: common.SSEMethodPOST}, handler) + require.NoError(t, err) + defer cancel() + + msg := receive(t, time.Second) + assert.Equal(t, common.MessageTypeComplete, msg.Type) + assert.Nil(t, msg.Err) + assert.Nil(t, msg.Payload) + }) + + t.Run("handles multi-line data", func(t *testing.T) { + t.Parallel() + + server := newSSEServer(t, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + + flusher := w.(http.Flusher) + // Multi-line data per SSE spec + fmt.Fprintf(w, "event: next\n") + fmt.Fprintf(w, "data: {\"data\": {\n") + fmt.Fprintf(w, "data: \"value\": 42\n") + fmt.Fprintf(w, "data: }}\n") + fmt.Fprintf(w, "\n") + flusher.Flush() + + fmt.Fprintf(w, "event: complete\ndata:\n\n") + flusher.Flush() + }) + + tr := NewSSETransport(t.Context(), http.DefaultClient, nil) + + handler, receive := collectingHandler() + cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{Endpoint: server.URL, SSEMethod: common.SSEMethodPOST}, handler) + require.NoError(t, err) + defer cancel() + + msg := receive(t, time.Second) + require.NotNil(t, msg.Payload) + // The multi-line data is joined with newlines + assert.Contains(t, string(msg.Payload.Data), "42") + }) + + t.Run("ignores SSE comments", func(t *testing.T) { + t.Parallel() + + server := newSSEServer(t, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + + flusher := w.(http.Flusher) + + // Send some keep-alive comments + fmt.Fprintf(w, ": keep-alive\n") + fmt.Fprintf(w, ": another comment\n") + flusher.Flush() + + fmt.Fprintf(w, "event: next\ndata: {\"data\": {\"value\": 1}}\n\n") + flusher.Flush() + + fmt.Fprintf(w, ": more keep-alive\n") + flusher.Flush() + + fmt.Fprintf(w, "event: complete\ndata:\n\n") + flusher.Flush() + }) + + tr := NewSSETransport(t.Context(), http.DefaultClient, nil) + + handler, _ := collectingHandler() + wrappedHandler, collect := waitForMessages(handler) + cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{Endpoint: server.URL, SSEMethod: common.SSEMethodPOST}, wrappedHandler) + require.NoError(t, err) + defer cancel() + + msgs := collect(time.Second) + + // Should only receive 2 messages (next + complete), not comments + assert.Len(t, msgs, 2) + }) + + t.Run("cancel closes connection", func(t *testing.T) { + t.Parallel() + + serverClosed := make(chan struct{}) + server := newSSEServer(t, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + + flusher := w.(http.Flusher) + fmt.Fprintf(w, "event: next\ndata: {\"data\": {}}\n\n") + flusher.Flush() + + // Wait for client to disconnect + <-r.Context().Done() + close(serverClosed) + }) + + tr := NewSSETransport(t.Context(), http.DefaultClient, nil) + + handler, receive := collectingHandler() + cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{Endpoint: server.URL, SSEMethod: common.SSEMethodPOST}, handler) + require.NoError(t, err) + + // Receive first message + receive(t, time.Second) + + assert.Equal(t, 1, tr.ConnCount()) + + // Cancel should close the connection + cancel() + + select { + case <-serverClosed: + // Good, server detected disconnect + case <-time.After(time.Second): + t.Fatal("server did not detect disconnect") + } + + assert.Eventually(t, func() bool { + return tr.ConnCount() == 0 + }, time.Second, 10*time.Millisecond) + }) + + t.Run("context cancellation stops subscription", func(t *testing.T) { + t.Parallel() + + serverClosed := make(chan struct{}) + server := newSSEServer(t, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + + flusher := w.(http.Flusher) + fmt.Fprintf(w, "event: next\ndata: {\"data\": {}}\n\n") + flusher.Flush() + + <-r.Context().Done() + close(serverClosed) + }) + + transportCtx, transportCancel := context.WithCancel(context.Background()) + + tr := NewSSETransport(transportCtx, http.DefaultClient, nil) + + handler, receive := collectingHandler() + cancel, err := tr.Subscribe(transportCtx, &common.Request{ + Query: "subscription { test }", + }, common.Options{Endpoint: server.URL, SSEMethod: common.SSEMethodPOST}, handler) + require.NoError(t, err) + defer cancel() + + _ = receive(t, time.Second) + + // Cancel context + transportCancel() + + select { + case <-serverClosed: + case <-time.After(10 * time.Second): + t.Fatal("server did not detect context cancellation") + } + }) + + t.Run("handles non-200 response", func(t *testing.T) { + t.Parallel() + + server := newSSEServer(t, func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + }) + + tr := NewSSETransport(t.Context(), http.DefaultClient, nil) + + _, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{Endpoint: server.URL, SSEMethod: common.SSEMethodPOST}, func(_ *common.Message) {}) + + require.Error(t, err) + assert.Contains(t, err.Error(), "401") + }) + + t.Run("handles non-200 with body", func(t *testing.T) { + t.Parallel() + + server := newSSEServer(t, func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte("Internal server error")) + }) + + tr := NewSSETransport(t.Context(), http.DefaultClient, nil) + + _, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{Endpoint: server.URL, SSEMethod: common.SSEMethodPOST}, func(_ *common.Message) {}) + + require.Error(t, err) + assert.Contains(t, err.Error(), "500") + }) + + t.Run("creates separate connection per subscription", func(t *testing.T) { + t.Parallel() + + var reqCount atomic.Int32 + server := newSSEServer(t, func(w http.ResponseWriter, r *http.Request) { + reqCount.Add(1) + + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + + flusher := w.(http.Flusher) + fmt.Fprintf(w, "event: next\ndata: {\"data\": {}}\n\n") + flusher.Flush() + + // Keep connection open + <-r.Context().Done() + }) + + tr := NewSSETransport(t.Context(), http.DefaultClient, nil) + + opts := common.Options{Endpoint: server.URL, SSEMethod: common.SSEMethodPOST} + + handler1, receive1 := collectingHandler() + cancel1, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { a }"}, opts, handler1) + require.NoError(t, err) + + handler2, receive2 := collectingHandler() + cancel2, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { b }"}, opts, handler2) + require.NoError(t, err) + + receive1(t, time.Second) + receive2(t, time.Second) + + // SSE creates separate HTTP requests (no multiplexing) + assert.Equal(t, int32(2), reqCount.Load()) + assert.Equal(t, 2, tr.ConnCount()) + + cancel1() + cancel2() + }) + + t.Run("handles server closing stream", func(t *testing.T) { + t.Parallel() + + server := newSSEServer(t, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + + flusher := w.(http.Flusher) + fmt.Fprintf(w, "event: next\ndata: {\"data\": {\"value\": 1}}\n\n") + flusher.Flush() + + // Server closes without sending complete + }) + + tr := NewSSETransport(t.Context(), http.DefaultClient, nil) + + handler, receive := collectingHandler() + cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{Endpoint: server.URL, SSEMethod: common.SSEMethodPOST}, handler) + require.NoError(t, err) + defer cancel() + + msg := receive(t, time.Second) + assert.NotNil(t, msg.Payload) + + // Server closes without sending complete — this is an abnormal + // disconnection per graphql-sse protocol, delivered as a connection error. + msg = receive(t, time.Second) + assert.Equal(t, common.MessageTypeConnectionError, msg.Type) + assert.Error(t, msg.Err) + }) + + t.Run("handles data without event type", func(t *testing.T) { + t.Parallel() + + server := newSSEServer(t, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + + flusher := w.(http.Flusher) + // Some servers send data without explicit event type + fmt.Fprintf(w, "data: {\"data\": {\"value\": 99}}\n\n") + flusher.Flush() + + fmt.Fprintf(w, "event: complete\ndata:\n\n") + flusher.Flush() + }) + + tr := NewSSETransport(t.Context(), http.DefaultClient, nil) + + handler, receive := collectingHandler() + cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{Endpoint: server.URL, SSEMethod: common.SSEMethodPOST}, handler) + require.NoError(t, err) + defer cancel() + + msg := receive(t, time.Second) + require.NotNil(t, msg.Payload) + assert.Contains(t, string(msg.Payload.Data), "99") + }) +} + +func TestSSETransport_ContextCancellation(t *testing.T) { + t.Parallel() + + t.Run("context cancellation closes all connections", func(t *testing.T) { + t.Parallel() + + var closedCount atomic.Int32 + server := newSSEServer(t, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + + flusher := w.(http.Flusher) + fmt.Fprintf(w, "event: next\ndata: {\"data\": {}}\n\n") + flusher.Flush() + + <-r.Context().Done() + closedCount.Add(1) + }) + + ctx, cancel := context.WithCancel(context.Background()) + tr := NewSSETransport(ctx, http.DefaultClient, nil) + + opts := common.Options{Endpoint: server.URL, SSEMethod: common.SSEMethodPOST} + + handler1, receive1 := collectingHandler() + _, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { a }"}, opts, handler1) + require.NoError(t, err) + + handler2, receive2 := collectingHandler() + _, err = tr.Subscribe(context.Background(), &common.Request{Query: "subscription { b }"}, opts, handler2) + require.NoError(t, err) + + receive1(t, time.Second) + receive2(t, time.Second) + + assert.Equal(t, 2, tr.ConnCount()) + + cancel() + + assert.Eventually(t, func() bool { + return closedCount.Load() == 2 + }, time.Second, 10*time.Millisecond) + + assert.Eventually(t, func() bool { + return tr.ConnCount() == 0 + }, time.Second, 10*time.Millisecond) + }) +} + +func TestSSETransport_CustomClient(t *testing.T) { + t.Parallel() + + t.Run("uses custom http client", func(t *testing.T) { + t.Parallel() + + var customHeaderReceived string + server := newSSEServer(t, func(w http.ResponseWriter, r *http.Request) { + customHeaderReceived = r.Header.Get("X-Custom-Client") + + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, "event: complete\ndata:\n\n") + }) + + // Custom client with transport that adds a header + customClient := &http.Client{ + Transport: &headerTransport{ + base: http.DefaultTransport, + headers: http.Header{ + "X-Custom-Client": []string{"test-client"}, + }, + }, + } + + tr := NewSSETransport(t.Context(), customClient, nil) + + handler, receive := collectingHandler() + cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{Endpoint: server.URL, SSEMethod: common.SSEMethodPOST}, handler) + require.NoError(t, err) + defer cancel() + + receive(t, time.Second) + + assert.Equal(t, "test-client", customHeaderReceived) + }) +} + +// Test helpers + +func newSSEServer(t *testing.T, handler func(w http.ResponseWriter, r *http.Request)) *httptest.Server { + t.Helper() + + server := httptest.NewServer(http.HandlerFunc(handler)) + t.Cleanup(server.Close) + + return server +} + +// headerTransport is a custom RoundTripper that adds headers to requests +type headerTransport struct { + base http.RoundTripper + headers http.Header +} + +func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) { + maps.Copy(req.Header, t.headers) + return t.base.RoundTrip(req) +} + +func TestSSETransport_ContentTypeValidation(t *testing.T) { + t.Parallel() + + t.Run("accepts text/event-stream with charset", func(t *testing.T) { + t.Parallel() + + server := newSSEServer(t, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream; charset=utf-8") + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, "event: complete\ndata:\n\n") + }) + + tr := NewSSETransport(t.Context(), http.DefaultClient, nil) + + handler, receive := collectingHandler() + cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{Endpoint: server.URL, SSEMethod: common.SSEMethodPOST}, handler) + require.NoError(t, err) + defer cancel() + + msg := receive(t, time.Second) + assert.Equal(t, common.MessageTypeComplete, msg.Type) + }) + + t.Run("rejects non-SSE content type", func(t *testing.T) { + t.Parallel() + + server := newSSEServer(t, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"error": "not sse"}`)) + }) + + tr := NewSSETransport(t.Context(), http.DefaultClient, nil) + + _, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{Endpoint: server.URL, SSEMethod: common.SSEMethodPOST}, func(_ *common.Message) {}) + + require.Error(t, err) + assert.True(t, strings.Contains(err.Error(), "content-type") || strings.Contains(err.Error(), "Content-Type")) + }) +} + +func TestSSETransport_GETMethod(t *testing.T) { + t.Parallel() + + t.Run("sends GET request with query parameters", func(t *testing.T) { + t.Parallel() + + var receivedMethod string + var receivedQuery string + var receivedVariables string + var receivedOperationName string + server := newSSEServer(t, func(w http.ResponseWriter, r *http.Request) { + receivedMethod = r.Method + receivedQuery = r.URL.Query().Get("query") + receivedVariables = r.URL.Query().Get("variables") + receivedOperationName = r.URL.Query().Get("operationName") + + // Verify no body for GET + body, _ := io.ReadAll(r.Body) + assert.Empty(t, body) + + // Verify headers + assert.Equal(t, "text/event-stream", r.Header.Get("Accept")) + assert.Equal(t, "no-cache", r.Header.Get("Cache-Control")) + assert.Empty(t, r.Header.Get("Content-Type")) // No content-type for GET + + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + + flusher := w.(http.Flusher) + fmt.Fprintf(w, "event: next\ndata: {\"data\": {\"value\": 42}}\n\n") + flusher.Flush() + + fmt.Fprintf(w, "event: complete\ndata:\n\n") + flusher.Flush() + }) + + tr := NewSSETransport(t.Context(), http.DefaultClient, nil) + + handler, receive := collectingHandler() + cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + Variables: []byte(`{"id": 123}`), + OperationName: "TestSub", + }, common.Options{ + Endpoint: server.URL, + SSEMethod: common.SSEMethodGET, + }, handler) + require.NoError(t, err) + defer cancel() + + // Verify GET method and query params + assert.Equal(t, http.MethodGet, receivedMethod) + assert.Equal(t, "subscription { test }", receivedQuery) + assert.Equal(t, `{"id":123}`, receivedVariables) + assert.Equal(t, "TestSub", receivedOperationName) + + // Receive data message + msg := receive(t, time.Second) + require.NotNil(t, msg.Payload) + assert.Contains(t, string(msg.Payload.Data), "42") + + // Receive complete message + msg = receive(t, time.Second) + assert.Equal(t, common.MessageTypeComplete, msg.Type) + }) + + t.Run("GET preserves existing query parameters", func(t *testing.T) { + t.Parallel() + + var receivedToken string + var receivedQuery string + server := newSSEServer(t, func(w http.ResponseWriter, r *http.Request) { + receivedToken = r.URL.Query().Get("token") + receivedQuery = r.URL.Query().Get("query") + + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, "event: complete\ndata:\n\n") + }) + + tr := NewSSETransport(t.Context(), http.DefaultClient, nil) + + handler, receive := collectingHandler() + cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{ + Endpoint: server.URL + "?token=abc123", + SSEMethod: common.SSEMethodGET, + }, handler) + require.NoError(t, err) + defer cancel() + + receive(t, time.Second) + + assert.Equal(t, "abc123", receivedToken) + assert.Equal(t, "subscription { test }", receivedQuery) + }) + + t.Run("GET passes custom headers", func(t *testing.T) { + t.Parallel() + + var receivedAuth string + server := newSSEServer(t, func(w http.ResponseWriter, r *http.Request) { + receivedAuth = r.Header.Get("Authorization") + + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, "event: complete\ndata:\n\n") + }) + + tr := NewSSETransport(t.Context(), http.DefaultClient, nil) + + headers := http.Header{ + "Authorization": []string{"Bearer token123"}, + } + + handler, receive := collectingHandler() + cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{ + Endpoint: server.URL, + SSEMethod: common.SSEMethodGET, + Headers: headers, + }, handler) + require.NoError(t, err) + defer cancel() + + receive(t, time.Second) + + assert.Equal(t, "Bearer token123", receivedAuth) + }) + + t.Run("GET omits empty variables and operationName", func(t *testing.T) { + t.Parallel() + + var hasVariables bool + var hasOperationName bool + server := newSSEServer(t, func(w http.ResponseWriter, r *http.Request) { + hasVariables = r.URL.Query().Has("variables") + hasOperationName = r.URL.Query().Has("operationName") + + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, "event: complete\ndata:\n\n") + }) + + tr := NewSSETransport(t.Context(), http.DefaultClient, nil) + + handler, receive := collectingHandler() + cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + // No variables or operationName + }, common.Options{ + Endpoint: server.URL, + SSEMethod: common.SSEMethodGET, + }, handler) + require.NoError(t, err) + defer cancel() + + receive(t, time.Second) + + assert.False(t, hasVariables, "variables should not be in query params") + assert.False(t, hasOperationName, "operationName should not be in query params") + }) +} + +func TestSSETransport_Subscribe_UnrecognizedMethod(t *testing.T) { + t.Parallel() + + t.Run("returns error for unrecognized SSE method", func(t *testing.T) { + t.Parallel() + + tr := NewSSETransport(t.Context(), http.DefaultClient, nil) + + _, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{ + Endpoint: "http://example.invalid", + SSEMethod: common.SSEMethod("PATCH"), + }, func(_ *common.Message) {}) + + require.Error(t, err) + assert.Contains(t, err.Error(), "unsupported SSE method") + }) + + t.Run("returns error when SSE method is empty", func(t *testing.T) { + t.Parallel() + + tr := NewSSETransport(t.Context(), http.DefaultClient, nil) + + _, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{ + Endpoint: "http://example.invalid", + }, func(_ *common.Message) {}) + + require.Error(t, err) + assert.Contains(t, err.Error(), "unsupported SSE method") + }) +} diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/transport_test.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/transport_test.go new file mode 100644 index 0000000000..afb47bb390 --- /dev/null +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/transport_test.go @@ -0,0 +1,103 @@ +package transport + +import ( + "net/http" + "sync" + "testing" + "time" + + "github.com/coder/websocket" + "github.com/jensneuse/abstractlogger" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol" +) + +// newTestWSConnection mirrors the defaults applied by client.New so that tests +// can continue to pass wsConnectionOptions{...} literals without repeating +// default values at every call site. +func newTestWSConnection(t *testing.T, conn *websocket.Conn, proto protocol.Protocol, opts wsConnectionOptions) *wsConnection { + t.Helper() + if opts.logger == nil { + opts.logger = abstractlogger.NoopLogger + } + if opts.writeTimeout <= 0 { + opts.writeTimeout = 5 * time.Second + } + return newWSConnection(conn, proto, opts) +} + +// newTestWSTransport mirrors the defaults applied by client.New so that tests +// can continue to pass WSTransportOptions{...} literals without repeating +// default values at every call site. +func newTestWSTransport(t *testing.T, opts WSTransportOptions) *WSTransport { + t.Helper() + if opts.UpgradeClient == nil { + opts.UpgradeClient = http.DefaultClient + } + if opts.Logger == nil { + opts.Logger = abstractlogger.NoopLogger + } + if opts.ReadLimit <= 0 { + opts.ReadLimit = 1 << 20 + } + if opts.AckTimeout <= 0 { + opts.AckTimeout = 30 * time.Second + } + if opts.WriteTimeout <= 0 { + opts.WriteTimeout = 5 * time.Second + } + return NewWSTransport(t.Context(), opts) +} + +// collectingHandler returns a handler that appends messages to a channel, +// plus a helper to receive with timeout (for use in tests). +func collectingHandler() (common.Handler, func(t *testing.T, timeout time.Duration) *common.Message) { + ch := make(chan *common.Message, 64) + handler := func(msg *common.Message) { + ch <- msg + } + receive := func(t *testing.T, timeout time.Duration) *common.Message { + t.Helper() + select { + case msg := <-ch: + return msg + case <-time.After(timeout): + t.Fatal("timeout waiting for message") + return nil + } + } + return handler, receive +} + +// waitForMessages collects messages from a handler until a terminal message or timeout. +func waitForMessages(handler common.Handler) (common.Handler, func(timeout time.Duration) []*common.Message) { + var mu sync.Mutex + var msgs []*common.Message + done := make(chan struct{}, 1) + + wrappedHandler := func(msg *common.Message) { + mu.Lock() + msgs = append(msgs, msg) + mu.Unlock() + handler(msg) + if msg.Type.IsTerminal() { + select { + case done <- struct{}{}: + default: + } + } + } + + collect := func(timeout time.Duration) []*common.Message { + select { + case <-done: + case <-time.After(timeout): + } + mu.Lock() + defer mu.Unlock() + return append([]*common.Message{}, msgs...) + } + + return wrappedHandler, collect +} diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn.go new file mode 100644 index 0000000000..6f7c6a71c6 --- /dev/null +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn.go @@ -0,0 +1,284 @@ +package transport + +import ( + "context" + "errors" + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/coder/websocket" + "github.com/jensneuse/abstractlogger" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol" +) + +var ErrSubscriptionExists = errors.New("subscription ID already exists") + +type wsConnectionOptions struct { + logger abstractlogger.Logger + writeTimeout time.Duration + idleTimeout time.Duration + onEmpty func() +} + +type wsConnection struct { + conn *websocket.Conn + protocol protocol.Protocol + log abstractlogger.Logger + + // cancel cancels the connection-scoped context, unblocking readLoop and + // any in-flight writes. It is called exactly once inside shutdown(). + cancel context.CancelFunc + ctx context.Context + + subsMu sync.RWMutex + subs map[string]common.Handler + + closed atomic.Bool + + onEmpty func() + idleTimeout time.Duration + + writeTimeout time.Duration + + // Ping/pong tracking for client-initiated heartbeats. + // Values stored as UnixNano timestamps. + lastPingSentAt atomic.Int64 + lastPongAt atomic.Int64 +} + +func newWSConnection(conn *websocket.Conn, proto protocol.Protocol, opts wsConnectionOptions) *wsConnection { + if opts.logger == nil { + opts.logger = abstractlogger.NoopLogger + } + + ctx, cancel := context.WithCancel(context.Background()) + + c := &wsConnection{ + conn: conn, + protocol: proto, + log: opts.logger, + cancel: cancel, + ctx: ctx, + subs: make(map[string]common.Handler), + onEmpty: opts.onEmpty, + + writeTimeout: opts.writeTimeout, + idleTimeout: opts.idleTimeout, + } + + c.lastPongAt.Store(time.Now().UnixNano()) + + return c +} + +// subscribe registers handler for a new subscription on this connection and +// sends the protocol-level subscribe message. The returned function unsubscribes +// and, if this was the last subscription, triggers the idle-close flow. +func (c *wsConnection) subscribe(ctx context.Context, id string, req *common.Request, handler common.Handler) (func(), error) { + c.subsMu.Lock() + + if c.closed.Load() { + c.subsMu.Unlock() + return nil, common.ErrConnectionClosed + } + + if _, exists := c.subs[id]; exists { + c.subsMu.Unlock() + return nil, ErrSubscriptionExists + } + + c.subs[id] = handler + c.subsMu.Unlock() + + subscribeCtx, subscribeCancel := context.WithTimeout(ctx, c.writeTimeout) + defer subscribeCancel() + + if err := c.protocol.Subscribe(subscribeCtx, c.conn, id, req); err != nil { + c.log.Error("wsConnection.Subscribe", + abstractlogger.String("id", id), + abstractlogger.Error(err), + ) + c.removeSub(id) + return nil, err + } + + c.log.Debug("wsConnection.Subscribe", + abstractlogger.String("id", id), + abstractlogger.String("status", "subscribed"), + ) + + cancel := func() { c.unsubscribe(id) } + + return cancel, nil +} + +func (c *wsConnection) removeSub(id string) { + c.subsMu.Lock() + delete(c.subs, id) + isEmpty := len(c.subs) == 0 + c.subsMu.Unlock() + + if isEmpty { + if c.idleTimeout > 0 { + time.AfterFunc(c.idleTimeout, func() { + c.subsMu.RLock() + stillEmpty := len(c.subs) == 0 + c.subsMu.RUnlock() + if stillEmpty { + c.closeConn() + } + }) + } else { + c.closeConn() + } + } +} + +func (c *wsConnection) unsubscribe(id string) { + c.subsMu.Lock() + _, exists := c.subs[id] + c.subsMu.Unlock() + + if !exists { + return + } + + c.log.Debug("wsConnection.unsubscribe", abstractlogger.String("id", id)) + + unsubscribeCtx, cancel := context.WithTimeout(context.Background(), c.writeTimeout) + defer cancel() + + _ = c.protocol.Unsubscribe(unsubscribeCtx, c.conn, id) + + c.removeSub(id) +} + +func (c *wsConnection) readLoop() { + defer c.shutdown(errors.New("read loop exited")) + + for { + if c.closed.Load() { + return + } + + msg, err := c.protocol.Read(c.ctx, c.conn) + if err != nil { + c.log.Debug("wsConnection.ReadLoop", + abstractlogger.String("status", "error"), + abstractlogger.Error(err), + ) + c.shutdown(fmt.Errorf("%w: read: %w", common.ErrConnectionClosed, err)) + return + } + + switch msg.Type { + case protocol.MessagePing: + c.log.Debug("wsConnection.ReadLoop", abstractlogger.String("message", "ping")) + if pinger, ok := c.protocol.(protocol.Pinger); ok { + pongCtx, cancel := context.WithTimeout(c.ctx, c.writeTimeout) + _ = pinger.Pong(pongCtx, c.conn) + cancel() + } + case protocol.MessagePong: + c.lastPongAt.Store(time.Now().UnixNano()) + c.log.Debug("wsConnection.ReadLoop", abstractlogger.String("message", "pong")) + case protocol.MessageData, protocol.MessageError, protocol.MessageComplete: + c.dispatch(msg) + } + } +} + +func (c *wsConnection) dispatch(msg *protocol.WireMessage) { + c.subsMu.RLock() + handler, exists := c.subs[msg.ID] + c.subsMu.RUnlock() + + if !exists { + return + } + + handler(msg.IntoClientMessage()) + + if msg.Type == protocol.MessageComplete || msg.Type == protocol.MessageError { + c.removeSub(msg.ID) + } +} + +func (c *wsConnection) shutdown(err error) { + if !c.closed.CompareAndSwap(false, true) { + return + } + + c.log.Debug("wsConnection.shutdown", + abstractlogger.Error(err), + ) + + c.conn.Close(websocket.StatusNormalClosure, "shutdown") + + c.subsMu.Lock() + subs := c.subs + c.subs = make(map[string]common.Handler) + c.subsMu.Unlock() + + errMsg := &common.Message{Type: common.MessageTypeConnectionError, Err: err} + for _, handler := range subs { + handler(errMsg) + } + + // Cancel after dispatching errors so readLoop consumers still have a live + // context when they receive the error message. + c.cancel() + + if c.onEmpty != nil { + c.onEmpty() + } +} + +func (c *wsConnection) closeConn() { + c.shutdown(common.ErrConnectionClosed) +} + +func (c *wsConnection) subCount() int { + c.subsMu.RLock() + defer c.subsMu.RUnlock() + return len(c.subs) +} + +// sendPing sends a protocol-level ping message and records the timestamp. +// For protocols that don't implement Pinger (e.g. legacy graphql-ws), +// this is a no-op — lastPingSentAt stays zero so pongOverdue never triggers. +func (c *wsConnection) sendPing() error { + pinger, ok := c.protocol.(protocol.Pinger) + if !ok { + return nil + } + + pingCtx, cancel := context.WithTimeout(c.ctx, c.writeTimeout) + defer cancel() + + err := pinger.Ping(pingCtx, c.conn) + if err != nil { + return err + } + + c.lastPingSentAt.Store(time.Now().UnixNano()) + return nil +} + +// pongOverdue returns true if a pong has not been received since the last ping +// and the ping timeout has elapsed. +func (c *wsConnection) pongOverdue(timeout time.Duration) bool { + pingSent := c.lastPingSentAt.Load() + if pingSent == 0 { + return false + } + return c.lastPongAt.Load() < pingSent && time.Since(time.Unix(0, pingSent)) > timeout +} + +func (c *wsConnection) isClosed() bool { + return c.closed.Load() +} diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn_test.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn_test.go new file mode 100644 index 0000000000..431180d125 --- /dev/null +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_conn_test.go @@ -0,0 +1,726 @@ +package transport + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "runtime" + "strings" + "sync" + "testing" + "time" + + "github.com/coder/websocket" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol" +) + +func TestWSConnection_Subscribe(t *testing.T) { + t.Parallel() + + t.Run("calls protocol subscribe and handler can receive", func(t *testing.T) { + t.Parallel() + + conn, _ := newTestConn(t) + proto := newMockProtocol() + wsc := newTestWSConnection(t, conn, proto, wsConnectionOptions{}) + + handler, _ := collectingHandler() + cancel, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{ + Query: "subscription { test }", + }, handler) + require.NoError(t, err) + defer cancel() + assert.Len(t, proto.SubscribeCalls(), 1) + assert.Equal(t, "sub-1", proto.SubscribeCalls()[0].ID) + }) + + t.Run("returns error for duplicate subscription id", func(t *testing.T) { + t.Parallel() + + conn, _ := newTestConn(t) + proto := newMockProtocol() + wsc := newTestWSConnection(t, conn, proto, wsConnectionOptions{}) + + cancel, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}, func(_ *common.Message) {}) + require.NoError(t, err) + defer cancel() + + _, err = wsc.subscribe(context.Background(), "sub-1", &common.Request{}, func(_ *common.Message) {}) + + assert.ErrorIs(t, err, ErrSubscriptionExists) + }) + + t.Run("returns error when connection is closed", func(t *testing.T) { + t.Parallel() + + conn, _ := newTestConn(t) + proto := newMockProtocol() + wsc := newTestWSConnection(t, conn, proto, wsConnectionOptions{}) + wsc.closeConn() + + _, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}, func(_ *common.Message) {}) + + assert.ErrorIs(t, err, common.ErrConnectionClosed) + }) + + t.Run("returns ErrConnectionClosed when shutdown races before lock acquisition", func(t *testing.T) { + t.Parallel() + + // Test for TOCTOU between closed.Load() and subsMu.Lock() + // in subscribe(). + // + // Without a closed re-check under the lock, this sequence is possible: + // 1. subscribe: closed.Load() → false (check) + // 2. shutdown: closed.CAS(false,true) (invalidates the check) + // 3. shutdown: swaps c.subs, dispatches errors to old handlers + // 4. subscribe: subsMu.Lock(), adds handler to NEW empty map (use) + // 5. subscribe: protocol.Subscribe → nil (mock doesn't check conn) + // 6. subscribe returns (cancel, nil) — looks successful + // + // The handler is now orphaned: closed=true prevents any future + // shutdown from running (CAS fails), so the handler will never + // receive a terminal message. This test forces the exact interleaving + // deterministically. + + conn, _ := newTestConn(t) + proto := newMockProtocol() + wsc := newTestWSConnection(t, conn, proto, wsConnectionOptions{}) + + // Step 1: Hold subsMu so subscribe() blocks after its closed check. + wsc.subsMu.Lock() + + subscribeResult := make(chan error, 1) + go func() { + // Passes closed.Load() (still false), then blocks on subsMu.Lock(). + _, err := wsc.subscribe(context.Background(), "sub-race", &common.Request{}, func(_ *common.Message) {}) + subscribeResult <- err + }() + + // Let the goroutine reach the blocked Lock(). + runtime.Gosched() + time.Sleep(5 * time.Millisecond) + + // Step 2: Simulate what shutdown does first — set closed=true. + // We set it directly rather than calling shutdown() because shutdown + // also needs subsMu (which we hold), and we want to control ordering. + wsc.closed.Store(true) + + // Step 3: Release the lock. subscribe() now enters the critical + // section with a stale view: it saw closed=false, but closed is now true. + wsc.subsMu.Unlock() + + select { + case err := <-subscribeResult: + // With the fix: subscribe re-checks closed under the lock → ErrConnectionClosed. + // Without the fix: subscribe succeeds (nil error) — the handler is orphaned + // because closed=true means no future shutdown() can deliver a terminal message. + require.ErrorIs(t, err, common.ErrConnectionClosed, + "subscribe must detect closed state under the lock; without this check "+ + "the handler is orphaned (closed=true prevents future shutdown)") + case <-time.After(time.Second): + t.Fatal("subscribe did not return within timeout") + } + }) + + t.Run("returns error when protocol subscribe fails", func(t *testing.T) { + t.Parallel() + + conn, _ := newTestConn(t) + proto := newMockProtocol() + proto.subscribeErr = assert.AnError + wsc := newTestWSConnection(t, conn, proto, wsConnectionOptions{}) + + _, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}, func(_ *common.Message) {}) + + assert.Error(t, err) + assert.Equal(t, 0, wsc.subCount(), "failed subscription should not be registered") + }) +} + +func TestWSConnection_ReadLoop(t *testing.T) { + t.Parallel() + + t.Run("dispatches data message to subscription handler", func(t *testing.T) { + t.Parallel() + + conn, _ := newTestConn(t) + proto := newMockProtocol() + wsc := newTestWSConnection(t, conn, proto, wsConnectionOptions{}) + + handler, receive := collectingHandler() + cancel, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}, handler) + require.NoError(t, err) + defer cancel() + + go wsc.readLoop() + + proto.PushMessage(&protocol.WireMessage{ + ID: "sub-1", + Type: protocol.MessageData, + Payload: &common.ExecutionResult{Data: json.RawMessage(`{"value": 42}`)}, + }) + + msg := receive(t, time.Second) + require.NotNil(t, msg.Payload) + assert.Contains(t, string(msg.Payload.Data), "42") + }) + + t.Run("delivers complete message to handler", func(t *testing.T) { + t.Parallel() + + conn, _ := newTestConn(t) + proto := newMockProtocol() + wsc := newTestWSConnection(t, conn, proto, wsConnectionOptions{}) + + handler, receive := collectingHandler() + _, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}, handler) + require.NoError(t, err) + + go wsc.readLoop() + + proto.PushMessage(&protocol.WireMessage{ + ID: "sub-1", + Type: protocol.MessageComplete, + }) + + msg := receive(t, time.Second) + assert.Equal(t, common.MessageTypeComplete, msg.Type) + }) + + t.Run("responds to ping with pong", func(t *testing.T) { + t.Parallel() + + conn, _ := newTestConn(t) + proto := newMockProtocol() + wsc := newTestWSConnection(t, conn, proto, wsConnectionOptions{}) + + go wsc.readLoop() + + proto.PushMessage(&protocol.WireMessage{Type: protocol.MessagePing}) + + assert.Eventually(t, func() bool { + return proto.PongCount() > 0 + }, time.Second, 10*time.Millisecond) + }) + + t.Run("ignores messages for unknown subscription ids", func(t *testing.T) { + t.Parallel() + + conn, _ := newTestConn(t) + proto := newMockProtocol() + wsc := newTestWSConnection(t, conn, proto, wsConnectionOptions{}) + + handler, receive := collectingHandler() + cancel, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}, handler) + require.NoError(t, err) + defer cancel() + + go wsc.readLoop() + + proto.PushMessage(&protocol.WireMessage{ + ID: "unknown-sub", + Type: protocol.MessageData, + Payload: &common.ExecutionResult{Data: json.RawMessage(`{"wrong": true}`)}, + }) + + proto.PushMessage(&protocol.WireMessage{ + ID: "sub-1", + Type: protocol.MessageData, + Payload: &common.ExecutionResult{Data: json.RawMessage(`{"right": true}`)}, + }) + + msg := receive(t, time.Second) + assert.Contains(t, string(msg.Payload.Data), "right") + }) +} + +func TestWSConnection_Unsubscribe(t *testing.T) { + t.Parallel() + + t.Run("calls protocol unsubscribe and removes subscription", func(t *testing.T) { + t.Parallel() + + conn, _ := newTestConn(t) + proto := newMockProtocol() + wsc := newTestWSConnection(t, conn, proto, wsConnectionOptions{}) + + cancel, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}, func(_ *common.Message) {}) + require.NoError(t, err) + + cancel() + + assert.Len(t, proto.UnsubscribeCalls(), 1) + assert.Equal(t, "sub-1", proto.UnsubscribeCalls()[0]) + assert.Equal(t, 0, wsc.subCount()) + }) + + t.Run("is idempotent", func(t *testing.T) { + t.Parallel() + + conn, _ := newTestConn(t) + proto := newMockProtocol() + wsc := newTestWSConnection(t, conn, proto, wsConnectionOptions{}) + + cancel, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}, func(_ *common.Message) {}) + require.NoError(t, err) + + cancel() + cancel() + cancel() + + assert.Len(t, proto.UnsubscribeCalls(), 1) + }) + + t.Run("times out using WriteTimeout", func(t *testing.T) { + t.Parallel() + + conn, _ := newTestConn(t) + proto := newMockProtocol() + proto.unsubscribeDelay = 500 * time.Millisecond + wsc := newTestWSConnection(t, conn, proto, wsConnectionOptions{ + writeTimeout: 50 * time.Millisecond, + }) + + cancel, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}, func(_ *common.Message) {}) + require.NoError(t, err) + + start := time.Now() + cancel() + elapsed := time.Since(start) + + assert.Less(t, elapsed, 200*time.Millisecond) + }) +} + +func TestWSConnection_OnEmpty(t *testing.T) { + t.Parallel() + + t.Run("calls callback when last subscription removed", func(t *testing.T) { + t.Parallel() + + conn, _ := newTestConn(t) + proto := newMockProtocol() + + emptyCalled := make(chan struct{}, 1) + wsc := newTestWSConnection(t, conn, proto, wsConnectionOptions{ + onEmpty: func() { emptyCalled <- struct{}{} }, + }) + + cancel, _ := wsc.subscribe(context.Background(), "sub-1", &common.Request{}, func(_ *common.Message) {}) + cancel() + + select { + case <-emptyCalled: + // success + case <-time.After(100 * time.Millisecond): + t.Error("onEmpty callback not called") + } + + assert.True(t, wsc.isClosed(), "connection should be closed after last subscription removed") + }) + + t.Run("does not call callback when subscriptions remain", func(t *testing.T) { + t.Parallel() + + conn, _ := newTestConn(t) + proto := newMockProtocol() + + emptyCalled := make(chan struct{}, 1) + wsc := newTestWSConnection(t, conn, proto, wsConnectionOptions{ + onEmpty: func() { emptyCalled <- struct{}{} }, + }) + + cancel1, _ := wsc.subscribe(context.Background(), "sub-1", &common.Request{}, func(_ *common.Message) {}) + cancel2, _ := wsc.subscribe(context.Background(), "sub-2", &common.Request{}, func(_ *common.Message) {}) + + cancel1() + + select { + case <-emptyCalled: + t.Error("onEmpty should not be called when subscriptions remain") + case <-time.After(100 * time.Millisecond): + // success + } + + cancel2() + + select { + case <-emptyCalled: + // success + case <-time.After(100 * time.Millisecond): + t.Error("onEmpty should be called after last subscription removed") + } + }) + + t.Run("calls callback on direct Close", func(t *testing.T) { + t.Parallel() + + conn, _ := newTestConn(t) + proto := newMockProtocol() + + emptyCalled := make(chan struct{}, 1) + wsc := newTestWSConnection(t, conn, proto, wsConnectionOptions{ + onEmpty: func() { emptyCalled <- struct{}{} }, + }) + + wsc.closeConn() + + select { + case <-emptyCalled: + // success + case <-time.After(100 * time.Millisecond): + t.Error("onEmpty callback not called on Close") + } + }) + + t.Run("calls callback on read loop exit", func(t *testing.T) { + t.Parallel() + + conn, _ := newTestConn(t) + proto := newMockProtocol() + + emptyCalled := make(chan struct{}, 1) + wsc := newTestWSConnection(t, conn, proto, wsConnectionOptions{ + onEmpty: func() { emptyCalled <- struct{}{} }, + }) + + go wsc.readLoop() + + // Close the connection to cause the read loop to exit + wsc.closeConn() + + select { + case <-emptyCalled: + // success + case <-time.After(time.Second): + t.Error("onEmpty callback not called on read loop exit") + } + }) +} + +func TestWSConnection_IdleTimeout(t *testing.T) { + t.Parallel() + + t.Run("removeSub defers close for idle timeout duration when subs are empty", func(t *testing.T) { + t.Parallel() + + conn, _ := newTestConn(t) + proto := newMockProtocol() + + wsc := newTestWSConnection(t, conn, proto, wsConnectionOptions{ + idleTimeout: 200 * time.Millisecond, + }) + + cancel, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}, func(_ *common.Message) {}) + require.NoError(t, err) + + cancel() + + assert.False(t, wsc.isClosed(), "connection should not be closed immediately") + + assert.Eventually(t, func() bool { + return wsc.isClosed() + }, time.Second, 10*time.Millisecond, "connection should close after idle timeout") + }) + + t.Run("removeSub does not close when new subscription arrives before timeout", func(t *testing.T) { + t.Parallel() + + conn, _ := newTestConn(t) + proto := newMockProtocol() + + wsc := newTestWSConnection(t, conn, proto, wsConnectionOptions{ + idleTimeout: 200 * time.Millisecond, + }) + + cancel1, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}, func(_ *common.Message) {}) + require.NoError(t, err) + + cancel1() // starts idle timer + + _, err = wsc.subscribe(context.Background(), "sub-2", &common.Request{}, func(_ *common.Message) {}) + require.NoError(t, err) + + time.Sleep(300 * time.Millisecond) + + assert.False(t, wsc.isClosed(), "connection should stay open while subscription exists") + }) + + t.Run("removeSub closes immediately when idle timeout is zero", func(t *testing.T) { + t.Parallel() + + conn, _ := newTestConn(t) + proto := newMockProtocol() + + emptyCalled := make(chan struct{}, 1) + wsc := newTestWSConnection(t, conn, proto, wsConnectionOptions{ + onEmpty: func() { emptyCalled <- struct{}{} }, + }) + + cancel, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}, func(_ *common.Message) {}) + require.NoError(t, err) + + cancel() + + select { + case <-emptyCalled: + // success + case <-time.After(100 * time.Millisecond): + t.Error("connection should close immediately with zero idle timeout") + } + + assert.True(t, wsc.isClosed()) + }) + +} + +func TestWSConnection_Close(t *testing.T) { + t.Parallel() + + t.Run("notifies all subscriptions with error", func(t *testing.T) { + t.Parallel() + + conn, _ := newTestConn(t) + proto := newMockProtocol() + wsc := newTestWSConnection(t, conn, proto, wsConnectionOptions{}) + + handler1, receive1 := collectingHandler() + _, _ = wsc.subscribe(context.Background(), "sub-1", &common.Request{}, handler1) + + handler2, receive2 := collectingHandler() + _, _ = wsc.subscribe(context.Background(), "sub-2", &common.Request{}, handler2) + + wsc.closeConn() + + msg1 := receive1(t, 100*time.Millisecond) + assert.Error(t, msg1.Err) + assert.Equal(t, common.MessageTypeConnectionError, msg1.Type) + + msg2 := receive2(t, 100*time.Millisecond) + assert.Error(t, msg2.Err) + assert.Equal(t, common.MessageTypeConnectionError, msg2.Type) + }) + + t.Run("is idempotent", func(t *testing.T) { + t.Parallel() + + conn, _ := newTestConn(t) + proto := newMockProtocol() + wsc := newTestWSConnection(t, conn, proto, wsConnectionOptions{}) + + assert.NotPanics(t, func() { + wsc.closeConn() + wsc.closeConn() + wsc.closeConn() + }) + }) +} + +func TestWSConnection_SubCount(t *testing.T) { + t.Parallel() + + t.Run("tracks subscription count accurately", func(t *testing.T) { + t.Parallel() + + conn, _ := newTestConn(t) + proto := newMockProtocol() + wsc := newTestWSConnection(t, conn, proto, wsConnectionOptions{}) + + assert.Equal(t, 0, wsc.subCount()) + + cancel1, _ := wsc.subscribe(context.Background(), "sub-1", &common.Request{}, func(_ *common.Message) {}) + assert.Equal(t, 1, wsc.subCount()) + + cancel2, _ := wsc.subscribe(context.Background(), "sub-2", &common.Request{}, func(_ *common.Message) {}) + assert.Equal(t, 2, wsc.subCount()) + + cancel1() + assert.Equal(t, 1, wsc.subCount()) + + cancel2() + assert.Equal(t, 0, wsc.subCount()) + }) +} + +func TestWSConnection_WriteTimeout(t *testing.T) { + t.Parallel() + + t.Run("pong write respects WriteTimeout", func(t *testing.T) { + t.Parallel() + + conn, _ := newTestConn(t) + proto := newMockProtocol() + proto.pongDelay = 500 * time.Millisecond + wsc := newTestWSConnection(t, conn, proto, wsConnectionOptions{ + writeTimeout: 50 * time.Millisecond, + }) + + handler, receive := collectingHandler() + cancel, err := wsc.subscribe(context.Background(), "sub-1", &common.Request{}, handler) + require.NoError(t, err) + defer cancel() + + go wsc.readLoop() + + // Send ping (will trigger slow pong) + proto.PushMessage(&protocol.WireMessage{Type: protocol.MessagePing}) + + // Send data message right after + proto.PushMessage(&protocol.WireMessage{ + ID: "sub-1", + Type: protocol.MessageData, + Payload: &common.ExecutionResult{Data: json.RawMessage(`{"test": true}`)}, + }) + + // Should receive data within timeout + small buffer + // If pong blocked for 500ms, this would timeout + msg := receive(t, 150*time.Millisecond) + assert.NotNil(t, msg.Payload) + }) +} + +// Test helpers + +func newTestConn(t *testing.T) (*websocket.Conn, *websocket.Conn) { + t.Helper() + + serverConn := make(chan *websocket.Conn, 1) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, nil) + require.NoError(t, err) + serverConn <- conn + + for { + _, _, err := conn.Read(r.Context()) + if err != nil { + conn.Close(websocket.StatusNormalClosure, "shutdown") + return + } + } + })) + + t.Cleanup(server.Close) + + url := "ws" + strings.TrimPrefix(server.URL, "http") + clientConn, _, err := websocket.Dial(context.Background(), url, nil) //nolint:bodyclose + require.NoError(t, err) + + t.Cleanup(func() { + clientConn.Close(websocket.StatusNormalClosure, "shutdown") + }) + + srvConn := <-serverConn + t.Cleanup(func() { _ = srvConn.CloseNow() }) + + return clientConn, srvConn +} + +// mockProtocol implements protocol.Protocol for testing. +type mockProtocol struct { + mu sync.Mutex + subscribeCalls []subscribeCall + unsubCalls []string + pongCount int + subscribeErr error + unsubscribeDelay time.Duration + pongDelay time.Duration + + messages chan *protocol.WireMessage +} + +type subscribeCall struct { + ID string + Req *common.Request +} + +func newMockProtocol() *mockProtocol { + return &mockProtocol{ + messages: make(chan *protocol.WireMessage, 100), + } +} + +func (m *mockProtocol) Subprotocol() string { return "graphql-transport-ws" } + +func (m *mockProtocol) Init(ctx context.Context, conn *websocket.Conn, payload map[string]any) error { + return nil +} + +func (m *mockProtocol) Subscribe(ctx context.Context, conn *websocket.Conn, id string, req *common.Request) error { + m.mu.Lock() + defer m.mu.Unlock() + m.subscribeCalls = append(m.subscribeCalls, subscribeCall{ID: id, Req: req}) + return m.subscribeErr +} + +func (m *mockProtocol) Unsubscribe(ctx context.Context, conn *websocket.Conn, id string) error { + if m.unsubscribeDelay > 0 { + select { + case <-time.After(m.unsubscribeDelay): + case <-ctx.Done(): + return ctx.Err() + } + } + + m.mu.Lock() + defer m.mu.Unlock() + m.unsubCalls = append(m.unsubCalls, id) + return nil +} + +func (m *mockProtocol) Read(ctx context.Context, conn *websocket.Conn) (*protocol.WireMessage, error) { + select { + case msg := <-m.messages: + return msg, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +// Ping and Pong implement protocol.Pinger — the mock simulates graphql-transport-ws. + +func (m *mockProtocol) Ping(ctx context.Context, conn *websocket.Conn) error { + return nil +} + +func (m *mockProtocol) Pong(ctx context.Context, conn *websocket.Conn) error { + if m.pongDelay > 0 { + select { + case <-time.After(m.pongDelay): + case <-ctx.Done(): + return ctx.Err() + } + } + + m.mu.Lock() + defer m.mu.Unlock() + m.pongCount++ + return nil +} + +func (m *mockProtocol) PushMessage(msg *protocol.WireMessage) { + m.messages <- msg +} + +func (m *mockProtocol) SubscribeCalls() []subscribeCall { + m.mu.Lock() + defer m.mu.Unlock() + return append([]subscribeCall{}, m.subscribeCalls...) +} + +func (m *mockProtocol) UnsubscribeCalls() []string { + m.mu.Lock() + defer m.mu.Unlock() + return append([]string{}, m.unsubCalls...) +} + +func (m *mockProtocol) PongCount() int { + m.mu.Lock() + defer m.mu.Unlock() + return m.pongCount +} diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport.go new file mode 100644 index 0000000000..33c7a9b0af --- /dev/null +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport.go @@ -0,0 +1,357 @@ +package transport + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "sync" + "time" + + "github.com/coder/websocket" + "github.com/jensneuse/abstractlogger" + "github.com/rs/xid" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/protocol" + "github.com/wundergraph/graphql-go-tools/v2/pkg/pool" +) + +// ErrDialFailed indicates that the WebSocket dial (TCP + HTTP upgrade) failed. +// The underlying cause is available via errors.Unwrap. +var ErrDialFailed = errors.New("websocket dial failed") + +// ErrInitFailed indicates that the GraphQL protocol init (connection_init / +// connection_ack handshake) failed after a successful WebSocket dial. The +// underlying cause (e.g. protocol.ErrAckTimeout) is available via errors.Unwrap. +var ErrInitFailed = errors.New("protocol init failed") + +type ErrFailedUpgrade struct { + URL string + StatusCode int +} + +func (e ErrFailedUpgrade) Error() string { + return fmt.Sprintf("failed to upgrade connection to %s, status code: %d", e.URL, e.StatusCode) +} + +type ErrInvalidSubprotocol string + +func (e ErrInvalidSubprotocol) Error() string { + return fmt.Sprintf("provided websocket subprotocol '%s' is not supported. The supported subprotocols are graphql-ws and graphql-transport-ws. Please configure your subscriptions with the mentioned subprotocols", string(e)) +} + +// WSTransportOptions configures a WSTransport. +type WSTransportOptions struct { + // UpgradeClient is the HTTP client used for the WebSocket upgrade request. + UpgradeClient *http.Client + + // Logger is the logger used for transport and connection-level events. + Logger abstractlogger.Logger + + // ReadLimit is the maximum message size in bytes the WebSocket connection + // will accept. + ReadLimit int64 + + // PingInterval is how often the transport sends a ping to each connection. + // Zero disables pinging. + PingInterval time.Duration + + // PingTimeout is how long a connection may go without a pong before it is + // considered dead. Only meaningful when PingInterval is set. + PingTimeout time.Duration + + // AckTimeout is the maximum time to wait for a connection_ack during the + // protocol init handshake. Passed to the protocol at construction. + AckTimeout time.Duration + + // WriteTimeout is the deadline applied to each WebSocket write (subscribe, + // unsubscribe, ping, pong). Passed to each connection. + WriteTimeout time.Duration + + // IdleTimeout is the duration a connection stays open after its last + // subscription is removed, allowing new subscriptions to reuse it without + // re-dialing. Zero means close immediately. + IdleTimeout time.Duration +} + +type WSTransport struct { + ctx context.Context + opts WSTransportOptions + + // mu guards both dialing and conns. + // + // dialing coalesces concurrent dial attempts: waiters block on + // dialResult.done rather than each dialing independently. + // + // conns holds only fully established connections (dial + protocol init + // complete) so every entry is always fully usable, never in a partial + // ready state. + mu sync.Mutex + dialing map[uint64]*dialResult + conns map[uint64]*wsConnection +} + +type dialResult struct { + done chan struct{} + conn *wsConnection + err error +} + +// NewWSTransport creates a new WSTransport. Connections are not closed when ctx +// is cancelled; instead they close themselves when their last subscriber is +// removed via the resolver's drain chain. The ping loop exits on ctx cancellation. +// +// If PingInterval is set, a single goroutine sends protocol-level pings to all +// connections at that cadence. If PingTimeout is also set, connections that fail +// to respond with a pong within that window are shut down. +func NewWSTransport(ctx context.Context, opts WSTransportOptions) *WSTransport { + if opts.UpgradeClient == nil { + opts.UpgradeClient = http.DefaultClient + } + + if opts.Logger == nil { + opts.Logger = abstractlogger.NoopLogger + } + + t := &WSTransport{ + ctx: ctx, + opts: opts, + conns: make(map[uint64]*wsConnection), + dialing: make(map[uint64]*dialResult), + } + + if opts.PingInterval > 0 { + go t.pingLoop() + } + + return t +} + +// Subscribe initiates a GraphQL subscription over WebSocket. It reuses an +// existing connection when one is available for the same endpoint, subprotocol, +// headers, and init payload, dialing a new one otherwise. +func (t *WSTransport) Subscribe(ctx context.Context, req *common.Request, opts common.Options, handler common.Handler) (func(), error) { + conn, err := t.getOrDial(ctx, opts) + if err != nil { + return nil, err + } + + id := xid.New().String() + return conn.subscribe(ctx, id, req, handler) +} + +// pingLoop sends periodic pings to all active connections and shuts down +// any that have not responded with a pong in time. +func (t *WSTransport) pingLoop() { + tick := time.Tick(t.opts.PingInterval) + for { + select { + case <-t.ctx.Done(): + return + case <-tick: + t.mu.Lock() + conns := make([]*wsConnection, 0, len(t.conns)) + for _, conn := range t.conns { + conns = append(conns, conn) + } + t.mu.Unlock() + + for _, conn := range conns { + if conn.isClosed() { + continue + } + + if t.opts.PingTimeout > 0 && conn.pongOverdue(t.opts.PingTimeout) { + t.opts.Logger.Debug("wsTransport.pingLoop", + abstractlogger.String("action", "pong_timeout"), + ) + conn.closeConn() + continue + } + + if err := conn.sendPing(); err != nil { + t.opts.Logger.Debug("wsTransport.pingLoop", + abstractlogger.String("action", "ping_failed"), + abstractlogger.Error(err), + ) + } + } + } + } +} + +func (t *WSTransport) ConnCount() int { + t.mu.Lock() + defer t.mu.Unlock() + + return len(t.conns) +} + +func (t *WSTransport) getOrDial(ctx context.Context, opts common.Options) (*wsConnection, error) { + key := connKey(opts) + + t.mu.Lock() + + if conn, ok := t.conns[key]; ok && !conn.isClosed() { + t.mu.Unlock() + return conn, nil + } + + if result, ok := t.dialing[key]; ok { + t.mu.Unlock() + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-result.done: + } + + if result.err != nil { + return nil, result.err + } + + return result.conn, nil + } + + result := &dialResult{done: make(chan struct{})} + t.dialing[key] = result + t.mu.Unlock() + + conn, err := t.dial(ctx, key, opts) + + result.conn = conn + result.err = err + close(result.done) + + t.mu.Lock() + delete(t.dialing, key) + + if err == nil { + t.conns[key] = conn + } + t.mu.Unlock() + + return conn, err +} + +func (t *WSTransport) dial(ctx context.Context, key uint64, opts common.Options) (*wsConnection, error) { + t.opts.Logger.Debug("wsTransport.dial", + abstractlogger.String("endpoint", opts.Endpoint), + abstractlogger.String("subprotocol", string(opts.WSSubprotocol)), + ) + + wsConn, resp, err := websocket.Dial(ctx, opts.Endpoint, &websocket.DialOptions{ //nolint:bodyclose + HTTPClient: t.opts.UpgradeClient, + Subprotocols: opts.WSSubprotocol.Subprotocols(), + HTTPHeader: opts.Headers, + }) + if err != nil { + if errors.Is(err, context.Canceled) { + return nil, err + } + + t.opts.Logger.Error("wsTransport.dial", + abstractlogger.String("endpoint", opts.Endpoint), + abstractlogger.Error(err), + ) + + // backwards compatibility with error handling in the router + if resp != nil && resp.StatusCode != http.StatusSwitchingProtocols { + return nil, ErrFailedUpgrade{URL: opts.Endpoint, StatusCode: resp.StatusCode} + } + + return nil, fmt.Errorf("%w: %w", ErrDialFailed, err) + } + + wsConn.SetReadLimit(t.opts.ReadLimit) + + proto, err := t.negotiateSubprotocol(opts.WSSubprotocol, wsConn.Subprotocol()) + if err != nil { + t.opts.Logger.Error("wsTransport.dial", + abstractlogger.String("endpoint", opts.Endpoint), + abstractlogger.String("error", "subprotocol negotiation failed"), + abstractlogger.Error(err), + ) + _ = wsConn.Close(websocket.StatusProtocolError, err.Error()) + return nil, err + } + + initCtx, initCancel := context.WithTimeout(ctx, t.opts.AckTimeout) + defer initCancel() + + if err := proto.Init(initCtx, wsConn, opts.InitPayload); err != nil { + t.opts.Logger.Error("wsTransport.dial", + abstractlogger.String("endpoint", opts.Endpoint), + abstractlogger.String("error", "protocol init failed"), + abstractlogger.Error(err), + ) + _ = wsConn.Close(websocket.StatusProtocolError, "init failed") + return nil, fmt.Errorf("%w: %w", ErrInitFailed, err) + } + + t.opts.Logger.Debug("wsTransport.dial", + abstractlogger.String("endpoint", opts.Endpoint), + abstractlogger.String("status", "connected"), + abstractlogger.String("negotiated_subprotocol", wsConn.Subprotocol()), + ) + + conn := newWSConnection(wsConn, proto, wsConnectionOptions{ + logger: t.opts.Logger, + writeTimeout: t.opts.WriteTimeout, + idleTimeout: t.opts.IdleTimeout, + onEmpty: func() { t.removeConn(key) }, + }) + + go conn.readLoop() + + return conn, nil +} + +func (t *WSTransport) negotiateSubprotocol(requested common.WSSubprotocol, accepted string) (protocol.Protocol, error) { + if requested != common.SubprotocolAuto { + if accepted != string(requested) { + return nil, ErrInvalidSubprotocol(accepted) + } + } + + switch common.WSSubprotocol(accepted) { + case common.SubprotocolGraphQLTransportWS: + return protocol.NewGraphQLTransportWS(), nil + case common.SubprotocolGraphQLWS: + return protocol.NewGraphQLWS(), nil + default: + return nil, ErrInvalidSubprotocol(accepted) + } +} + +func (t *WSTransport) removeConn(key uint64) { + t.mu.Lock() + defer t.mu.Unlock() + delete(t.conns, key) +} + +// connKey computes a hash key for connection pooling. +func connKey(opts common.Options) uint64 { + h := pool.Hash64.Get() + defer pool.Hash64.Put(h) + + _, _ = h.WriteString(opts.Endpoint) + _, _ = h.WriteString("\x00") + + _, _ = h.WriteString(string(opts.WSSubprotocol)) + _, _ = h.WriteString("\x00") + + if len(opts.Headers) > 0 { + _ = opts.Headers.Write(h) + } + _, _ = h.WriteString("\x00") + + if len(opts.InitPayload) > 0 { + if data, err := json.Marshal(opts.InitPayload); err == nil { + _, _ = h.Write(data) + } + } + + return h.Sum64() +} diff --git a/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport_test.go b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport_test.go new file mode 100644 index 0000000000..57f7405c01 --- /dev/null +++ b/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/transport/ws_transport_test.go @@ -0,0 +1,1223 @@ +package transport + +import ( + "context" + "net/http" + "net/http/httptest" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/coder/websocket" + "github.com/coder/websocket/wsjson" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource/subscriptionclient/common" +) + +func TestWSTransport_Subscribe(t *testing.T) { + t.Parallel() + + t.Run("dials and returns message channel", func(t *testing.T) { + t.Parallel() + + server := newGraphQLWSServer(t, func(ctx context.Context, conn *websocket.Conn) { + // Read subscribe + var msg map[string]any + _ = wsjson.Read(ctx, conn, &msg) + assert.Equal(t, "subscribe", msg["type"]) + + // Send data + _ = wsjson.Write(ctx, conn, map[string]any{ + "id": msg["id"], + "type": "next", + "payload": map[string]any{"data": map[string]any{"value": 42}}, + }) + + // Send complete + _ = wsjson.Write(ctx, conn, map[string]any{ + "id": msg["id"], + "type": "complete", + }) + }) + + tr := newTestWSTransport(t, WSTransportOptions{}) + + handler, receive := collectingHandler() + cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + }, handler) + require.NoError(t, err) + defer cancel() + + msg := receive(t, time.Second) + assert.Contains(t, string(msg.Payload.Data), "42") + + msg = receive(t, time.Second) + assert.Equal(t, common.MessageTypeComplete, msg.Type) + }) + + t.Run("reuses connection for same endpoint", func(t *testing.T) { + t.Parallel() + + var dialCount atomic.Int32 + server := newGraphQLWSServer(t, func(ctx context.Context, conn *websocket.Conn) { + dialCount.Add(1) + + for { + var msg map[string]any + if err := wsjson.Read(ctx, conn, &msg); err != nil { + return + } + + if msg["type"] == "subscribe" { + _ = wsjson.Write(ctx, conn, map[string]any{ + "id": msg["id"], + "type": "next", + "payload": map[string]any{"data": map[string]any{"value": 1}}, + }) + } + } + }) + + tr := newTestWSTransport(t, WSTransportOptions{}) + + opts := common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + } + + handler1, receive1 := collectingHandler() + cancel1, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { a }"}, opts, handler1) + require.NoError(t, err) + defer cancel1() + + handler2, receive2 := collectingHandler() + cancel2, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { b }"}, opts, handler2) + require.NoError(t, err) + defer cancel2() + + // Both should receive messages + receive1(t, time.Second) + receive2(t, time.Second) + + // Only one connection should have been made + assert.Equal(t, int32(1), dialCount.Load()) + assert.Equal(t, 1, tr.ConnCount()) + }) + + t.Run("creates new connection for different headers", func(t *testing.T) { + t.Parallel() + + var dialCount atomic.Int32 + server := newGraphQLWSServer(t, func(ctx context.Context, conn *websocket.Conn) { + dialCount.Add(1) + + for { + var msg map[string]any + if err := wsjson.Read(ctx, conn, &msg); err != nil { + return + } + + if msg["type"] == "subscribe" { + _ = wsjson.Write(ctx, conn, map[string]any{ + "id": msg["id"], + "type": "next", + "payload": map[string]any{"data": map[string]any{"value": 1}}, + }) + } + } + }) + + tr := newTestWSTransport(t, WSTransportOptions{}) + + headers1 := http.Header{"Authorization": []string{"Bearer token1"}} + headers2 := http.Header{"Authorization": []string{"Bearer token2"}} + + handler1, receive1 := collectingHandler() + cancel1, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { a }"}, common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + Headers: headers1, + }, handler1) + require.NoError(t, err) + defer cancel1() + + handler2, receive2 := collectingHandler() + cancel2, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { b }"}, common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + Headers: headers2, + }, handler2) + require.NoError(t, err) + defer cancel2() + + receive1(t, time.Second) + receive2(t, time.Second) + + // Two connections due to different headers + assert.Equal(t, int32(2), dialCount.Load()) + assert.Equal(t, 2, tr.ConnCount()) + }) + + t.Run("creates new connection for different init payload", func(t *testing.T) { + t.Parallel() + + var dialCount atomic.Int32 + server := newGraphQLWSServer(t, func(ctx context.Context, conn *websocket.Conn) { + dialCount.Add(1) + + for { + var msg map[string]any + if err := wsjson.Read(ctx, conn, &msg); err != nil { + return + } + + if msg["type"] == "subscribe" { + _ = wsjson.Write(ctx, conn, map[string]any{ + "id": msg["id"], + "type": "next", + "payload": map[string]any{"data": map[string]any{"value": 1}}, + }) + } + } + }) + + tr := newTestWSTransport(t, WSTransportOptions{}) + + handler1, receive1 := collectingHandler() + cancel1, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { a }"}, common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + InitPayload: map[string]any{"token": "abc"}, + }, handler1) + require.NoError(t, err) + defer cancel1() + + handler2, receive2 := collectingHandler() + cancel2, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { b }"}, common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + InitPayload: map[string]any{"token": "xyz"}, + }, handler2) + require.NoError(t, err) + defer cancel2() + + receive1(t, time.Second) + receive2(t, time.Second) + + // Two connections due to different init payload + assert.Equal(t, int32(2), dialCount.Load()) + assert.Equal(t, 2, tr.ConnCount()) + }) + + t.Run("removes connection when all subscriptions closed", func(t *testing.T) { + t.Parallel() + + server := newGraphQLWSServer(t, func(ctx context.Context, conn *websocket.Conn) { + for { + var msg map[string]any + if err := wsjson.Read(ctx, conn, &msg); err != nil { + return + } + } + }) + + tr := newTestWSTransport(t, WSTransportOptions{}) + + opts := common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + } + + cancel1, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { a }"}, opts, func(_ *common.Message) {}) + require.NoError(t, err) + + cancel2, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { b }"}, opts, func(_ *common.Message) {}) + require.NoError(t, err) + + assert.Equal(t, 1, tr.ConnCount()) + + cancel1() + assert.Equal(t, 1, tr.ConnCount()) // still has one subscription + + cancel2() + + // Wait for onEmpty callback + assert.Eventually(t, func() bool { + return tr.ConnCount() == 0 + }, time.Second, 10*time.Millisecond) + }) + + t.Run("redials after connection closed", func(t *testing.T) { + t.Parallel() + + var dialCount atomic.Int32 + server := newGraphQLWSServer(t, func(ctx context.Context, conn *websocket.Conn) { + dialCount.Add(1) + + for { + var msg map[string]any + if err := wsjson.Read(ctx, conn, &msg); err != nil { + return + } + + if msg["type"] == "subscribe" { + _ = wsjson.Write(ctx, conn, map[string]any{ + "id": msg["id"], + "type": "next", + "payload": map[string]any{"data": map[string]any{"value": 1}}, + }) + } + } + }) + + tr := newTestWSTransport(t, WSTransportOptions{}) + + opts := common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + } + + // First subscription + handler1, receive1 := collectingHandler() + cancel1, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { a }"}, opts, handler1) + require.NoError(t, err) + receive1(t, time.Second) + cancel1() + + // Wait for connection to be removed + assert.Eventually(t, func() bool { + return tr.ConnCount() == 0 + }, time.Second, 10*time.Millisecond) + + // Second subscription should redial + handler2, receive2 := collectingHandler() + cancel2, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { b }"}, opts, handler2) + require.NoError(t, err) + defer cancel2() + receive2(t, time.Second) + + assert.Equal(t, int32(2), dialCount.Load()) + }) +} + +func TestWSTransport_SubscriberDrain(t *testing.T) { + t.Parallel() + + t.Run("connection closes when last subscriber cancels", func(t *testing.T) { + t.Parallel() + + server := newGraphQLWSServer(t, func(ctx context.Context, conn *websocket.Conn) { + for { + var msg map[string]any + if err := wsjson.Read(ctx, conn, &msg); err != nil { + return + } + } + }) + + tr := newTestWSTransport(t, WSTransportOptions{}) + + cancel, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { a }"}, common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + }, func(_ *common.Message) {}) + require.NoError(t, err) + + assert.Equal(t, 1, tr.ConnCount()) + + cancel() + + assert.Eventually(t, func() bool { + return tr.ConnCount() == 0 + }, time.Second, 10*time.Millisecond) + }) + + t.Run("connection stays open while subscribers remain", func(t *testing.T) { + t.Parallel() + + server := newGraphQLWSServer(t, func(ctx context.Context, conn *websocket.Conn) { + for { + var msg map[string]any + if err := wsjson.Read(ctx, conn, &msg); err != nil { + return + } + } + }) + + tr := newTestWSTransport(t, WSTransportOptions{}) + + opts := common.Options{Endpoint: server.URL, Transport: common.TransportWS} + + cancel1, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { a }"}, opts, func(_ *common.Message) {}) + require.NoError(t, err) + + cancel2, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { b }"}, opts, func(_ *common.Message) {}) + require.NoError(t, err) + + assert.Equal(t, 1, tr.ConnCount()) + + cancel1() + time.Sleep(50 * time.Millisecond) + assert.Equal(t, 1, tr.ConnCount(), "connection should stay open with remaining subscriber") + + cancel2() + + assert.Eventually(t, func() bool { + return tr.ConnCount() == 0 + }, time.Second, 10*time.Millisecond) + }) +} + +func TestWSTransport_ConcurrentSubscribe(t *testing.T) { + t.Parallel() + + t.Run("handles concurrent subscribes to same endpoint", func(t *testing.T) { + t.Parallel() + + var dialCount atomic.Int32 + server := newGraphQLWSServer(t, func(ctx context.Context, conn *websocket.Conn) { + dialCount.Add(1) + + for { + var msg map[string]any + if err := wsjson.Read(ctx, conn, &msg); err != nil { + return + } + + if msg["type"] == "subscribe" { + _ = wsjson.Write(ctx, conn, map[string]any{ + "id": msg["id"], + "type": "next", + "payload": map[string]any{"data": map[string]any{"value": 1}}, + }) + } + } + }) + + tr := newTestWSTransport(t, WSTransportOptions{ + IdleTimeout: 30 * time.Second, + }) + + opts := common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + } + + var wg sync.WaitGroup + for range 10 { + wg.Go(func() { + handler, receive := collectingHandler() + cancel, err := tr.Subscribe(context.Background(), &common.Request{Query: "subscription { test }"}, opts, handler) + if err != nil { + t.Errorf("subscribe error: %v", err) + return + } + defer cancel() + + receive(t, time.Second) + }) + } + + wg.Wait() + + assert.Equal(t, int32(1), dialCount.Load()) + }) +} + +func TestWSTransport_InitPayloadForwarding(t *testing.T) { + t.Parallel() + + t.Run("forwards init payload to server with graphql-transport-ws protocol", func(t *testing.T) { + t.Parallel() + + receivedPayload := make(chan map[string]any, 1) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + Subprotocols: []string{"graphql-transport-ws"}, + }) + if err != nil { + return + } + defer conn.Close(websocket.StatusNormalClosure, "") + + ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second) + defer cancel() + + // Read connection_init and capture payload + var initMsg map[string]any + if err := wsjson.Read(ctx, conn, &initMsg); err != nil { + return + } + if initMsg["type"] != "connection_init" { + return + } + if payload, ok := initMsg["payload"].(map[string]any); ok { + receivedPayload <- payload + } else { + receivedPayload <- nil + } + + _ = wsjson.Write(ctx, conn, map[string]string{"type": "connection_ack"}) + + // Read subscribe and respond + var subMsg map[string]any + if err := wsjson.Read(ctx, conn, &subMsg); err != nil { + return + } + _ = wsjson.Write(ctx, conn, map[string]any{ + "id": subMsg["id"], + "type": "next", + "payload": map[string]any{"data": map[string]any{"value": 1}}, + }) + })) + t.Cleanup(server.Close) + + tr := newTestWSTransport(t, WSTransportOptions{}) + + initPayload := map[string]any{ + "Authorization": "Bearer secret-token", + "X-Custom": "custom-value", + "nested": map[string]any{ + "key": "nested-value", + }, + } + + handler, receive := collectingHandler() + cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + WSSubprotocol: common.SubprotocolGraphQLTransportWS, + InitPayload: initPayload, + }, handler) + require.NoError(t, err) + defer cancel() + + // Verify payload was received by server + select { + case payload := <-receivedPayload: + require.NotNil(t, payload, "server should receive init payload") + assert.Equal(t, "Bearer secret-token", payload["Authorization"]) + assert.Equal(t, "custom-value", payload["X-Custom"]) + nested, ok := payload["nested"].(map[string]any) + require.True(t, ok, "nested should be a map") + assert.Equal(t, "nested-value", nested["key"]) + case <-time.After(time.Second): + t.Fatal("timeout waiting for init payload") + } + + // Subscription should work + msg := receive(t, time.Second) + assert.NotNil(t, msg.Payload) + }) + + t.Run("forwards init payload to server with graphql-ws legacy protocol", func(t *testing.T) { + t.Parallel() + + receivedPayload := make(chan map[string]any, 1) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + Subprotocols: []string{"graphql-ws"}, + }) + if err != nil { + return + } + defer conn.Close(websocket.StatusNormalClosure, "") + + ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second) + defer cancel() + + // Read connection_init and capture payload + var initMsg map[string]any + if err := wsjson.Read(ctx, conn, &initMsg); err != nil { + return + } + if initMsg["type"] != "connection_init" { + return + } + if payload, ok := initMsg["payload"].(map[string]any); ok { + receivedPayload <- payload + } else { + receivedPayload <- nil + } + + _ = wsjson.Write(ctx, conn, map[string]string{"type": "connection_ack"}) + + // Read start and respond + var startMsg map[string]any + if err := wsjson.Read(ctx, conn, &startMsg); err != nil { + return + } + _ = wsjson.Write(ctx, conn, map[string]any{ + "id": startMsg["id"], + "type": "data", + "payload": map[string]any{"data": map[string]any{"value": 1}}, + }) + })) + t.Cleanup(server.Close) + + tr := newTestWSTransport(t, WSTransportOptions{}) + + initPayload := map[string]any{ + "token": "legacy-auth-token", + "version": float64(2), // JSON numbers are float64 + } + + handler, receive := collectingHandler() + cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + WSSubprotocol: common.SubprotocolGraphQLWS, // Legacy protocol + InitPayload: initPayload, + }, handler) + require.NoError(t, err) + defer cancel() + + // Verify payload was received by server + select { + case payload := <-receivedPayload: + require.NotNil(t, payload, "server should receive init payload") + assert.Equal(t, "legacy-auth-token", payload["token"]) + assert.Equal(t, float64(2), payload["version"]) + case <-time.After(time.Second): + t.Fatal("timeout waiting for init payload") + } + + // Subscription should work + msg := receive(t, time.Second) + assert.NotNil(t, msg.Payload) + }) + + t.Run("sends empty payload when init payload is nil", func(t *testing.T) { + t.Parallel() + + receivedPayload := make(chan map[string]any, 1) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + Subprotocols: []string{"graphql-transport-ws"}, + }) + if err != nil { + return + } + defer conn.Close(websocket.StatusNormalClosure, "") + + ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second) + defer cancel() + + // Read connection_init and capture payload + var initMsg map[string]any + if err := wsjson.Read(ctx, conn, &initMsg); err != nil { + return + } + if initMsg["type"] != "connection_init" { + return + } + if payload, ok := initMsg["payload"].(map[string]any); ok { + receivedPayload <- payload + } else { + receivedPayload <- nil + } + + _ = wsjson.Write(ctx, conn, map[string]string{"type": "connection_ack"}) + + // Read subscribe and respond + var subMsg map[string]any + if err := wsjson.Read(ctx, conn, &subMsg); err != nil { + return + } + _ = wsjson.Write(ctx, conn, map[string]any{ + "id": subMsg["id"], + "type": "next", + "payload": map[string]any{"data": map[string]any{"value": 1}}, + }) + })) + t.Cleanup(server.Close) + + tr := newTestWSTransport(t, WSTransportOptions{}) + + handler, receive := collectingHandler() + cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + WSSubprotocol: common.SubprotocolGraphQLTransportWS, + InitPayload: nil, // No init payload + }, handler) + require.NoError(t, err) + defer cancel() + + // Server should receive nil/empty payload + select { + case payload := <-receivedPayload: + assert.Nil(t, payload, "server should receive nil payload when not provided") + case <-time.After(time.Second): + t.Fatal("timeout waiting for init message") + } + + // Subscription should still work + msg := receive(t, time.Second) + assert.NotNil(t, msg.Payload) + }) + + t.Run("same endpoint with different init payloads uses separate connections", func(t *testing.T) { + t.Parallel() + + var mu sync.Mutex + receivedPayloads := make([]map[string]any, 0) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + Subprotocols: []string{"graphql-transport-ws"}, + }) + if err != nil { + return + } + defer conn.Close(websocket.StatusNormalClosure, "") + + ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second) + defer cancel() + + // Read connection_init and capture payload + var initMsg map[string]any + if err := wsjson.Read(ctx, conn, &initMsg); err != nil { + return + } + if initMsg["type"] != "connection_init" { + return + } + + mu.Lock() + if payload, ok := initMsg["payload"].(map[string]any); ok { + receivedPayloads = append(receivedPayloads, payload) + } + mu.Unlock() + + _ = wsjson.Write(ctx, conn, map[string]string{"type": "connection_ack"}) + + // Handle subscriptions + for { + var msg map[string]any + if err := wsjson.Read(ctx, conn, &msg); err != nil { + return + } + if msg["type"] == "subscribe" { + _ = wsjson.Write(ctx, conn, map[string]any{ + "id": msg["id"], + "type": "next", + "payload": map[string]any{"data": map[string]any{"value": 1}}, + }) + } + } + })) + t.Cleanup(server.Close) + + tr := newTestWSTransport(t, WSTransportOptions{}) + + // First subscription with user1 token + handler1, receive1 := collectingHandler() + cancel1, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + WSSubprotocol: common.SubprotocolGraphQLTransportWS, + InitPayload: map[string]any{"user": "user1"}, + }, handler1) + require.NoError(t, err) + defer cancel1() + + receive1(t, time.Second) + + // Second subscription with user2 token - should create new connection + handler2, receive2 := collectingHandler() + cancel2, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + WSSubprotocol: common.SubprotocolGraphQLTransportWS, + InitPayload: map[string]any{"user": "user2"}, + }, handler2) + require.NoError(t, err) + defer cancel2() + + receive2(t, time.Second) + + // Verify two separate connections were made with different payloads + assert.Equal(t, 2, tr.ConnCount()) + + mu.Lock() + defer mu.Unlock() + require.Len(t, receivedPayloads, 2) + + users := make([]string, 0, 2) + for _, p := range receivedPayloads { + if user, ok := p["user"].(string); ok { + users = append(users, user) + } + } + assert.ElementsMatch(t, []string{"user1", "user2"}, users) + }) +} + +func TestWSTransport_LegacyProtocol(t *testing.T) { + t.Parallel() + + t.Run("connects to legacy graphql-ws server", func(t *testing.T) { + t.Parallel() + + server := newLegacyGraphQLWSServer(t, func(ctx context.Context, conn *websocket.Conn) { + // Read start message + var msg map[string]any + require.NoError(t, wsjson.Read(ctx, conn, &msg)) + assert.Equal(t, "start", msg["type"]) + + // Send data + _ = wsjson.Write(ctx, conn, map[string]any{ + "id": msg["id"], + "type": "data", + "payload": map[string]any{"data": map[string]any{"value": 42}}, + }) + + // Send complete + _ = wsjson.Write(ctx, conn, map[string]any{ + "id": msg["id"], + "type": "complete", + }) + }) + + tr := newTestWSTransport(t, WSTransportOptions{}) + + handler, receive := collectingHandler() + cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + WSSubprotocol: common.SubprotocolGraphQLWS, // Request legacy protocol + }, handler) + require.NoError(t, err) + defer cancel() + + msg := receive(t, time.Second) + assert.Contains(t, string(msg.Payload.Data), "42") + + msg = receive(t, time.Second) + assert.Equal(t, common.MessageTypeComplete, msg.Type) + }) + + t.Run("handles keep-alive messages", func(t *testing.T) { + t.Parallel() + + server := newLegacyGraphQLWSServer(t, func(ctx context.Context, conn *websocket.Conn) { + // Read start message + var msg map[string]any + require.NoError(t, wsjson.Read(ctx, conn, &msg)) + + // Send keep-alive + _ = wsjson.Write(ctx, conn, map[string]string{"type": "ka"}) + + // Send data + _ = wsjson.Write(ctx, conn, map[string]any{ + "id": msg["id"], + "type": "data", + "payload": map[string]any{"data": map[string]any{"value": 1}}, + }) + + // Send complete + _ = wsjson.Write(ctx, conn, map[string]any{ + "id": msg["id"], + "type": "complete", + }) + }) + + tr := newTestWSTransport(t, WSTransportOptions{}) + + handler, receive := collectingHandler() + cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + WSSubprotocol: common.SubprotocolGraphQLWS, + }, handler) + require.NoError(t, err) + defer cancel() + + // Should receive data (keep-alive is handled internally) + msg := receive(t, time.Second) + assert.NotNil(t, msg.Payload) + + msg = receive(t, time.Second) + assert.Equal(t, common.MessageTypeComplete, msg.Type) + }) + + t.Run("auto-negotiates to legacy when modern unavailable", func(t *testing.T) { + t.Parallel() + + // Server only supports legacy protocol + server := newLegacyGraphQLWSServer(t, func(ctx context.Context, conn *websocket.Conn) { + var msg map[string]any + require.NoError(t, wsjson.Read(ctx, conn, &msg)) + assert.Equal(t, "start", msg["type"]) // Should use legacy message type + + _ = wsjson.Write(ctx, conn, map[string]any{ + "id": msg["id"], + "type": "data", + "payload": map[string]any{"data": map[string]any{"value": 99}}, + }) + _ = wsjson.Write(ctx, conn, map[string]any{ + "id": msg["id"], + "type": "complete", + }) + }) + + tr := newTestWSTransport(t, WSTransportOptions{}) + + handler, receive := collectingHandler() + cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + WSSubprotocol: common.SubprotocolAuto, // Auto-negotiate + }, handler) + require.NoError(t, err) + defer cancel() + + msg := receive(t, time.Second) + assert.Contains(t, string(msg.Payload.Data), "99") + }) +} + +func TestWSTransport_Heartbeat(t *testing.T) { + t.Parallel() + + t.Run("sends pings and receives pongs", func(t *testing.T) { + t.Parallel() + + var pingCount atomic.Int32 + server := newGraphQLWSServer(t, func(ctx context.Context, conn *websocket.Conn) { + // Read subscribe + var msg map[string]any + _ = wsjson.Read(ctx, conn, &msg) + + // Keep connection alive, respond to pings with pongs + for { + var incoming map[string]any + if err := wsjson.Read(ctx, conn, &incoming); err != nil { + return + } + if incoming["type"] == "ping" { + pingCount.Add(1) + _ = wsjson.Write(ctx, conn, map[string]string{"type": "pong"}) + } + } + }) + + tr := newTestWSTransport(t, WSTransportOptions{PingInterval: 50 * time.Millisecond}) + + cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + }, func(_ *common.Message) {}) + require.NoError(t, err) + defer cancel() + + // Wait for at least 2 pings to be sent + assert.Eventually(t, func() bool { + return pingCount.Load() >= 2 + }, time.Second, 10*time.Millisecond) + }) + + t.Run("closes connection on pong timeout", func(t *testing.T) { + t.Parallel() + + server := newGraphQLWSServer(t, func(ctx context.Context, conn *websocket.Conn) { + // Read subscribe + var msg map[string]any + _ = wsjson.Read(ctx, conn, &msg) + + // Read pings but never respond with pong + for { + var incoming map[string]any + if err := wsjson.Read(ctx, conn, &incoming); err != nil { + return + } + } + }) + + tr := newTestWSTransport(t, WSTransportOptions{PingInterval: 100 * time.Millisecond, PingTimeout: 50 * time.Millisecond}) + + handler, receive := collectingHandler() + _, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + }, handler) + require.NoError(t, err) + + // Connection should be closed due to pong timeout, subscriber gets notified + msg := receive(t, time.Second) + assert.Equal(t, common.MessageTypeConnectionError, msg.Type) + assert.Error(t, msg.Err) + + assert.Eventually(t, func() bool { + return tr.ConnCount() == 0 + }, time.Second, 10*time.Millisecond) + }) + + t.Run("does not kill connection when ping timeout is disabled", func(t *testing.T) { + t.Parallel() + + var pingCount atomic.Int32 + server := newGraphQLWSServer(t, func(ctx context.Context, conn *websocket.Conn) { + // Read subscribe + var msg map[string]any + _ = wsjson.Read(ctx, conn, &msg) + + // Read pings but never respond with pong + for { + var incoming map[string]any + if err := wsjson.Read(ctx, conn, &incoming); err != nil { + return + } + if incoming["type"] == "ping" { + pingCount.Add(1) + } + } + }) + + // PingInterval set, PingTimeout left at zero (disabled) + tr := newTestWSTransport(t, WSTransportOptions{PingInterval: 50 * time.Millisecond}) + + cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + }, func(_ *common.Message) {}) + require.NoError(t, err) + defer cancel() + + // Wait for several ping cycles without any pong responses + assert.Eventually(t, func() bool { + return pingCount.Load() >= 3 + }, time.Second, 10*time.Millisecond) + + // Connection must still be alive despite no pongs + assert.Equal(t, 1, tr.ConnCount()) + }) + + t.Run("keeps connection alive when pongs arrive", func(t *testing.T) { + t.Parallel() + + server := newGraphQLWSServer(t, func(ctx context.Context, conn *websocket.Conn) { + // Read subscribe + var msg map[string]any + _ = wsjson.Read(ctx, conn, &msg) + + // Respond to pings with pongs + for { + var incoming map[string]any + if err := wsjson.Read(ctx, conn, &incoming); err != nil { + return + } + if incoming["type"] == "ping" { + _ = wsjson.Write(ctx, conn, map[string]string{"type": "pong"}) + } + } + }) + + tr := newTestWSTransport(t, WSTransportOptions{PingInterval: 50 * time.Millisecond, PingTimeout: 200 * time.Millisecond}) + + cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + }, func(_ *common.Message) {}) + require.NoError(t, err) + defer cancel() + + // Connection should remain alive after several ping cycles + time.Sleep(250 * time.Millisecond) + assert.Equal(t, 1, tr.ConnCount()) + }) + + t.Run("legacy graphql-ws survives ping timeout with ka messages", func(t *testing.T) { + t.Parallel() + + // Server sends periodic ka (keep-alive) messages, never expects client pings. + server := newLegacyGraphQLWSServer(t, func(ctx context.Context, conn *websocket.Conn) { + var msg map[string]any + _ = wsjson.Read(ctx, conn, &msg) + + // Read in the background so the close handshake completes promptly. + closed := make(chan struct{}) + go func() { + defer close(closed) + var discard map[string]any + _ = wsjson.Read(ctx, conn, &discard) + }() + + for { + select { + case <-closed: + return + case <-ctx.Done(): + return + case <-time.After(30 * time.Millisecond): + if err := wsjson.Write(ctx, conn, map[string]string{"type": "ka"}); err != nil { + return + } + } + } + }) + + // Enable ping loop with a tight timeout. Legacy connections are + // unaffected because sendPing is a no-op for non-Pinger protocols, + // so lastPingSentAt stays zero and pongOverdue never triggers. + tr := newTestWSTransport(t, WSTransportOptions{ + PingInterval: 50 * time.Millisecond, + PingTimeout: 150 * time.Millisecond, + WriteTimeout: 100 * time.Millisecond, + }) + + cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + WSSubprotocol: common.SubprotocolGraphQLWS, + }, func(_ *common.Message) {}) + require.NoError(t, err) + defer cancel() + + // Survive well past the ping timeout — several cycles. + time.Sleep(400 * time.Millisecond) + assert.Equal(t, 1, tr.ConnCount()) + }) + + t.Run("legacy graphql-ws does not send client pings", func(t *testing.T) { + t.Parallel() + + // Track any messages the server receives after the subscribe. + var extraMessages atomic.Int32 + server := newLegacyGraphQLWSServer(t, func(ctx context.Context, conn *websocket.Conn) { + // Read subscribe + var msg map[string]any + _ = wsjson.Read(ctx, conn, &msg) + + // Any further messages from the client are unexpected — legacy + // clients should never send ping. + for { + var incoming map[string]any + if err := wsjson.Read(ctx, conn, &incoming); err != nil { + return + } + extraMessages.Add(1) + } + }) + + tr := newTestWSTransport(t, WSTransportOptions{ + PingInterval: 50 * time.Millisecond, + }) + + cancel, err := tr.Subscribe(context.Background(), &common.Request{ + Query: "subscription { test }", + }, common.Options{ + Endpoint: server.URL, + Transport: common.TransportWS, + WSSubprotocol: common.SubprotocolGraphQLWS, + }, func(_ *common.Message) {}) + require.NoError(t, err) + defer cancel() + + // Wait long enough for several ping cycles to pass. + time.Sleep(200 * time.Millisecond) + + // Server should not have received any messages (no pings sent). + assert.Equal(t, int32(0), extraMessages.Load()) + }) +} + +// Test helpers + +func newGraphQLWSServer(t *testing.T, handler func(ctx context.Context, conn *websocket.Conn)) *httptest.Server { + t.Helper() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + Subprotocols: []string{"graphql-transport-ws"}, + }) + if err != nil { + return + } + defer conn.Close(websocket.StatusNormalClosure, "") + + ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second) + defer cancel() + + // Handle connection_init + var initMsg map[string]any + if err := wsjson.Read(ctx, conn, &initMsg); err != nil { + return + } + if initMsg["type"] != "connection_init" { + return + } + _ = wsjson.Write(ctx, conn, map[string]string{"type": "connection_ack"}) + + handler(ctx, conn) + })) + + t.Cleanup(server.Close) + return server +} + +func newLegacyGraphQLWSServer(t *testing.T, handler func(ctx context.Context, conn *websocket.Conn)) *httptest.Server { + t.Helper() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + Subprotocols: []string{"graphql-ws"}, // Legacy protocol only + }) + if err != nil { + return + } + defer conn.Close(websocket.StatusNormalClosure, "") + + ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second) + defer cancel() + + // Handle connection_init + var initMsg map[string]any + if err := wsjson.Read(ctx, conn, &initMsg); err != nil { + return + } + if initMsg["type"] != "connection_init" { + return + } + _ = wsjson.Write(ctx, conn, map[string]string{"type": "connection_ack"}) + + handler(ctx, conn) + })) + + t.Cleanup(server.Close) + return server +} diff --git a/v2/pkg/engine/resolve/datasource.go b/v2/pkg/engine/resolve/datasource.go index b03bdd0781..be9169c1d3 100644 --- a/v2/pkg/engine/resolve/datasource.go +++ b/v2/pkg/engine/resolve/datasource.go @@ -18,11 +18,6 @@ type SubscriptionDataSource interface { Start(ctx *Context, headers http.Header, input []byte, updater SubscriptionUpdater) error } -type AsyncSubscriptionDataSource interface { - AsyncStart(ctx *Context, id uint64, headers http.Header, input []byte, updater SubscriptionUpdater) error - AsyncStop(id uint64) -} - // HookableSubscriptionDataSource is a hookable interface for subscription data sources. // It is used to call a function when a subscription is started. // This is useful for data sources that need to do some work when a subscription is started, diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index f735752ef9..bea484483f 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -10,11 +10,12 @@ import ( "io" "net/http" "runtime" + "sync" + "sync/atomic" "time" "github.com/buger/jsonparser" "github.com/pkg/errors" - "go.uber.org/atomic" "github.com/wundergraph/go-arena" @@ -26,10 +27,16 @@ const ( DefaultHeartbeatInterval = 5 * time.Second ) -// ConnectionIDs is used to create unique connection IDs for each subscription -// Whenever a new connection is created, use this to generate a new ID -// It is public because it can be used in more high level packages to instantiate a new connection -var ConnectionIDs = atomic.NewInt64(0) +// ConnectionID identifies a client connection for subscription routing. +type ConnectionID int64 + +// connectionIDCounter is the monotonic counter backing NewConnectionID. +var connectionIDCounter atomic.Int64 + +// NewConnectionID returns a unique ConnectionID via an atomic increment. +func NewConnectionID() ConnectionID { + return ConnectionID(connectionIDCounter.Add(1)) +} type Reporter interface { // SubscriptionUpdateSent called when a new subscription update is sent @@ -48,29 +55,34 @@ type AsyncErrorWriter interface { WriteError(ctx *Context, err error, res *GraphQLResponse, w io.Writer) } -// Resolver is a single threaded event loop that processes all events on a single goroutine. -// It is absolutely critical to ensure that all events are processed quickly to prevent blocking -// and that resolver modifications are done on the event loop goroutine. Long-running operations -// should be offloaded to the subscription worker goroutine. If a different goroutine needs to emit -// an event, it should be done through the events channel to avoid race conditions. +// Resolver manages GraphQL subscriptions using a mutex-protected trigger registry. +// All trigger/subscription state is guarded by mu. Long-running I/O (writes to clients) +// is performed outside the lock using a snapshot-and-release pattern. type Resolver struct { ctx context.Context options ResolverOptions maxConcurrency chan struct{} - triggers map[uint64]*trigger - events chan subscriptionEvent - triggerUpdateBuf *bytes.Buffer + // mu protects: shutdown, triggers, subscriptionsByID, subscriptionsByConnection. + // Lock ordering: subscriptionUpdater.mu > Resolver.mu > trigger.mu (then subscriptionState.writeMu outside those locks). + mu sync.Mutex + shutdown bool + triggers map[uint64]*trigger + subscriptionsByID map[SubscriptionIdentifier]*subscriptionState + subscriptionsByConnection map[ConnectionID]map[SubscriptionIdentifier]*subscriptionState allowedErrorExtensionFields map[string]struct{} allowedErrorFields map[string]struct{} - reporter Reporter - asyncErrorWriter AsyncErrorWriter + reporter Reporter + + // errorFormatter is a function provided by the router that formats Go errors into GraphQL responses and writes them to a writer. + // It's not really async, and it needs to be done under the writer's mutex. This is complex and should be resolved in the future. + errorFormatter AsyncErrorWriter propagateSubgraphErrors bool propagateSubgraphStatusCodes bool - // Subscription heartbeat interval + // Subscription heartbeat interval for periodic updater heartbeats. heartbeatInterval time.Duration // maxSubscriptionFetchTimeout defines the maximum time a subscription fetch can take before it is considered timed out maxSubscriptionFetchTimeout time.Duration @@ -91,7 +103,7 @@ type Resolver struct { } func (r *Resolver) SetAsyncErrorWriter(w AsyncErrorWriter) { - r.asyncErrorWriter = w + r.errorFormatter = w } type tools struct { @@ -264,11 +276,11 @@ func New(ctx context.Context, options ResolverOptions) *Resolver { options: options, propagateSubgraphErrors: options.PropagateSubgraphErrors, propagateSubgraphStatusCodes: options.PropagateSubgraphStatusCodes, - events: make(chan subscriptionEvent), triggers: make(map[uint64]*trigger), + subscriptionsByID: make(map[SubscriptionIdentifier]*subscriptionState), + subscriptionsByConnection: make(map[ConnectionID]map[SubscriptionIdentifier]*subscriptionState), reporter: options.Reporter, - asyncErrorWriter: options.AsyncErrorWriter, - triggerUpdateBuf: bytes.NewBuffer(make([]byte, 0, 1024)), + errorFormatter: options.AsyncErrorWriter, allowedErrorExtensionFields: allowedExtensionFields, allowedErrorFields: allowedErrorFields, heartbeatInterval: options.SubscriptionHeartbeatInterval, @@ -283,7 +295,10 @@ func New(ctx context.Context, options ResolverOptions) *Resolver { resolver.maxConcurrency <- struct{}{} } - go resolver.processEvents() + go resolver.heartbeatLoop() + context.AfterFunc(resolver.ctx, func() { + resolver.shutdownResolver() + }) return resolver } @@ -436,132 +451,165 @@ func (r *Resolver) ArenaResolveGraphQLResponse(ctx *Context, response *GraphQLRe return resp, err } +// trigger groups subscriptions that share a data source and input. type trigger struct { + // mu protects subscriptions. + // Uses snapshot-and-release: held only during map access, released before I/O. + mu sync.RWMutex id uint64 cancel context.CancelFunc - subscriptions map[*Context]*sub - // initialized is set to true when the trigger is started and initialized - initialized bool + subscriptions map[SubscriptionIdentifier]*subscriptionState + // initialized is set to true when the trigger is started and initialized. + initialized atomic.Bool updater *subscriptionUpdater } func (t *trigger) subscriptionIds() map[context.Context]SubscriptionIdentifier { + t.mu.RLock() + defer t.mu.RUnlock() + subs := make(map[context.Context]SubscriptionIdentifier, len(t.subscriptions)) - for ctx, sub := range t.subscriptions { - subs[ctx.Context()] = sub.id + for _, sub := range t.subscriptions { + subs[sub.ctx.Context()] = sub.id } return subs } -// workItem is used to encapsulate a function that needs to be -// executed in the worker goroutine. fn will be executed, and if -// final is true the worker will be stopped after fn is executed. -type workItem struct { - fn func() - final bool -} - -type sub struct { - resolve *GraphQLSubscription - resolver *Resolver - ctx *Context - writer SubscriptionResponseWriter - id SubscriptionIdentifier - heartbeat bool - completed chan struct{} - // workChan is used to send work to the writer goroutine. All work is processed sequentially. - workChan chan workItem +// snapshotSubscriptions returns a point-in-time copy of all subscriptions. +func (t *trigger) snapshotSubscriptions() []*subscriptionState { + t.mu.RLock() + defer t.mu.RUnlock() + subs := make([]*subscriptionState, 0, len(t.subscriptions)) + for _, s := range t.subscriptions { + subs = append(subs, s) + } + return subs } -// startWorker runs in its own goroutine to process fetches and write data to the client synchronously -// it also takes care of sending heartbeats to the client but only if the subscription supports it -// TODO implement a goroutine pool that is sharded by the subscription id to avoid creating a new goroutine for each subscription -func (s *sub) startWorker() { - if s.heartbeat { - s.startWorkerWithHeartbeat() - return +// evalFilter runs SkipEvent for a single subscription. Must be called under t.mu. +func (t *trigger) evalFilter(s *subscriptionState, data []byte) (*subscriptionState, *pendingFilterError) { + if s.ctx.ctx.Err() != nil { + return nil, nil + } + skip, err := s.resolve.Filter.SkipEvent(s.ctx, data) + if err != nil { + fe := pendingFilterError{s.ctx, err, s.resolve.Response, s} + return nil, &fe } - s.startWorkerWithoutHeartbeat() + if skip { + return nil, nil + } + return s, nil } -// startWorkerWithHeartbeat is similar to startWorker but sends heartbeats to the client when enabled. -// It sends a heartbeat to the client every heartbeatInterval. Heartbeats are handled by the SubscriptionResponseWriter interface. -// TODO: Implement a shared timer implementation to avoid creating a new ticker for each subscription. -func (s *sub) startWorkerWithHeartbeat() { - heartbeatTicker := time.NewTicker(s.resolver.heartbeatInterval) - defer heartbeatTicker.Stop() +// filterSubscriptions evaluates SkipEvent for every active subscription and +// partitions them into pending updates and filter errors. +func (t *trigger) filterSubscriptions(data []byte) ([]*subscriptionState, []pendingFilterError) { + t.mu.Lock() + defer t.mu.Unlock() - for { - select { - case <-s.ctx.ctx.Done(): - // Complete when the client request context is done for synchronous subscriptions - s.close(SubscriptionCloseKindGoingAway) + var subs []*subscriptionState + var filterErrors []pendingFilterError - return - case <-s.resolver.ctx.Done(): - // Abort immediately if the resolver is shutting down - s.close(SubscriptionCloseKindGoingAway) + for _, s := range t.subscriptions { + pending, filterErr := t.evalFilter(s, data) + if pending != nil { + subs = append(subs, pending) + } + if filterErr != nil { + filterErrors = append(filterErrors, *filterErr) + } + } - return - case <-heartbeatTicker.C: - s.resolver.handleHeartbeat(s) - case work := <-s.workChan: - work.fn() + return subs, filterErrors +} - if work.final { - return - } +// filterSubscription evaluates SkipEvent for a single subscription by ID. +func (t *trigger) filterSubscription(id SubscriptionIdentifier, data []byte) (*subscriptionState, *pendingFilterError) { + t.mu.Lock() + defer t.mu.Unlock() - // Reset the heartbeat ticker after each write to avoid sending unnecessary heartbeats - heartbeatTicker.Reset(s.resolver.heartbeatInterval) - } + s, ok := t.subscriptions[id] + if !ok { + return nil, nil } -} -func (s *sub) startWorkerWithoutHeartbeat() { - for { - select { - case <-s.ctx.ctx.Done(): - // Complete when the client request context is done for synchronous subscriptions - s.close(SubscriptionCloseKindGoingAway) + sub, filterErr := t.evalFilter(s, data) - return - case <-s.resolver.ctx.Done(): - // Abort immediately if the resolver is shutting down - s.close(SubscriptionCloseKindGoingAway) + return sub, filterErr +} - return - case work := <-s.workChan: - work.fn() +// subscriptionState tracks a single active subscription. +type subscriptionState struct { + triggerID uint64 + resolve *GraphQLSubscription + ctx *Context + writer SubscriptionResponseWriter + id SubscriptionIdentifier + heartbeat bool + completed chan struct{} + // writeMu protects all writes to writer (Complete, Error, Write, Flush, Heartbeat). + // Paired with the removed atomic to prevent writes after removal. + writeMu sync.Mutex + // removed guards against writes after the subscription has been removed. + // Uses CompareAndSwap to prevent double-close of the completed channel. + removed atomic.Bool + // lastWriteTime stores unix nanos of the last successful data write. + lastWriteTime atomic.Int64 +} - if work.final { - return - } - } +func closeSubs(subs []*subscriptionState) { + for _, s := range subs { + s.done() } } -// Called when subgraph indicates a "complete" subscription -func (s *sub) complete() { - // The channel is used to communicate that the subscription is done - // It is used only in the synchronous subscription case and to avoid sending events - // to a subscription that is already done. - defer close(s.completed) +// done closes the completed channel to signal that the subscription is finished. +// It does not send any downstream messages — Complete/Error are sent separately. +func (s *subscriptionState) done() { + s.writeMu.Lock() + defer s.writeMu.Unlock() + close(s.completed) +} +// complete delivers a "subscription done" signal to the downstream writer. +// Called by handleTriggerComplete, not through toClose. +func (s *subscriptionState) complete() { + s.writeMu.Lock() + defer s.writeMu.Unlock() s.writer.Complete() } -// Called when subgraph becomes unreachable or closes the connection without a "complete" event -func (s *sub) close(kind SubscriptionCloseKind) { - // The channel is used to communicate that the subscription is done - // It is used only in the synchronous subscription case and to avoid sending events - // to a subscription that is already done. - defer close(s.completed) +// error delivers a terminal error payload to the downstream writer. +// Called by handleTriggerError, not through toClose. +func (s *subscriptionState) error(data []byte) { + s.writeMu.Lock() + defer s.writeMu.Unlock() + s.writer.Error(data) +} - s.writer.Close(kind) +// writeError delivers a formatted error to the downstream writer under writeMu. +func (s *subscriptionState) writeError(w AsyncErrorWriter, ctx *Context, err error, response *GraphQLResponse) { + s.writeMu.Lock() + defer s.writeMu.Unlock() + if s.removed.Load() { + return + } + w.WriteError(ctx, err, response, s.writer) } -func (r *Resolver) executeSubscriptionUpdate(resolveCtx *Context, sub *sub, sharedInput []byte) { +// sendHeartbeat sends a keep-alive frame to the downstream writer under writeMu. +// @TODO: this is bad, see ENG-9356 +func (s *subscriptionState) sendHeartbeat() error { + s.writeMu.Lock() + defer s.writeMu.Unlock() + if s.removed.Load() { + return nil + } + return s.writer.Heartbeat() +} + +func (r *Resolver) executeSubscriptionUpdate(resolveCtx *Context, sub *subscriptionState, sharedInput []byte) { if r.options.Debug { fmt.Printf("resolver:trigger:subscription:update:%d\n", sub.id.SubscriptionID) } @@ -580,7 +628,7 @@ func (r *Resolver) executeSubscriptionUpdate(resolveCtx *Context, sub *sub, shar if err := t.resolvable.InitSubscription(resolveCtx, input, sub.resolve.Trigger.PostProcessing); err != nil { r.resolveArenaPool.Release(resolveArena) - r.asyncErrorWriter.WriteError(resolveCtx, err, sub.resolve.Response, sub.writer) + sub.writeError(r.errorFormatter, resolveCtx, err, sub.resolve.Response) if r.options.Debug { fmt.Printf("resolver:trigger:subscription:init:failed:%d\n", sub.id.SubscriptionID) } @@ -592,7 +640,7 @@ func (r *Resolver) executeSubscriptionUpdate(resolveCtx *Context, sub *sub, shar if err := t.loader.LoadGraphQLResponseData(resolveCtx, sub.resolve.Response, t.resolvable); err != nil { r.resolveArenaPool.Release(resolveArena) - r.asyncErrorWriter.WriteError(resolveCtx, err, sub.resolve.Response, sub.writer) + sub.writeError(r.errorFormatter, resolveCtx, err, sub.resolve.Response) if r.options.Debug { fmt.Printf("resolver:trigger:subscription:load:failed:%d\n", sub.id.SubscriptionID) } @@ -602,9 +650,17 @@ func (r *Resolver) executeSubscriptionUpdate(resolveCtx *Context, sub *sub, shar return } + sub.writeMu.Lock() + if sub.removed.Load() { + sub.writeMu.Unlock() + r.resolveArenaPool.Release(resolveArena) + return + } + if err := t.resolvable.Resolve(resolveCtx.ctx, sub.resolve.Response.Data, sub.resolve.Response.Fetches, sub.writer); err != nil { r.resolveArenaPool.Release(resolveArena) - r.asyncErrorWriter.WriteError(resolveCtx, err, sub.resolve.Response, sub.writer) + r.errorFormatter.WriteError(resolveCtx, err, sub.resolve.Response, sub.writer) + sub.writeMu.Unlock() if r.options.Debug { fmt.Printf("resolver:trigger:subscription:resolve:failed:%d\n", sub.id.SubscriptionID) } @@ -617,10 +673,13 @@ func (r *Resolver) executeSubscriptionUpdate(resolveCtx *Context, sub *sub, shar r.resolveArenaPool.Release(resolveArena) if err := sub.writer.Flush(); err != nil { + sub.writeMu.Unlock() // If flush fails (e.g. client disconnected), remove the subscription. - _ = r.AsyncUnsubscribeSubscription(sub.id) + _ = r.UnsubscribeSubscription(sub.id) return } + sub.lastWriteTime.Store(time.Now().UnixNano()) + sub.writeMu.Unlock() if r.options.Debug { fmt.Printf("resolver:trigger:subscription:flushed:%d\n", sub.id.SubscriptionID) @@ -634,80 +693,18 @@ func (r *Resolver) executeSubscriptionUpdate(resolveCtx *Context, sub *sub, shar } } -// processEvents maintains the single threaded event loop that processes all events -func (r *Resolver) processEvents() { - done := r.ctx.Done() - - // events channel can't be closed here because producers are - // sending events across multiple goroutines - - for { - select { - case <-done: - r.handleShutdown() - return - case event := <-r.events: - r.handleEvent(event) - } - } -} - -// handleEvent is a single threaded function that processes events from the events channel -// All events are processed in the order they are received and need to be processed quickly -// to prevent blocking the event loop and any other events from being processed. -// TODO: consider using a worker pool that distributes events from different triggers to different workers -// to avoid blocking the event loop and improve performance. -func (r *Resolver) handleEvent(event subscriptionEvent) { - switch event.kind { - case subscriptionEventKindAddSubscription: - r.handleAddSubscription(event.triggerID, event.addSubscription) - case subscriptionEventKindRemoveSubscription: - r.handleRemoveSubscription(event.id) - case subscriptionEventKindCompleteSubscription: - r.handleCompleteSubscription(event.id) - case subscriptionEventKindRemoveClient: - r.handleRemoveClient(event.id.ConnectionID) - case subscriptionEventKindUpdateSubscription: - r.handleUpdateSubscription(event.triggerID, event.data, event.id) - case subscriptionEventKindTriggerUpdate: - r.handleTriggerUpdate(event.triggerID, event.data) - case subscriptionEventKindTriggerComplete: - r.handleTriggerComplete(event.triggerID) - case subscriptionEventKindTriggerInitialized: - r.handleTriggerInitialized(event.triggerID) - case subscriptionEventKindTriggerClose: - r.handleTriggerClose(event) - case subscriptionEventKindUnknown: - panic("unknown event") - } -} - -// handleHeartbeat sends a heartbeat to the client. It needs to be executed on the same goroutine as the writer. -func (r *Resolver) handleHeartbeat(sub *sub) { - if r.options.Debug { - fmt.Printf("resolver:heartbeat\n") - } - - if r.ctx.Err() != nil { - return - } - - if sub.ctx.Context().Err() != nil { - return - } - +func (r *Resolver) executeSubscriptionHeartbeat(sub *subscriptionState) { if r.options.Debug { fmt.Printf("resolver:heartbeat:subscription:%d\n", sub.id.SubscriptionID) } - if err := sub.writer.Heartbeat(); err != nil { - // If heartbeat fails (e.g. client disconnected), remove the subscription. - _ = r.AsyncUnsubscribeSubscription(sub.id) + if r.ctx.Err() != nil || sub.ctx.Context().Err() != nil { return } - if r.options.Debug { - fmt.Printf("resolver:heartbeat:subscription:done:%d\n", sub.id.SubscriptionID) + if err := sub.sendHeartbeat(); err != nil { + _ = r.UnsubscribeSubscription(sub.id) + return } if r.reporter != nil { @@ -715,89 +712,79 @@ func (r *Resolver) handleHeartbeat(sub *sub) { } } -func (r *Resolver) handleTriggerClose(s subscriptionEvent) { - if r.options.Debug { - fmt.Printf("resolver:trigger:shutdown:%d:%d\n", s.triggerID, s.id.SubscriptionID) - } - - r.closeTrigger(s.triggerID, s.closeKind) +type StartupHookContext struct { + Context context.Context + Updater func(data []byte) } -func (r *Resolver) handleTriggerInitialized(triggerID uint64) { - trig, ok := r.triggers[triggerID] +func (r *Resolver) executeStartupHooks(add *addSubscription, updater *subscriptionUpdater) error { + hook, ok := add.resolve.Trigger.Source.(HookableSubscriptionDataSource) if !ok { - return + return nil } - trig.initialized = true - - if r.reporter != nil { - r.reporter.TriggerCountInc(1) + hookCtx := StartupHookContext{ + Context: add.ctx.Context(), + Updater: func(data []byte) { + updater.UpdateSubscription(add.id, data) + }, } -} - -func (r *Resolver) handleTriggerComplete(triggerID uint64) { - if r.options.Debug { - fmt.Printf("resolver:trigger:complete:%d\n", triggerID) + err := hook.SubscriptionOnStart(hookCtx, add.input) + if err != nil && r.options.Debug { + fmt.Printf("resolver:trigger:subscription:startup:failed:%d\n", add.id.SubscriptionID) } - - r.completeTrigger(triggerID) + return err } -type StartupHookContext struct { - Context context.Context - Updater func(data []byte) +// registerSubscriptionLocked updates the by-ID and by-connection indexes. +func (r *Resolver) registerSubscriptionLocked(trig *trigger, s *subscriptionState) { + trig.mu.Lock() + trig.subscriptions[s.id] = s + trig.mu.Unlock() + id := s.id + r.subscriptionsByID[id] = s + byConn, ok := r.subscriptionsByConnection[id.ConnectionID] + if !ok { + byConn = make(map[SubscriptionIdentifier]*subscriptionState) + r.subscriptionsByConnection[id.ConnectionID] = byConn + } + byConn[id] = s } -func (r *Resolver) executeStartupHooks(add *addSubscription, updater *subscriptionUpdater) error { - hook, ok := add.resolve.Trigger.Source.(HookableSubscriptionDataSource) - if ok { - hookCtx := StartupHookContext{ - Context: add.ctx.Context(), - Updater: func(data []byte) { - // Writing on the updater channel is safe but has to happen outside of the event loop - // to respect order and not block the event loop - updater.UpdateSubscription(add.id, data) - }, - } - - err := hook.SubscriptionOnStart(hookCtx, add.input) - if err != nil { - if r.options.Debug { - fmt.Printf("resolver:trigger:subscription:startup:failed:%d\n", add.id.SubscriptionID) - } - r.asyncErrorWriter.WriteError(add.ctx, err, add.resolve.Response, add.writer) - _ = r.AsyncUnsubscribeSubscription(add.id) - return err - } +// unregisterSubscriptionLocked removes from the by-ID and by-connection indexes. +func (r *Resolver) unregisterSubscriptionLocked(id SubscriptionIdentifier) { + delete(r.subscriptionsByID, id) + byConn, ok := r.subscriptionsByConnection[id.ConnectionID] + if !ok { + return + } + delete(byConn, id) + if len(byConn) == 0 { + delete(r.subscriptionsByConnection, id.ConnectionID) } - return nil } -func (r *Resolver) handleAddSubscription(triggerID uint64, add *addSubscription) { - var ( - err error - ) +// addSubscription registers a new subscription under the given trigger. +func (r *Resolver) addSubscription(triggerID uint64, add *addSubscription) error { + r.mu.Lock() + defer r.mu.Unlock() + if r.shutdown { + return r.ctx.Err() + } if r.options.Debug { fmt.Printf("resolver:trigger:subscription:add:%d:%d\n", triggerID, add.id.SubscriptionID) } - s := &sub{ + s := &subscriptionState{ + triggerID: triggerID, ctx: add.ctx, resolve: add.resolve, writer: add.writer, id: add.id, completed: add.completed, - workChan: make(chan workItem, 32), - resolver: r, } - if add.ctx.ExecutionOptions.SendHeartbeat { s.heartbeat = true } - // Start the dedicated worker goroutine where the subscription updates are processed - // and writes are written to the client in a single threaded manner - go s.startWorker() - trig, ok := r.triggers[triggerID] if ok { if r.reporter != nil { @@ -806,18 +793,16 @@ func (r *Resolver) handleAddSubscription(triggerID uint64, add *addSubscription) if r.options.Debug { fmt.Printf("resolver:trigger:subscription:added:%d:%d\n", triggerID, add.id.SubscriptionID) } - // Execute the startup hooks in a separate goroutine to avoid blocking the event loop - s.workChan <- workItem{ - fn: func() { - _ = r.executeStartupHooks(add, trig.updater) - // if the startup hooks return an error, we don't have to do anything else - }, - final: false, - } - // After the startup hooks are executed, we can add the subscription to the subscriptions registry - // so that it can start receive events - trig.subscriptions[add.ctx] = s - return + // Register first so startup hooks can deliver initial data via UpdateSubscription. + r.registerSubscriptionLocked(trig, s) + // Execute the startup hooks in a goroutine to avoid holding the lock. + go func() { + if err := r.executeStartupHooks(add, trig.updater); err != nil { + s.writeError(r.errorFormatter, add.ctx, err, add.resolve.Response) + _ = r.UnsubscribeSubscription(add.id) + } + }() + return nil } if r.options.Debug { @@ -827,409 +812,457 @@ func (r *Resolver) handleAddSubscription(triggerID uint64, add *addSubscription) updater := &subscriptionUpdater{ debug: r.options.Debug, triggerID: triggerID, - ch: r.events, + resolver: r, ctx: ctx, } cloneCtx := add.ctx.clone(ctx) trig = &trigger{ id: triggerID, - subscriptions: make(map[*Context]*sub), + subscriptions: make(map[SubscriptionIdentifier]*subscriptionState), cancel: cancel, updater: updater, } r.triggers[triggerID] = trig - trig.subscriptions[add.ctx] = s updater.subsFn = trig.subscriptionIds + r.registerSubscriptionLocked(trig, s) if r.reporter != nil { r.reporter.SubscriptionCountInc(1) } - var asyncDataSource AsyncSubscriptionDataSource - - if async, ok := add.resolve.Trigger.Source.(AsyncSubscriptionDataSource); ok { - trig.cancel = func() { - cancel() - async.AsyncStop(triggerID) - } - asyncDataSource = async - } - go func() { if r.options.Debug { fmt.Printf("resolver:trigger:start:%d\n", triggerID) } - // This is blocking so the startup hook can decide if a subscription should be started or not by returning an error - err = r.executeStartupHooks(add, trig.updater) - if err != nil { - return - } - - if asyncDataSource != nil { - err = asyncDataSource.AsyncStart(cloneCtx, triggerID, add.headers, add.input, trig.updater) - } else { + // The startup hook is blocking so it can reject the subscription before Source.Start. + // If either step fails, broadcast the error to all subs and tear down the trigger. + err := r.executeStartupHooks(add, trig.updater) + if err == nil { err = add.resolve.Trigger.Source.Start(cloneCtx, add.headers, add.input, trig.updater) } if err != nil { if r.options.Debug { fmt.Printf("resolver:trigger:failed:%d\n", triggerID) } - r.asyncErrorWriter.WriteError(add.ctx, err, add.resolve.Response, add.writer) - _ = r.emitTriggerClose(triggerID) + for _, sub := range trig.snapshotSubscriptions() { + sub.writeError(r.errorFormatter, sub.ctx, err, sub.resolve.Response) + } + r.doneTriggerFromUpdater(triggerID) return } - _ = r.emitTriggerInitialized(triggerID) + r.markTriggerInitialized(triggerID) if r.options.Debug { fmt.Printf("resolver:trigger:started:%d\n", triggerID) } }() + return nil +} +func (r *Resolver) getTrigger(id uint64) (*trigger, bool) { + r.mu.Lock() + trig, ok := r.triggers[id] + r.mu.Unlock() + return trig, ok } -func (r *Resolver) emitTriggerClose(triggerID uint64) error { - if r.options.Debug { - fmt.Printf("resolver:trigger:shutdown:%d\n", triggerID) +// markTriggerInitialized marks a trigger as initialized and reports it. +func (r *Resolver) markTriggerInitialized(triggerID uint64) { + trig, ok := r.getTrigger(triggerID) + if !ok { + return } - - select { - case <-r.ctx.Done(): - return r.ctx.Err() - case r.events <- subscriptionEvent{ - triggerID: triggerID, - kind: subscriptionEventKindTriggerClose, - closeKind: SubscriptionCloseKindNormal, - }: + trig.initialized.Store(true) + if r.reporter != nil { + r.reporter.TriggerCountInc(1) } - - return nil } -func (r *Resolver) emitTriggerInitialized(triggerID uint64) error { +// doneTriggerFromUpdater performs cleanup for a trigger from a datasource/updater goroutine. +// It detaches the trigger, runs done toClose (close completed channels), and cancels the trigger context. +func (r *Resolver) doneTriggerFromUpdater(triggerID uint64) { if r.options.Debug { - fmt.Printf("resolver:trigger:initialized:%d\n", triggerID) + fmt.Printf("resolver:trigger:shutdown:%d\n", triggerID) } - - select { - case <-r.ctx.Done(): - return r.ctx.Err() - case r.events <- subscriptionEvent{ - triggerID: triggerID, - kind: subscriptionEventKindTriggerInitialized, - }: + r.mu.Lock() + res := r.detachTriggerLocked(triggerID) + if r.reporter != nil { + r.reporter.SubscriptionCountDec(res.removed) + if res.initialized { + r.reporter.TriggerCountDec(1) + } + } + r.mu.Unlock() + closeSubs(res.toClose) + if res.triggerCancel != nil { + res.triggerCancel() } - - return nil } -func (r *Resolver) handleCompleteSubscription(id SubscriptionIdentifier) { - if r.options.Debug { - fmt.Printf("resolver:trigger:subscription:remove:%d:%d\n", id.ConnectionID, id.SubscriptionID) +// handleTriggerComplete delivers a complete signal to all subscriptions on the trigger. +// Does NOT detach the trigger — Done() does that. +func (r *Resolver) handleTriggerComplete(triggerID uint64) { + trig, ok := r.getTrigger(triggerID) + if !ok { + return } - removed := 0 - for u := range r.triggers { - trig := r.triggers[u] - removed += r.completeTriggerSubscriptions(u, func(sID SubscriptionIdentifier) bool { - return sID == id - }) - if len(trig.subscriptions) == 0 { - r.completeTrigger(trig.id) + subs := trig.snapshotSubscriptions() + + for _, s := range subs { + if !s.removed.Load() { + s.complete() } } - if r.reporter != nil { - r.reporter.SubscriptionCountDec(removed) - } } -func (r *Resolver) handleRemoveSubscription(id SubscriptionIdentifier) { - if r.options.Debug { - fmt.Printf("resolver:trigger:subscription:remove:%d:%d\n", id.ConnectionID, id.SubscriptionID) +// handleTriggerError delivers a terminal error to all subscriptions on the trigger, +// bypassing the resolve pipeline. Does NOT detach the trigger — Done() does that. +func (r *Resolver) handleTriggerError(triggerID uint64, data []byte) { + trig, ok := r.getTrigger(triggerID) + if !ok { + return } - removed := 0 - for u := range r.triggers { - trig := r.triggers[u] - removed += r.closeTriggerSubscriptions(u, SubscriptionCloseKindNormal, func(sID SubscriptionIdentifier) bool { - return sID == id - }) - if len(trig.subscriptions) == 0 { - r.closeTrigger(trig.id, SubscriptionCloseKindNormal) + subs := trig.snapshotSubscriptions() + + for _, s := range subs { + if !s.removed.Load() { + s.error(data) } } - if r.reporter != nil { - r.reporter.SubscriptionCountDec(removed) - } } -func (r *Resolver) handleRemoveClient(id int64) { +func (r *Resolver) removeClient(id ConnectionID) removeClientResult { + r.mu.Lock() + defer r.mu.Unlock() + if r.shutdown { + return removeClientResult{} + } if r.options.Debug { fmt.Printf("resolver:trigger:subscription:remove:client:%d\n", id) } removed := 0 - for u := range r.triggers { - removed += r.closeTriggerSubscriptions(u, SubscriptionCloseKindNormal, func(sID SubscriptionIdentifier) bool { - return sID.ConnectionID == id - }) - if len(r.triggers[u].subscriptions) == 0 { - r.closeTrigger(r.triggers[u].id, SubscriptionCloseKindNormal) + toClose := make([]*subscriptionState, 0) + cancels := make([]context.CancelFunc, 0) + triggerDec := 0 + idsForConn := r.subscriptionsByConnection[id] + ids := make([]SubscriptionIdentifier, 0, len(idsForConn)) + for sid := range idsForConn { + ids = append(ids, sid) + } + for _, sid := range ids { + res := r.removeSubscriptionLocked(sid) + removed += res.removed + toClose = append(toClose, res.toClose...) + if res.triggerCancel != nil { + cancels = append(cancels, res.triggerCancel) + if res.initialized { + triggerDec++ + } } } + res := removeClientResult{ + removed: removed, + toClose: toClose, + cancels: cancels, + triggerDec: triggerDec, + } if r.reporter != nil { - r.reporter.SubscriptionCountDec(removed) + r.reporter.SubscriptionCountDec(res.removed) + if res.triggerDec > 0 { + r.reporter.TriggerCountDec(res.triggerDec) + } } + return res } -func (r *Resolver) handleTriggerUpdate(id uint64, data []byte) { - trig, ok := r.triggers[id] +// removeSubscriptionLocked removes a single subscription by id. +// r.mu must be held by the caller. +func (r *Resolver) removeSubscriptionLocked(id SubscriptionIdentifier) removeResult { + s, ok := r.subscriptionsByID[id] if !ok { - return - } - if r.options.Debug { - fmt.Printf("resolver:trigger:update:%d\n", id) + return removeResult{} } - for c, s := range trig.subscriptions { - r.sendUpdateToSubscription(data, c, s) + trig, ok := r.triggers[s.triggerID] + if !ok { + r.unregisterSubscriptionLocked(id) + return removeResult{} } -} -func (r *Resolver) handleUpdateSubscription(id uint64, data []byte, subIdentifier SubscriptionIdentifier) { - trig, ok := r.triggers[id] + trig.mu.Lock() + _, ok = trig.subscriptions[id] if !ok { - return + trig.mu.Unlock() + r.unregisterSubscriptionLocked(id) + return removeResult{} } - if r.options.Debug { - fmt.Printf("resolver:trigger:subscription:update:%d:%d,%d\n", id, subIdentifier.ConnectionID, subIdentifier.SubscriptionID) + var toClose []*subscriptionState + if s.removed.CompareAndSwap(false, true) { + toClose = append(toClose, s) } + delete(trig.subscriptions, id) + empty := len(trig.subscriptions) == 0 + trig.mu.Unlock() - for c, s := range trig.subscriptions { - if s.id != subIdentifier { - continue - } - r.sendUpdateToSubscription(data, c, s) - break - } -} + r.unregisterSubscriptionLocked(id) -func (r *Resolver) sendUpdateToSubscription(data []byte, c *Context, s *sub) { - if err := c.ctx.Err(); err != nil { - return // no need to schedule an event update when the client already disconnected - } - skip, err := s.resolve.Filter.SkipEvent(c, data, r.triggerUpdateBuf) - if err != nil { - r.asyncErrorWriter.WriteError(c, err, s.resolve.Response, s.writer) - return - } - if skip { - return - } - - fn := func() { - r.executeSubscriptionUpdate(c, s, data) + var triggerCancel context.CancelFunc + initialized := false + if empty { + delete(r.triggers, trig.id) + triggerCancel = trig.cancel + initialized = trig.initialized.Load() } - select { - case <-r.ctx.Done(): - // Skip sending all events if the resolver is shutting down - return - case <-c.ctx.Done(): - // Skip sending the event if the client disconnected - case s.workChan <- workItem{fn, false}: - // Send the event to the subscription worker + return removeResult{ + removed: 1, + toClose: toClose, + triggerCancel: triggerCancel, + initialized: initialized, } } -func (r *Resolver) closeTrigger(id uint64, kind SubscriptionCloseKind) { - if r.options.Debug { - fmt.Printf("resolver:trigger:close:%d\n", id) - } +// detachTriggerLocked removes all subscriptions for the trigger and removes the trigger from resolver maps. +// r.mu must be held by the caller. +func (r *Resolver) detachTriggerLocked(id uint64) removeResult { trig, ok := r.triggers[id] if !ok { - return + return removeResult{} } - removed := r.closeTriggerSubscriptions(id, kind, nil) + toClose := make([]*subscriptionState, 0, len(trig.subscriptions)) + removed := 0 - // Cancels the async datasource and cleanup the connection - trig.cancel() + trig.mu.Lock() + for sid, s := range trig.subscriptions { + if s.removed.CompareAndSwap(false, true) { + toClose = append(toClose, s) + } + delete(trig.subscriptions, sid) + r.unregisterSubscriptionLocked(sid) + removed++ + } + trig.mu.Unlock() delete(r.triggers, id) - if r.reporter != nil { - r.reporter.SubscriptionCountDec(removed) - if trig.initialized { - r.reporter.TriggerCountDec(1) - } + return removeResult{ + removed: removed, + toClose: toClose, + triggerCancel: trig.cancel, + initialized: trig.initialized.Load(), } } -func (r *Resolver) completeTrigger(id uint64) { - if r.options.Debug { - fmt.Printf("resolver:trigger:complete:%d\n", id) - } +type removeResult struct { + removed int + toClose []*subscriptionState + triggerCancel context.CancelFunc // non-nil if trigger became empty + initialized bool // whether the removed trigger was initialized +} - trig, ok := r.triggers[id] +type removeClientResult struct { + removed int + toClose []*subscriptionState + cancels []context.CancelFunc + triggerDec int +} + +type pendingFilterError struct { + ctx *Context + err error + response *GraphQLResponse + sub *subscriptionState +} + +// handleTriggerUpdate sends data to all subscriptions of a trigger. +func (r *Resolver) handleTriggerUpdate(id uint64, data []byte) { + trig, ok := r.getTrigger(id) if !ok { return } + if r.options.Debug { + fmt.Printf("resolver:trigger:update:%d\n", id) + } - removed := r.completeTriggerSubscriptions(id, nil) - - // Cancels the async datasource and cleanup the connection - trig.cancel() + subs, filterErrors := trig.filterSubscriptions(data) - delete(r.triggers, id) + for _, fe := range filterErrors { + fe.sub.writeError(r.errorFormatter, fe.ctx, fe.err, fe.response) + } - if r.reporter != nil { - r.reporter.SubscriptionCountDec(removed) - if trig.initialized { - r.reporter.TriggerCountDec(1) + var wg sync.WaitGroup + for _, sub := range subs { + if sub.removed.Load() { + continue } + wg.Go(func() { + r.executeSubscriptionUpdate(sub.ctx, sub, data) + }) } + wg.Wait() } -func (r *Resolver) completeTriggerSubscriptions(id uint64, completeMatcher func(a SubscriptionIdentifier) bool) int { - trig, ok := r.triggers[id] +// handleUpdateSubscription sends data to a single subscription. +func (r *Resolver) handleUpdateSubscription(id uint64, data []byte, subIdentifier SubscriptionIdentifier) { + trig, ok := r.getTrigger(id) if !ok { - return 0 + return } - removed := 0 - for c, s := range trig.subscriptions { - if completeMatcher != nil && !completeMatcher(s.id) { - continue - } - - // Send a work item to complete the subscription - s.workChan <- workItem{s.complete, true} - // Because the event loop is single threaded, we can safely close the channel from this sender - // The subscription worker will finish processing all events before the channel is closed. - close(s.workChan) + if r.options.Debug { + fmt.Printf("resolver:trigger:subscription:update:%d:%d,%d\n", id, subIdentifier.ConnectionID, subIdentifier.SubscriptionID) + } - // Important because we remove the subscription from the trigger on the same goroutine - // as we send work to the subscription worker. We can ensure that no new work is sent to the worker after this point. - delete(trig.subscriptions, c) + sub, filterErr := trig.filterSubscription(subIdentifier, data) - if r.options.Debug { - fmt.Printf("resolver:trigger:subscription:closed:%d:%d\n", trig.id, s.id.SubscriptionID) - } + if filterErr != nil { + filterErr.sub.writeError(r.errorFormatter, filterErr.ctx, filterErr.err, filterErr.response) + } - removed++ + if sub != nil && !sub.removed.Load() { + r.executeSubscriptionUpdate(sub.ctx, sub, data) } - return removed } -func (r *Resolver) closeTriggerSubscriptions(id uint64, closeKind SubscriptionCloseKind, closeMatcher func(a SubscriptionIdentifier) bool) int { - trig, ok := r.triggers[id] +func (r *Resolver) heartbeatTriggerSubscriptions(id uint64) { + trig, ok := r.getTrigger(id) if !ok { - return 0 + return } - removed := 0 - for c, s := range trig.subscriptions { - if closeMatcher != nil && !closeMatcher(s.id) { + + subs := trig.snapshotSubscriptions() + targets := make([]*subscriptionState, 0, len(subs)) + for _, s := range subs { + if !s.heartbeat || s.removed.Load() { continue } - - // Send a work item to close the subscription - s.workChan <- workItem{func() { s.close(closeKind) }, true} - - // Because the event loop is single threaded, we can safely close the channel from this sender - // The subscription worker will finish processing all events before the channel is closed. - close(s.workChan) - - // Important because we remove the subscription from the trigger on the same goroutine - // as we send work to the subscription worker. We can ensure that no new work is sent to the worker after this point. - delete(trig.subscriptions, c) - - if r.options.Debug { - fmt.Printf("resolver:trigger:subscription:closed:%d:%d\n", trig.id, s.id.SubscriptionID) + if time.Since(time.Unix(0, s.lastWriteTime.Load())) < r.heartbeatInterval { + continue } + targets = append(targets, s) + } - removed++ + for _, s := range targets { + r.executeSubscriptionHeartbeat(s) } - return removed } -func (r *Resolver) handleShutdown() { +func (r *Resolver) shutdownResolver() { if r.options.Debug { fmt.Printf("resolver:trigger:shutdown\n") } + r.mu.Lock() + if r.shutdown { + r.mu.Unlock() + return + } + + r.shutdown = true + triggerIDs := make([]uint64, 0, len(r.triggers)) for id := range r.triggers { - r.closeTrigger(id, SubscriptionCloseKindGoingAway) + triggerIDs = append(triggerIDs, id) + } + + allToClose := make([]*subscriptionState, 0) + cancels := make([]context.CancelFunc, 0, len(triggerIDs)) + removedTotal := 0 + triggerDec := 0 + + for _, id := range triggerIDs { + res := r.detachTriggerLocked(id) + removedTotal += res.removed + allToClose = append(allToClose, res.toClose...) + if res.triggerCancel != nil { + cancels = append(cancels, res.triggerCancel) + } + if res.initialized { + triggerDec++ + } + } + + if r.reporter != nil { + r.reporter.SubscriptionCountDec(removedTotal) + if triggerDec > 0 { + r.reporter.TriggerCountDec(triggerDec) + } + } + + r.triggers = make(map[uint64]*trigger) + r.subscriptionsByID = make(map[SubscriptionIdentifier]*subscriptionState) + r.subscriptionsByConnection = make(map[ConnectionID]map[SubscriptionIdentifier]*subscriptionState) + r.mu.Unlock() + + closeSubs(allToClose) + for _, cancel := range cancels { + cancel() } + if r.options.Debug { fmt.Printf("resolver:trigger:shutdown:done\n") } - r.triggers = make(map[uint64]*trigger) } -type SubscriptionIdentifier struct { - ConnectionID int64 - SubscriptionID int64 +func (r *Resolver) heartbeatLoop() { + ticker := time.NewTicker(r.heartbeatInterval) + defer ticker.Stop() + for { + select { + case <-r.ctx.Done(): + return + case <-ticker.C: + r.sendTriggerHeartbeats() + } + } } -func (r *Resolver) AsyncCompleteSubscription(id SubscriptionIdentifier) error { - select { - case <-r.ctx.Done(): - return r.ctx.Err() - case r.events <- subscriptionEvent{ - id: id, - kind: subscriptionEventKindCompleteSubscription, - }: +func (r *Resolver) sendTriggerHeartbeats() { + r.mu.Lock() + triggerIDs := make([]uint64, 0, len(r.triggers)) + for id := range r.triggers { + triggerIDs = append(triggerIDs, id) + } + r.mu.Unlock() + + for _, id := range triggerIDs { + r.heartbeatTriggerSubscriptions(id) } - return nil } -func (r *Resolver) AsyncUnsubscribeSubscription(id SubscriptionIdentifier) error { - select { - case <-r.ctx.Done(): +type SubscriptionIdentifier struct { + ConnectionID ConnectionID + SubscriptionID int64 +} + +func (r *Resolver) UnsubscribeSubscription(id SubscriptionIdentifier) error { + r.mu.Lock() + if r.shutdown { + r.mu.Unlock() return r.ctx.Err() - case r.events <- subscriptionEvent{ - id: id, - kind: subscriptionEventKindRemoveSubscription, - }: - default: - // In the event we cannot insert immediately, defer insertion a goroutine, this should prevent deadlocks, at the cost of goroutine creation. - go func() { - select { - case <-r.ctx.Done(): - return - case r.events <- subscriptionEvent{ - id: id, - kind: subscriptionEventKindRemoveSubscription, - }: - } - }() + } + res := r.removeSubscriptionLocked(id) + if r.reporter != nil { + r.reporter.SubscriptionCountDec(res.removed) + if res.triggerCancel != nil && res.initialized { + r.reporter.TriggerCountDec(1) + } + } + r.mu.Unlock() + closeSubs(res.toClose) + if res.triggerCancel != nil { + res.triggerCancel() } return nil } -func (r *Resolver) AsyncUnsubscribeClient(connectionID int64) error { - select { - case <-r.ctx.Done(): - return r.ctx.Err() - case r.events <- subscriptionEvent{ - id: SubscriptionIdentifier{ - ConnectionID: connectionID, - }, - kind: subscriptionEventKindRemoveClient, - }: - default: - // In the event we cannot insert immediately, defer insertion a goroutine, this should prevent deadlocks, at the cost of goroutine creation. - go func() { - select { - case <-r.ctx.Done(): - return - case r.events <- subscriptionEvent{ - id: SubscriptionIdentifier{ - ConnectionID: connectionID, - }, - kind: subscriptionEventKindRemoveClient, - }: - } - }() +func (r *Resolver) UnsubscribeClient(connectionID ConnectionID) error { + res := r.removeClient(connectionID) + closeSubs(res.toClose) + for _, cancel := range res.cancels { + cancel() } return nil } @@ -1293,7 +1326,7 @@ func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQ headers, triggerID := r.prepareTrigger(ctx, subscription.Trigger.SourceName, input) id := SubscriptionIdentifier{ - ConnectionID: ConnectionIDs.Inc(), + ConnectionID: NewConnectionID(), SubscriptionID: 0, } if r.options.Debug { @@ -1302,43 +1335,34 @@ func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQ completed := make(chan struct{}) - select { - case <-r.ctx.Done(): - // Stop processing if the resolver is shutting down - return r.ctx.Err() - case r.events <- subscriptionEvent{ - triggerID: triggerID, - kind: subscriptionEventKindAddSubscription, - addSubscription: &addSubscription{ - ctx: ctx, - input: input, - resolve: subscription, - writer: writer, - id: id, - completed: completed, - sourceName: subscription.Trigger.SourceName, - headers: headers, - }, - }: + if err := r.addSubscription(triggerID, &addSubscription{ + ctx: ctx, + input: input, + resolve: subscription, + writer: writer, + id: id, + completed: completed, + sourceName: subscription.Trigger.SourceName, + headers: headers, + }); err != nil { + return err } // This will immediately block until one of the following conditions is met: select { case <-ctx.ctx.Done(): // Client disconnected, request context canceled. - // We will ignore the error and remove the subscription in the next step. - + _ = r.UnsubscribeSubscription(id) select { case <-completed: // Wait for the subscription to be completed to avoid race conditions // with go sdk request shutdown. case <-r.ctx.Done(): - // Resolver shutdown, no way to gracefully shut down the subscription + // Resolver shutdown return r.ctx.Err() } case <-r.ctx.Done(): - // Resolver shutdown, no way to gracefully shut down the subscription - // because the event loop is not running anymore and shutdown all triggers + subscriptions + // Resolver shutdown return r.ctx.Err() case <-completed: } @@ -1348,12 +1372,7 @@ func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQ } // Remove the subscription when the client disconnects. - - r.events <- subscriptionEvent{ - triggerID: triggerID, - kind: subscriptionEventKindRemoveSubscription, - id: id, - } + _ = r.UnsubscribeSubscription(id) return nil } @@ -1395,31 +1414,22 @@ func (r *Resolver) AsyncResolveGraphQLSubscription(ctx *Context, subscription *G return nil } + if err := ctx.ctx.Err(); err != nil { + return err + } + headers, triggerID := r.prepareTrigger(ctx, subscription.Trigger.SourceName, input) - select { - case <-r.ctx.Done(): - // Stop resolving if the resolver is shutting down - return r.ctx.Err() - case <-ctx.ctx.Done(): - // Stop resolving if the client is gone - return ctx.ctx.Err() - case r.events <- subscriptionEvent{ - triggerID: triggerID, - kind: subscriptionEventKindAddSubscription, - addSubscription: &addSubscription{ - ctx: ctx, - input: input, - resolve: subscription, - writer: writer, - id: id, - completed: make(chan struct{}), - sourceName: subscription.Trigger.SourceName, - headers: headers, - }, - }: - } - return nil + return r.addSubscription(triggerID, &addSubscription{ + ctx: ctx, + input: input, + resolve: subscription, + writer: writer, + id: id, + completed: make(chan struct{}), + sourceName: subscription.Trigger.SourceName, + headers: headers, + }) } func (r *Resolver) subscriptionInput(ctx *Context, subscription *GraphQLSubscription) (input []byte, err error) { @@ -1444,47 +1454,55 @@ func (r *Resolver) subscriptionInput(ctx *Context, subscription *GraphQLSubscrip return input, nil } +// subscriptionUpdater implements SubscriptionUpdater, the callback API for data sources. type subscriptionUpdater struct { + // mu serves two roles: + // + // 1. Event serialization gate -- held across the entire Update() call including + // wg.Wait(), ensuring event A fully completes before event B begins. + // + // 2. Lifecycle guard -- the done flag prevents callbacks after Done() has torn down + // the trigger. Every method checks done || ctx.Err() under the lock before proceeding. + mu sync.Mutex + done bool debug bool triggerID uint64 - ch chan subscriptionEvent + resolver *Resolver ctx context.Context subsFn func() map[context.Context]SubscriptionIdentifier } func (s *subscriptionUpdater) Update(data []byte) { + s.mu.Lock() + defer s.mu.Unlock() + if s.done || s.ctx.Err() != nil { + return + } if s.debug { fmt.Printf("resolver:subscription_updater:update:%d\n", s.triggerID) } + s.resolver.handleTriggerUpdate(s.triggerID, data) +} - select { - case <-s.ctx.Done(): - // Skip sending events if trigger is already done +func (s *subscriptionUpdater) Heartbeat() { + s.mu.Lock() + defer s.mu.Unlock() + if s.done || s.ctx.Err() != nil { return - case s.ch <- subscriptionEvent{ - triggerID: s.triggerID, - kind: subscriptionEventKindTriggerUpdate, - data: data, - }: } + s.resolver.heartbeatTriggerSubscriptions(s.triggerID) } func (s *subscriptionUpdater) UpdateSubscription(id SubscriptionIdentifier, data []byte) { + s.mu.Lock() + defer s.mu.Unlock() + if s.done || s.ctx.Err() != nil { + return + } if s.debug { fmt.Printf("resolver:subscription_updater:update:%d\n", s.triggerID) } - - select { - case <-s.ctx.Done(): - // Skip sending events if trigger is already done - return - case s.ch <- subscriptionEvent{ - triggerID: s.triggerID, - kind: subscriptionEventKindUpdateSubscription, - data: data, - id: id, - }: - } + s.resolver.handleUpdateSubscription(s.triggerID, data, id) } func (s *subscriptionUpdater) Subscriptions() map[context.Context]SubscriptionIdentifier { @@ -1492,81 +1510,62 @@ func (s *subscriptionUpdater) Subscriptions() map[context.Context]SubscriptionId } func (s *subscriptionUpdater) Complete() { - if s.debug { - fmt.Printf("resolver:subscription_updater:complete:%d\n", s.triggerID) - } - - select { - case <-s.ctx.Done(): - // Skip sending events if trigger is already done + s.mu.Lock() + defer s.mu.Unlock() + if s.done || s.ctx.Err() != nil { if s.debug { fmt.Printf("resolver:subscription_updater:complete:skip:%d\n", s.triggerID) } return - case s.ch <- subscriptionEvent{ - triggerID: s.triggerID, - kind: subscriptionEventKindTriggerComplete, - }: - if s.debug { - fmt.Printf("resolver:subscription_updater:complete:sent_event:%d\n", s.triggerID) - } } -} - -func (s *subscriptionUpdater) Close(kind SubscriptionCloseKind) { if s.debug { - fmt.Printf("resolver:subscription_updater:close:%d\n", s.triggerID) + fmt.Printf("resolver:subscription_updater:complete:%d\n", s.triggerID) } + s.resolver.handleTriggerComplete(s.triggerID) +} - select { - case <-s.ctx.Done(): - // Skip sending events if trigger is already done +func (s *subscriptionUpdater) Error(data []byte) { + s.mu.Lock() + defer s.mu.Unlock() + if s.done || s.ctx.Err() != nil { if s.debug { - fmt.Printf("resolver:subscription_updater:close:skip:%d\n", s.triggerID) + fmt.Printf("resolver:subscription_updater:error:skip:%d\n", s.triggerID) } return - case s.ch <- subscriptionEvent{ - triggerID: s.triggerID, - kind: subscriptionEventKindTriggerClose, - closeKind: kind, - }: - if s.debug { - fmt.Printf("resolver:subscription_updater:close:sent_event:%d\n", s.triggerID) - } } + if s.debug { + fmt.Printf("resolver:subscription_updater:error:%d\n", s.triggerID) + } + s.resolver.handleTriggerError(s.triggerID, data) } -func (s *subscriptionUpdater) CloseSubscription(kind SubscriptionCloseKind, id SubscriptionIdentifier) { +func (s *subscriptionUpdater) Done() { + s.mu.Lock() + defer s.mu.Unlock() + if s.done { + return + } + s.done = true if s.debug { - fmt.Printf("resolver:subscription_updater:close:%d\n", s.triggerID) + fmt.Printf("resolver:subscription_updater:done:%d\n", s.triggerID) } + s.resolver.doneTriggerFromUpdater(s.triggerID) +} - select { - case <-s.ctx.Done(): - // Skip sending events if trigger is already done +func (s *subscriptionUpdater) CloseSubscription(id SubscriptionIdentifier) { + s.mu.Lock() + defer s.mu.Unlock() + if s.done || s.ctx.Err() != nil { if s.debug { fmt.Printf("resolver:subscription_updater:close:skip:%d\n", s.triggerID) } return - case s.ch <- subscriptionEvent{ - triggerID: s.triggerID, - kind: subscriptionEventKindRemoveSubscription, - closeKind: kind, - id: id, - }: - if s.debug { - fmt.Printf("resolver:subscription_updater:close:sent_event:%d\n", s.triggerID) - } } -} + if s.debug { + fmt.Printf("resolver:subscription_updater:close:%d\n", s.triggerID) + } -type subscriptionEvent struct { - triggerID uint64 - id SubscriptionIdentifier - kind subscriptionEventKind - data []byte - addSubscription *addSubscription - closeKind SubscriptionCloseKind + _ = s.resolver.UnsubscribeSubscription(id) } type addSubscription struct { @@ -1580,32 +1579,22 @@ type addSubscription struct { headers http.Header } -type subscriptionEventKind int - -const ( - subscriptionEventKindUnknown subscriptionEventKind = iota - subscriptionEventKindTriggerUpdate - subscriptionEventKindTriggerComplete - subscriptionEventKindAddSubscription - subscriptionEventKindRemoveSubscription - subscriptionEventKindCompleteSubscription - subscriptionEventKindRemoveClient - subscriptionEventKindTriggerInitialized - subscriptionEventKindTriggerClose - subscriptionEventKindUpdateSubscription -) - type SubscriptionUpdater interface { // Update sends an update to the client. It is not guaranteed that the update is sent immediately. Update(data []byte) // UpdateSubscription sends an update to a single subscription. It is not guaranteed that the update is sent immediately. UpdateSubscription(id SubscriptionIdentifier, data []byte) - // Complete also takes care of cleaning up the trigger and all subscriptions. No more updates should be sent after calling Complete. + // Complete delivers a "subscription done" signal to all subscriptions on the trigger. + // Does not perform cleanup — call Done() after Complete(). Complete() - // Close closes the subscription and cleans up the trigger and all subscriptions. No more updates should be sent after calling Close. - Close(kind SubscriptionCloseKind) + // Error delivers a terminal error to all subscriptions on the trigger, bypassing the resolve pipeline. + // Does not perform cleanup — call Done() after Error(). + Error(data []byte) + // Done performs internal cleanup: detaches the trigger, closes completed channels. + // Must always be the final call. Does not send any downstream messages. + Done() // CloseSubscription closes a single subscription. No more updates should be sent to that subscription after calling CloseSubscription. - CloseSubscription(kind SubscriptionCloseKind, id SubscriptionIdentifier) + CloseSubscription(id SubscriptionIdentifier) // Subscriptions return all the subscriptions associated to this Updater Subscriptions() map[context.Context]SubscriptionIdentifier } diff --git a/v2/pkg/engine/resolve/resolve_test.go b/v2/pkg/engine/resolve/resolve_test.go index 98568556ab..8ce7f2066d 100644 --- a/v2/pkg/engine/resolve/resolve_test.go +++ b/v2/pkg/engine/resolve/resolve_test.go @@ -8,6 +8,7 @@ import ( "io" "net" "net/http" + "slices" "sync" "sync/atomic" "testing" @@ -1651,7 +1652,6 @@ func testFnWithPostEvaluationAndOptions(opts ResolverOptions, fn func(t *testing } func TestResolver_ResolveGraphQLResponse(t *testing.T) { - t.Run("empty graphql response", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { return &GraphQLResponse{ Data: &Object{ @@ -3066,8 +3066,9 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("when data null and errors present not nullable array should result to null data upstream error and resolve error", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { return &GraphQLResponse{ Fetches: SingleWithPath(&SingleFetch{ - FetchConfiguration: FetchConfiguration{DataSource: FakeDataSource( - `{"errors":[{"message":"Could not get name","locations":[{"line":3,"column":5}],"path":["todos","0","name"]}],"data":null}`), + FetchConfiguration: FetchConfiguration{ + DataSource: FakeDataSource( + `{"errors":[{"message":"Could not get name","locations":[{"line":3,"column":5}],"path":["todos","0","name"]}],"data":null}`), PostProcessing: PostProcessingConfiguration{ SelectResponseDataPath: []string{"data"}, SelectResponseErrorsPath: []string{"errors"}, @@ -3917,7 +3918,6 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { }, }, { - Name: []byte("reviews"), Value: &Array{ Path: []string{"reviews"}, @@ -3964,7 +3964,6 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { }, *NewContext(context.Background()), `{"data":{"me":{"id":"1234","username":"Me","reviews":[{"body":"foo","product":{"upc":"top-1","name":"Trilby"}},{"body":"bar","product":{"upc":"top-2","name":"Fedora"}},{"body":"baz","product":null},{"body":"bat","product":{"upc":"top-4","name":"Boater"}},{"body":"bal","product":{"upc":"top-5","name":"Top Hat"}},{"body":"ban","product":{"upc":"top-6","name":"Bowler"}}]}}}` })) t.Run("federation with fetch error", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { - userService := NewMockDataSource(ctrl) userService.EXPECT(). Load(gomock.Any(), gomock.Any(), gomock.Any()). @@ -4105,7 +4104,6 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { }, }, { - Name: []byte("reviews"), Value: &Array{ Path: []string{"reviews"}, @@ -4151,7 +4149,6 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { }, *NewContext(context.Background()), `{"errors":[{"message":"Failed to fetch from Subgraph at Path 'query.me.reviews.@.product', Reason: no data or errors in response."},{"message":"Cannot return null for non-nullable field 'Query.me.reviews.product.name'.","path":["me","reviews",0,"product","name"]},{"message":"Cannot return null for non-nullable field 'Query.me.reviews.product.name'.","path":["me","reviews",1,"product","name"]}],"data":{"me":{"id":"1234","username":"Me","reviews":[null,null]}}}` })) t.Run("federation with fetch error and non null fields inside an array", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { - userService := NewMockDataSource(ctrl) userService.EXPECT(). Load(gomock.Any(), gomock.Any(), gomock.Any()). @@ -4291,7 +4288,6 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { }, }, { - Name: []byte("reviews"), Value: &Array{ Path: []string{"reviews"}, @@ -4567,7 +4563,6 @@ func testFnArena(fn func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLRe } func TestResolver_ArenaResolveGraphQLResponse(t *testing.T) { - t.Run("empty graphql response", testFnArena(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx *Context, expectedOutput string) { resolveCtx := NewContext(context.Background()) return &GraphQLResponse{ @@ -5662,7 +5657,7 @@ type SubscriptionRecorder struct { buf *bytes.Buffer messages []string complete atomic.Bool - closed atomic.Bool + errors [][]byte mux sync.Mutex onFlush func(p []byte) } @@ -5717,20 +5712,6 @@ func (s *SubscriptionRecorder) AwaitComplete(t *testing.T, timeout time.Duration } } -func (s *SubscriptionRecorder) AwaitClosed(t *testing.T, timeout time.Duration) { - t.Helper() - deadline := time.Now().Add(timeout) - for { - if s.closed.Load() { - return - } - if time.Now().After(deadline) { - t.Fatalf("timed out waiting for close") - } - time.Sleep(time.Millisecond * 10) - } -} - func (s *SubscriptionRecorder) Write(p []byte) (n int, err error) { s.mux.Lock() defer s.mux.Unlock() @@ -5756,8 +5737,10 @@ func (s *SubscriptionRecorder) Heartbeat() error { return nil } -func (s *SubscriptionRecorder) Close(_ SubscriptionCloseKind) { - s.closed.Store(true) +func (s *SubscriptionRecorder) Error(data []byte) { + s.mux.Lock() + s.errors = append(s.errors, slices.Clone(data)) + s.mux.Unlock() } func (s *SubscriptionRecorder) Messages() []string { @@ -5817,7 +5800,7 @@ func (f *_fakeStream) Start(ctx *Context, headers http.Header, input []byte, upd for { select { case <-ctx.ctx.Done(): - updater.Complete() + updater.Done() f.isDone.Store(true) return default: @@ -5826,6 +5809,7 @@ func (f *_fakeStream) Start(ctx *Context, headers http.Header, input []byte, upd if done { time.Sleep(f.delay) updater.Complete() + updater.Done() f.isDone.Store(true) return } @@ -5844,7 +5828,6 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { } setup := func(ctx context.Context, stream SubscriptionDataSource) (*Resolver, *GraphQLSubscription, *SubscriptionRecorder, SubscriptionIdentifier) { - fetches := Sequence() fetches.Trigger = &FetchTreeNode{ Kind: FetchTreeNodeKindTrigger, @@ -6099,7 +6082,7 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { messages := recorder.Messages() assert.Greater(t, len(messages), 2) - time.Sleep(resolver.heartbeatInterval) + time.Sleep(10 * time.Millisecond) // Validate that despite the time, we don't see any heartbeats sent assert.Contains(t, messages, `{"data":{"counter":0}}`) assert.Contains(t, messages, `{"data":{"counter":1}}`) @@ -6126,7 +6109,7 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { } for i := 1; i <= 10; i++ { - id.ConnectionID = int64(i) + id.ConnectionID = ConnectionID(i) id.SubscriptionID = int64(i) recorder.complete.Store(false) err := resolver.AsyncResolveGraphQLSubscription(ctx, plan, recorder, id) @@ -6136,7 +6119,7 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { recorder.AwaitComplete(t, defaultTimeout) - time.Sleep(resolver.heartbeatInterval) + time.Sleep(10 * time.Millisecond) assert.Len(t, recorder.Messages(), 20) @@ -6239,9 +6222,8 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { err := resolver.AsyncResolveGraphQLSubscription(&ctx, plan, recorder, id) assert.NoError(t, err) recorder.AwaitAnyMessageCount(t, defaultTimeout) - err = resolver.AsyncUnsubscribeSubscription(id) + err = resolver.UnsubscribeSubscription(id) assert.NoError(t, err) - recorder.AwaitClosed(t, defaultTimeout) fakeStream.AwaitIsDone(t, defaultTimeout) }) @@ -6264,9 +6246,28 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { err := resolver.AsyncResolveGraphQLSubscription(&ctx, plan, recorder, id) assert.NoError(t, err) recorder.AwaitAnyMessageCount(t, defaultTimeout) - err = resolver.AsyncUnsubscribeClient(id.ConnectionID) + err = resolver.UnsubscribeClient(id.ConnectionID) + assert.NoError(t, err) + fakeStream.AwaitIsDone(t, defaultTimeout) + }) + + t.Run("should stop stream on unsubscribe client with close reason", func(t *testing.T) { + c, cancel := context.WithCancel(context.Background()) + defer cancel() + + fakeStream := createFakeStream(func(counter int) (message string, done bool) { + return fmt.Sprintf(`{"data":{"counter":%d}}`, counter), false + }, time.Millisecond*10, nil, nil) + + resolver, plan, recorder, id := setup(c, fakeStream) + + ctx := NewContext(context.Background()) + + err := resolver.AsyncResolveGraphQLSubscription(ctx, plan, recorder, id) + assert.NoError(t, err) + recorder.AwaitAnyMessageCount(t, defaultTimeout) + err = resolver.UnsubscribeClient(id.ConnectionID) assert.NoError(t, err) - recorder.AwaitClosed(t, defaultTimeout) fakeStream.AwaitIsDone(t, defaultTimeout) }) @@ -6340,13 +6341,6 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { resolver, plan, _, _ := setup(c, fakeStream) - ctx := &Context{ - ctx: context.Background(), - ExecutionOptions: ExecutionOptions{ - SendHeartbeat: true, - }, - } - const numSubscriptions = 2 var resolverCompleted atomic.Uint32 var recorderCompleted atomic.Uint32 @@ -6358,6 +6352,11 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { } recorder.complete.Store(false) + // Each subscription needs its own Context so they get separate entries + // in the trigger's subscriptions map (keyed by *Context). + subCtx := NewContext(context.Background()) + subCtx.ExecutionOptions.SendHeartbeat = true + go func() { defer recorderCompleted.Add(1) @@ -6367,7 +6366,7 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { go func() { defer resolverCompleted.Add(1) - err := resolver.ResolveGraphQLSubscription(ctx, plan, recorder) + err := resolver.ResolveGraphQLSubscription(subCtx, plan, recorder) assert.ErrorIs(t, err, context.Canceled) }() } @@ -6408,9 +6407,8 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { assert.NoError(t, err) recorder.AwaitAnyMessageCount(t, defaultTimeout) - err = resolver.AsyncUnsubscribeSubscription(id) + err = resolver.UnsubscribeSubscription(id) assert.NoError(t, err) - recorder.AwaitClosed(t, defaultTimeout) fakeStream.AwaitIsDone(t, defaultTimeout) }) @@ -8048,6 +8046,199 @@ func Benchmark_NestedBatchingArena(b *testing.B) { }) } +// startFailStream blocks until subBReady is closed, then returns an error from Start. +type startFailStream struct { + subBReady chan struct{} +} + +func (s *startFailStream) Start(_ *Context, _ http.Header, _ []byte, _ SubscriptionUpdater) error { + <-s.subBReady + return errors.New("connection refused") +} + +func TestSourceStartFailure(t *testing.T) { + t.Run("broadcasts error to all subscribers", func(t *testing.T) { + resolverCtx, cancelResolver := context.WithCancel(context.Background()) + defer cancelResolver() + + subBReady := make(chan struct{}) + + stream := &startFailStream{subBReady: subBReady} + + plan := &GraphQLSubscription{ + Trigger: GraphQLSubscriptionTrigger{ + Source: stream, + InputTemplate: InputTemplate{ + Segments: []TemplateSegment{ + { + SegmentType: StaticSegmentType, + Data: []byte(`{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { counter }"}}`), + }, + }, + }, + PostProcessing: PostProcessingConfiguration{ + SelectResponseDataPath: []string{"data"}, + SelectResponseErrorsPath: []string{"errors"}, + }, + }, + Response: &GraphQLResponse{ + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("counter"), + Value: &Integer{ + Path: []string{"counter"}, + }, + }, + }, + }, + }, + } + + resolver := New(resolverCtx, ResolverOptions{ + MaxConcurrency: 1024, + AsyncErrorWriter: &TestErrorWriter{}, + SubscriptionHeartbeatInterval: time.Hour, + }) + + writerA := &SubscriptionRecorder{buf: &bytes.Buffer{}} + writerB := &SubscriptionRecorder{buf: &bytes.Buffer{}} + + idA := SubscriptionIdentifier{ConnectionID: NewConnectionID(), SubscriptionID: 1} + idB := SubscriptionIdentifier{ConnectionID: NewConnectionID(), SubscriptionID: 2} + + ctxA := NewContext(context.Background()) + ctxB := NewContext(context.Background()) + + // Subscribe A — creates the trigger. + err := resolver.AsyncResolveGraphQLSubscription(ctxA, plan, writerA, idA) + require.NoError(t, err) + + // Subscribe B — joins the existing trigger. + err = resolver.AsyncResolveGraphQLSubscription(ctxB, plan, writerB, idB) + require.NoError(t, err) + + // Unblock Source.Start() so it returns an error. + close(subBReady) + + timeout := 2 * time.Second + writerA.AwaitAnyMessageCount(t, timeout) + writerB.AwaitAnyMessageCount(t, timeout) + }) +} + +// hookFailStream is a SubscriptionDataSource + HookableSubscriptionDataSource. +// The hook for the subscription whose context matches failCtx blocks until +// subBRegistered is closed, then returns an error. All other hooks succeed. +type hookFailStream struct { + subBRegistered chan struct{} + failCtx context.Context + sourceStarted atomic.Bool +} + +func (s *hookFailStream) Start(_ *Context, _ http.Header, _ []byte, _ SubscriptionUpdater) error { + s.sourceStarted.Store(true) + select {} +} + +func (s *hookFailStream) SubscriptionOnStart(ctx StartupHookContext, _ []byte) error { + if ctx.Context == s.failCtx { + <-s.subBRegistered + return errors.New("startup hook failed") + } + return nil +} + +func TestStartupHookFailure(t *testing.T) { + t.Run("cleans up all subscribers when trigger creator hook fails", func(t *testing.T) { + resolverCtx, cancelResolver := context.WithCancel(context.Background()) + defer cancelResolver() + + ctxAInner, cancelA := context.WithCancel(context.Background()) + defer cancelA() + ctxBInner, cancelB := context.WithCancel(context.Background()) + defer cancelB() + + subBRegistered := make(chan struct{}) + + stream := &hookFailStream{ + subBRegistered: subBRegistered, + failCtx: ctxAInner, + } + + plan := &GraphQLSubscription{ + Trigger: GraphQLSubscriptionTrigger{ + Source: stream, + InputTemplate: InputTemplate{ + Segments: []TemplateSegment{ + { + SegmentType: StaticSegmentType, + Data: []byte(`{"method":"POST","url":"http://localhost:4000","body":{"query":"subscription { counter }"}}`), + }, + }, + }, + PostProcessing: PostProcessingConfiguration{ + SelectResponseDataPath: []string{"data"}, + SelectResponseErrorsPath: []string{"errors"}, + }, + }, + Response: &GraphQLResponse{ + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("counter"), + Value: &Integer{ + Path: []string{"counter"}, + }, + }, + }, + }, + }, + } + + resolver := New(resolverCtx, ResolverOptions{ + MaxConcurrency: 1024, + AsyncErrorWriter: &TestErrorWriter{}, + SubscriptionHeartbeatInterval: time.Hour, + }) + + writerA := &SubscriptionRecorder{buf: &bytes.Buffer{}} + writerB := &SubscriptionRecorder{buf: &bytes.Buffer{}} + + idA := SubscriptionIdentifier{ConnectionID: NewConnectionID(), SubscriptionID: 1} + idB := SubscriptionIdentifier{ConnectionID: NewConnectionID(), SubscriptionID: 2} + + ctxA := NewContext(ctxAInner) + ctxB := NewContext(ctxBInner) + + // Subscribe A — creates the trigger. + err := resolver.AsyncResolveGraphQLSubscription(ctxA, plan, writerA, idA) + require.NoError(t, err) + + // Subscribe B — joins the existing trigger. + err = resolver.AsyncResolveGraphQLSubscription(ctxB, plan, writerB, idB) + require.NoError(t, err) + + // Unblock A's startup hook so it fails. + close(subBRegistered) + + // Sub A should receive an error. + writerA.AwaitAnyMessageCount(t, time.Second) + + // Sub B should also be cleaned up — not left orphaned on a triggerless source. + require.Eventually(t, func() bool { + writerB.mux.Lock() + hasMessages := len(writerB.messages) > 0 + writerB.mux.Unlock() + return hasMessages || writerB.complete.Load() + }, 2*time.Second, 10*time.Millisecond, + "sub B was orphaned: no error, no complete — "+ + "trigger left without a data source after first subscription's startup hook failed") + + require.False(t, stream.sourceStarted.Load(), "Source.Start() should not have been called") + }) +} + func Benchmark_NoCheckNestedBatching(b *testing.B) { rCtx, cancel := context.WithCancel(context.Background()) defer cancel() diff --git a/v2/pkg/engine/resolve/event_loop_test.go b/v2/pkg/engine/resolve/resolver_subscription_test.go similarity index 69% rename from v2/pkg/engine/resolve/event_loop_test.go rename to v2/pkg/engine/resolve/resolver_subscription_test.go index ba8b7c8e2f..0421316454 100644 --- a/v2/pkg/engine/resolve/event_loop_test.go +++ b/v2/pkg/engine/resolve/resolver_subscription_test.go @@ -2,6 +2,7 @@ package resolve import ( "context" + "errors" "io" "net/http" "sync" @@ -15,7 +16,6 @@ import ( type FakeErrorWriter struct{} func (f *FakeErrorWriter) WriteError(ctx *Context, err error, res *GraphQLResponse, w io.Writer) { - } type FakeSubscriptionWriter struct { @@ -23,7 +23,6 @@ type FakeSubscriptionWriter struct { buf []byte writtenMessages []string completed bool - closed bool messageCountOnComplete int } @@ -59,11 +58,7 @@ func (f *FakeSubscriptionWriter) Heartbeat() error { return nil } -func (f *FakeSubscriptionWriter) Close(SubscriptionCloseKind) { - f.mu.Lock() - defer f.mu.Unlock() - f.closed = true - f.messageCountOnComplete = len(f.writtenMessages) +func (f *FakeSubscriptionWriter) Error([]byte) { } type FakeSource struct { @@ -80,17 +75,35 @@ func (f *FakeSource) Start(ctx *Context, headers http.Header, input []byte, upda } } updater.Complete() + updater.Done() }() return nil } +type FailingHeartbeatWriter struct{} + +func (f *FailingHeartbeatWriter) Write(p []byte) (n int, err error) { + return len(p), nil +} + +func (f *FailingHeartbeatWriter) Flush() error { + return nil +} + +func (f *FailingHeartbeatWriter) Complete() {} + +func (f *FailingHeartbeatWriter) Heartbeat() error { + return errors.New("heartbeat failed") +} + +func (f *FailingHeartbeatWriter) Error([]byte) {} + type TestReporter struct { triggers atomic.Int64 subscriptions atomic.Int64 } func (t *TestReporter) SubscriptionUpdateSent() { - } func (t *TestReporter) SubscriptionCountInc(count int) { @@ -110,7 +123,6 @@ func (t *TestReporter) TriggerCountDec(count int) { } func TestEventLoop(t *testing.T) { - resolverCtx, stopEventLoop := context.WithCancel(context.Background()) t.Cleanup(stopEventLoop) @@ -187,5 +199,58 @@ func TestEventLoop(t *testing.T) { require.Equal(t, int64(0), subscriptionCount) return true }, time.Second, time.Millisecond*10) +} + +func TestResolver_HeartbeatError_DoesNotDeadlockOnUnsubscribe(t *testing.T) { + resolverCtx, cancelResolver := context.WithCancel(context.Background()) + defer cancelResolver() + + resolver := New(resolverCtx, ResolverOptions{ + MaxConcurrency: 1, + AsyncErrorWriter: &FakeErrorWriter{}, + SubscriptionHeartbeatInterval: time.Hour, // Long interval to prevent background heartbeat loop from competing + }) + + subCtx := NewContext(context.Background()) + subID := SubscriptionIdentifier{ + ConnectionID: 1, + SubscriptionID: 1, + } + triggerID := uint64(42) + s := &subscriptionState{ + triggerID: triggerID, + ctx: subCtx, + writer: &FailingHeartbeatWriter{}, + id: subID, + heartbeat: true, + completed: make(chan struct{}), + } + resolver.mu.Lock() + resolver.triggers[triggerID] = &trigger{ + id: triggerID, + cancel: func() {}, + subscriptions: map[SubscriptionIdentifier]*subscriptionState{subID: s}, + } + resolver.subscriptionsByID[subID] = s + resolver.subscriptionsByConnection[subID.ConnectionID] = map[SubscriptionIdentifier]*subscriptionState{subID: s} + resolver.mu.Unlock() + + done := make(chan struct{}) + go func() { + resolver.heartbeatTriggerSubscriptions(triggerID) + close(done) + }() + + select { + case <-done: + case <-time.After(250 * time.Millisecond): + t.Fatal("heartbeatTriggerSubscriptions deadlocked after heartbeat error") + } + + select { + case <-s.completed: + case <-time.After(250 * time.Millisecond): + t.Fatal("subscription was not closed after heartbeat failure") + } } diff --git a/v2/pkg/engine/resolve/response.go b/v2/pkg/engine/resolve/response.go index d8af8d017b..36d2f3a7e1 100644 --- a/v2/pkg/engine/resolve/response.go +++ b/v2/pkg/engine/resolve/response.go @@ -3,8 +3,6 @@ package resolve import ( "io" - "github.com/gobwas/ws" - "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" "github.com/wundergraph/graphql-go-tools/v2/pkg/lexer/literal" ) @@ -68,23 +66,12 @@ type ResponseWriter interface { io.Writer } -type SubscriptionCloseKind struct { - WSCode ws.StatusCode - Reason string -} - -var ( - SubscriptionCloseKindNormal SubscriptionCloseKind = SubscriptionCloseKind{ws.StatusNormalClosure, "Normal closure"} - SubscriptionCloseKindDownstreamServiceError SubscriptionCloseKind = SubscriptionCloseKind{ws.StatusGoingAway, "Downstream service error"} - SubscriptionCloseKindGoingAway SubscriptionCloseKind = SubscriptionCloseKind{ws.StatusGoingAway, "Going away"} -) - type SubscriptionResponseWriter interface { ResponseWriter Flush() error Complete() Heartbeat() error - Close(kind SubscriptionCloseKind) + Error(data []byte) } func writeGraphqlResponse(buf *BufPair, writer io.Writer, ignoreData bool) (err error) { diff --git a/v2/pkg/engine/resolve/subscription_filter.go b/v2/pkg/engine/resolve/subscription_filter.go index 25e0df5d65..2bffe480a0 100644 --- a/v2/pkg/engine/resolve/subscription_filter.go +++ b/v2/pkg/engine/resolve/subscription_filter.go @@ -11,6 +11,7 @@ import ( "github.com/wundergraph/astjson" "github.com/wundergraph/graphql-go-tools/v2/pkg/lexer/literal" + "github.com/wundergraph/graphql-go-tools/v2/pkg/pool" ) type SubscriptionFilter struct { @@ -25,14 +26,14 @@ type SubscriptionFieldFilter struct { Values []InputTemplate } -func (f *SubscriptionFilter) SkipEvent(ctx *Context, data []byte, buf *bytes.Buffer) (bool, error) { +func (f *SubscriptionFilter) SkipEvent(ctx *Context, data []byte) (bool, error) { if f == nil { return false, nil } if f.And != nil { for _, filter := range f.And { - skip, err := filter.SkipEvent(ctx, data, buf) + skip, err := filter.SkipEvent(ctx, data) if err != nil { return false, err } @@ -47,7 +48,7 @@ func (f *SubscriptionFilter) SkipEvent(ctx *Context, data []byte, buf *bytes.Buf if f.Or != nil { for _, filter := range f.Or { - skip, err := filter.SkipEvent(ctx, data, buf) + skip, err := filter.SkipEvent(ctx, data) if err != nil { return false, err } @@ -61,7 +62,7 @@ func (f *SubscriptionFilter) SkipEvent(ctx *Context, data []byte, buf *bytes.Buf } if f.Not != nil { - skip, err := f.Not.SkipEvent(ctx, data, buf) + skip, err := f.Not.SkipEvent(ctx, data) if err != nil { return false, err } @@ -69,7 +70,7 @@ func (f *SubscriptionFilter) SkipEvent(ctx *Context, data []byte, buf *bytes.Buf } if f.In != nil { - return f.In.SkipEvent(ctx, data, buf) + return f.In.SkipEvent(ctx, data) } return false, nil @@ -84,7 +85,7 @@ var ( ErrInvalidSubscriptionFilterTemplate = errors.New("invalid subscription filter template") ) -func (f *SubscriptionFieldFilter) SkipEvent(ctx *Context, data []byte, buf *bytes.Buffer) (bool, error) { +func (f *SubscriptionFieldFilter) SkipEvent(ctx *Context, data []byte) (bool, error) { if f == nil { return false, nil } @@ -94,6 +95,11 @@ func (f *SubscriptionFieldFilter) SkipEvent(ctx *Context, data []byte, buf *byte return true, nil } + // Scratch buffer for rendering filter template values. Pooled to avoid + // per-event allocations. + buf := pool.BytesBuffer.Get() + defer pool.BytesBuffer.Put(buf) + for i := range f.Values { buf.Reset() err := f.Values[i].Render(ctx, nil, buf) diff --git a/v2/pkg/engine/resolve/subscription_filter_test.go b/v2/pkg/engine/resolve/subscription_filter_test.go index 807ef706e6..98c6eb8068 100644 --- a/v2/pkg/engine/resolve/subscription_filter_test.go +++ b/v2/pkg/engine/resolve/subscription_filter_test.go @@ -1,7 +1,6 @@ package resolve import ( - "bytes" "testing" "github.com/stretchr/testify/assert" @@ -31,9 +30,8 @@ func TestSubscriptionFilter(t *testing.T) { c := &Context{ Variables: astjson.MustParseBytes([]byte(`{"var":true}`)), } - buf := &bytes.Buffer{} data := []byte(`{"event":true}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, false, skip) }) @@ -58,9 +56,8 @@ func TestSubscriptionFilter(t *testing.T) { c := &Context{ Variables: astjson.MustParseBytes([]byte(`{"var":"false"}`)), } - buf := &bytes.Buffer{} data := []byte(`{"event":true}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, true, skip) }) @@ -85,18 +82,16 @@ func TestSubscriptionFilter(t *testing.T) { c := &Context{ Variables: astjson.MustParseBytes([]byte(`{"var":"true"}`)), } - buf := &bytes.Buffer{} data := []byte(`{"event":true}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, true, skip) c = &Context{ Variables: astjson.MustParseBytes([]byte(`{"var":true}`)), } - buf = &bytes.Buffer{} data = []byte(`{"event":"true"}`) - skip, err = filter.SkipEvent(c, data, buf) + skip, err = filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, true, skip) }) @@ -121,9 +116,8 @@ func TestSubscriptionFilter(t *testing.T) { c := &Context{ Variables: astjson.MustParseBytes([]byte(`{"var":1.13}`)), } - buf := &bytes.Buffer{} data := []byte(`{"event":1.13}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, false, skip) }) @@ -148,18 +142,16 @@ func TestSubscriptionFilter(t *testing.T) { c := &Context{ Variables: astjson.MustParseBytes([]byte(`{"var":"1.13"}`)), } - buf := &bytes.Buffer{} data := []byte(`{"event":1.13}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, true, skip) c = &Context{ Variables: astjson.MustParseBytes([]byte(`{"var":1.13}`)), } - buf = &bytes.Buffer{} data = []byte(`{"event":"1.13"}`) - skip, err = filter.SkipEvent(c, data, buf) + skip, err = filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, true, skip) }) @@ -184,9 +176,8 @@ func TestSubscriptionFilter(t *testing.T) { c := &Context{ Variables: astjson.MustParseBytes([]byte(`{"var":49}`)), } - buf := &bytes.Buffer{} data := []byte(`{"event":49}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, false, skip) }) @@ -211,18 +202,16 @@ func TestSubscriptionFilter(t *testing.T) { c := &Context{ Variables: astjson.MustParseBytes([]byte(`{"var":"49"}`)), } - buf := &bytes.Buffer{} data := []byte(`{"event":49}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, true, skip) c = &Context{ Variables: astjson.MustParseBytes([]byte(`{"var":49}`)), } - buf = &bytes.Buffer{} data = []byte(`{"event":"49"}`) - skip, err = filter.SkipEvent(c, data, buf) + skip, err = filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, true, skip) }) @@ -247,9 +236,8 @@ func TestSubscriptionFilter(t *testing.T) { c := &Context{ Variables: astjson.MustParseBytes([]byte(`{"var":"9.77"}`)), } - buf := &bytes.Buffer{} data := []byte(`{"event":8.01}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, true, skip) }) @@ -274,9 +262,8 @@ func TestSubscriptionFilter(t *testing.T) { c := &Context{ Variables: astjson.MustParseBytes([]byte(`{"var":123}`)), } - buf := &bytes.Buffer{} data := []byte(`{"event":321}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, true, skip) }) @@ -301,9 +288,8 @@ func TestSubscriptionFilter(t *testing.T) { c := &Context{ Variables: astjson.MustParseBytes([]byte(`{"var":true}`)), } - buf := &bytes.Buffer{} data := []byte(`{"event":true}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, false, skip) }) @@ -328,9 +314,8 @@ func TestSubscriptionFilter(t *testing.T) { c := &Context{ Variables: astjson.MustParseBytes([]byte(`{"var":["a","b"]}`)), } - buf := &bytes.Buffer{} data := []byte(`{"event":"c"}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, true, skip) }) @@ -355,9 +340,8 @@ func TestSubscriptionFilter(t *testing.T) { c := &Context{ Variables: astjson.MustParseBytes([]byte(`{"var":[1,"2"]}`)), } - buf := &bytes.Buffer{} data := []byte(`{"event":2}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, true, skip) }) @@ -382,9 +366,8 @@ func TestSubscriptionFilter(t *testing.T) { c := &Context{ Variables: astjson.MustParseBytes([]byte(`{"var":["a","b","c"]}`)), } - buf := &bytes.Buffer{} data := []byte(`{"event":"c"}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, false, skip) }) @@ -411,9 +394,8 @@ func TestSubscriptionFilter(t *testing.T) { c := &Context{ Variables: astjson.MustParseBytes([]byte(`{"var":"b"}`)), } - buf := &bytes.Buffer{} data := []byte(`{"event":"b"}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, true, skip) }) @@ -440,9 +422,8 @@ func TestSubscriptionFilter(t *testing.T) { c := &Context{ Variables: astjson.MustParseBytes([]byte(`{"var":"b"}`)), } - buf := &bytes.Buffer{} data := []byte(`{"event":"c"}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, false, skip) }) @@ -488,9 +469,8 @@ func TestSubscriptionFilter(t *testing.T) { c := &Context{ Variables: astjson.MustParseBytes([]byte(`{"first":"b","second":"c"}`)), } - buf := &bytes.Buffer{} data := []byte(`{"eventX":"b","eventY":"c"}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, false, skip) }) @@ -530,9 +510,8 @@ func TestSubscriptionFilter(t *testing.T) { }, } c := &Context{} - buf := &bytes.Buffer{} data := []byte(`{"eventX":"b","eventY":"c"}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, false, skip) }) @@ -557,9 +536,8 @@ func TestSubscriptionFilter(t *testing.T) { }, } c := &Context{} - buf := &bytes.Buffer{} data := []byte(`{"eventX":true,"eventY":"c"}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, false, skip) }) @@ -586,9 +564,8 @@ func TestSubscriptionFilter(t *testing.T) { c := &Context{ Variables: astjson.MustParseBytes([]byte(`{"id":1}`)), } - buf := &bytes.Buffer{} data := []byte(`{"eventX":1,"eventY":"c"}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, true, skip) }) @@ -613,9 +590,8 @@ func TestSubscriptionFilter(t *testing.T) { }, } c := &Context{} - buf := &bytes.Buffer{} data := []byte(`{"eventX":null,"eventY":"c"}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, false, skip) }) @@ -661,9 +637,8 @@ func TestSubscriptionFilter(t *testing.T) { c := &Context{ Variables: astjson.MustParseBytes([]byte(`{"first":"d","second":"c"}`)), } - buf := &bytes.Buffer{} data := []byte(`{"eventX":"b","eventY":"c"}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, true, skip) }) @@ -709,9 +684,8 @@ func TestSubscriptionFilter(t *testing.T) { c := &Context{ Variables: astjson.MustParseBytes([]byte(`{"first":"b","unused":"c"}`)), } - buf := &bytes.Buffer{} data := []byte(`{"eventX":"b","eventY":"c"}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, true, skip) }) @@ -757,9 +731,8 @@ func TestSubscriptionFilter(t *testing.T) { c := &Context{ Variables: astjson.MustParseBytes([]byte(`{"first":"b","second":"c"}`)), } - buf := &bytes.Buffer{} data := []byte(`{"eventX":"b","eventY":"c"}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, false, skip) }) @@ -805,9 +778,8 @@ func TestSubscriptionFilter(t *testing.T) { c := &Context{ Variables: astjson.MustParseBytes([]byte(`{"first":"b","unused":"c"}`)), } - buf := &bytes.Buffer{} data := []byte(`{"eventX":"b","eventY":"c"}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, false, skip) }) @@ -853,9 +825,8 @@ func TestSubscriptionFilter(t *testing.T) { c := &Context{ Variables: astjson.MustParseBytes([]byte(`{"third":"b","second":"c","fourth":1}`)), } - buf := &bytes.Buffer{} data := []byte(`{"eventX":"b","eventY":"c","fourth":1}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, false, skip) }) @@ -901,9 +872,8 @@ func TestSubscriptionFilter(t *testing.T) { c := &Context{ Variables: astjson.MustParseBytes([]byte(`{"third":"b","second":"c"}`)), } - buf := &bytes.Buffer{} data := []byte(`{"eventX":"b","eventY":"c"}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, false, skip) }) @@ -956,9 +926,8 @@ func TestSubscriptionFilter(t *testing.T) { c := &Context{ Variables: astjson.MustParseBytes([]byte(`{"third":"b","second":"c","fourth":1}`)), } - buf := &bytes.Buffer{} data := []byte(`{"eventX":"b","eventY":"c1","fourth":1}`) - skip, err := filter.SkipEvent(c, data, buf) + skip, err := filter.SkipEvent(c, data) assert.NoError(t, err) assert.Equal(t, false, skip) })